diff --git a/go/internal/database/datastore/vulnerability.go b/go/internal/database/datastore/vulnerability.go index a4abfebde5c..162772fdebd 100644 --- a/go/internal/database/datastore/vulnerability.go +++ b/go/internal/database/datastore/vulnerability.go @@ -10,6 +10,7 @@ import ( "cloud.google.com/go/datastore" "github.com/google/osv.dev/go/internal/models" + "github.com/ossf/osv-schema/bindings/go/osvschema" "google.golang.org/api/iterator" ) @@ -80,3 +81,15 @@ func (s *VulnerabilityStore) GetSourceModified(ctx context.Context, id string) ( return v.ModifiedRaw, nil } + +func (s *VulnerabilityStore) Get(_ context.Context, _ string) (*osvschema.Vulnerability, error) { + panic("not implemented") +} + +func (s *VulnerabilityStore) Write(_ context.Context, _ models.WriteRequest) error { + panic("not implemented") +} + +func (s *VulnerabilityStore) Withdraw(_ context.Context, _ string) error { + panic("not implemented") +} diff --git a/go/internal/importer/mock_test.go b/go/internal/importer/mock_test.go index dc7f73e106a..fea2d38ec11 100644 --- a/go/internal/importer/mock_test.go +++ b/go/internal/importer/mock_test.go @@ -8,6 +8,7 @@ import ( "time" "github.com/google/osv.dev/go/internal/models" + "github.com/ossf/osv-schema/bindings/go/osvschema" ) type mockSourceRepositoryStore struct { @@ -53,6 +54,18 @@ func (m *mockVulnerabilityStore) GetSourceModified(_ context.Context, vuln strin return time.Time{}, models.ErrNotFound } +func (m *mockVulnerabilityStore) Get(_ context.Context, _ string) (*osvschema.Vulnerability, error) { + panic("not implemented") +} + +func (m *mockVulnerabilityStore) Write(_ context.Context, _ models.WriteRequest) error { + panic("not implemented") +} + +func (m *mockVulnerabilityStore) Withdraw(_ context.Context, _ string) error { + panic("not implemented") +} + type mockSourceRecord struct { DataToRead []byte ReadError error diff --git a/go/internal/models/vulnerability.go b/go/internal/models/vulnerability.go index f761409a9b4..898c2423365 100644 --- a/go/internal/models/vulnerability.go +++ b/go/internal/models/vulnerability.go @@ -5,6 +5,8 @@ import ( "context" "iter" "time" + + "github.com/ossf/osv-schema/bindings/go/osvschema" ) // VulnSourceRef represents a minimal vulnerability entry for indexing/reconciliation. @@ -15,10 +17,38 @@ type VulnSourceRef struct { ModifiedRaw time.Time } +// WriteRequest bundles everything needed to perform a permanent update to an OSV record. +type WriteRequest struct { + ID string + Source string // The source name (e.g. "debian") + Path string // The relative path in the source (e.g. "CVE-2023.json") + Raw *osvschema.Vulnerability // The original input proto + Enriched *osvschema.Vulnerability // The final enriched proto + AffectedCommits AffectedCommitsResult // Derived affected commits +} + +// AffectedCommitsResult handles the distinction between 'no change' and 'set to empty'. +type AffectedCommitsResult struct { + Commits [][]byte + Skip bool // If true, the store should not modify existing affected commits for this ID. +} + type VulnerabilityStore interface { // ListBySource returns an iterator over vulnerabilities for a given source. ListBySource(ctx context.Context, source string, skipWithdrawn bool) iter.Seq2[*VulnSourceRef, error] + // GetSourceModified returns the modified time of a vulnerability according to the source. // Returns ErrNotFound if the vulnerability is not found. GetSourceModified(ctx context.Context, id string) (time.Time, error) + + // Get returns the fully enriched vulnerability. + Get(ctx context.Context, id string) (*osvschema.Vulnerability, error) + + // Write atomically updates the base record and its derived indexes. + Write(ctx context.Context, req WriteRequest) error + + // Withdraw marks a vulnerability as withdrawn/deleted. + // Sets the withdrawn + modified dates to the current time. + // If the vulnerability is already withdrawn, this is a no-op. + Withdraw(ctx context.Context, id string) error } diff --git a/go/internal/worker/engine.go b/go/internal/worker/engine.go new file mode 100644 index 00000000000..1f518731a0a --- /dev/null +++ b/go/internal/worker/engine.go @@ -0,0 +1,103 @@ +// Package worker contains the implementation for the vulnerability enrichment worker pipeline. +package worker + +import ( + "context" + "errors" + "fmt" + "log/slog" + + "github.com/google/go-cmp/cmp" + "github.com/google/osv.dev/go/internal/models" + "github.com/google/osv.dev/go/internal/worker/pipeline" + "github.com/google/osv.dev/go/logger" + "github.com/ossf/osv-schema/bindings/go/osvschema" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/testing/protocmp" + "google.golang.org/protobuf/types/known/timestamppb" +) + +type Engine struct { + Stores Stores + Pipeline []pipeline.Enricher +} + +func (e *Engine) RunTask(ctx context.Context, task Task) error { + switch task.Type { + case TaskDelete: + return e.handleDelete(ctx, task) + case TaskUpdate: + return e.handleUpdate(ctx, task) + default: + return fmt.Errorf("unknown task type: %v", task.Type) + } +} + +func (e *Engine) handleUpdate(ctx context.Context, task Task) error { + params := pipeline.EnrichParams{ + PathInSource: task.PathInSource, + } + var err error + params.SourceRepo, err = e.Stores.SourceRepo.Get(ctx, task.SourceID) + if err != nil { + return err + } + if task.Vuln == nil { + // TODO: Download Vuln from source + return errors.New("vuln not provided") + } + + enriched := proto.Clone(task.Vuln).(*osvschema.Vulnerability) + for _, enricher := range e.Pipeline { + if err := enricher.Enrich(ctx, enriched, ¶ms); err != nil { + logger.ErrorContext(ctx, "Enricher failed with error", + slog.String("id", task.Vuln.GetId()), + slog.String("enricher", fmt.Sprintf("%T", enricher)), + slog.Any("error", err), + ) + + return err + } + } + + // TODO: affected commits + + // Get the current state of the vuln to check against + current, err := e.Stores.Vulnerability.Get(ctx, enriched.GetId()) + isNotFound := errors.Is(err, models.ErrNotFound) + + if err != nil && !isNotFound { + logger.ErrorContext(ctx, "Failed to get current vuln state", slog.String("vuln_id", enriched.GetId()), slog.Any("error", err)) + + return fmt.Errorf("failed to get current vuln state: %w", err) + } + + if isNotFound || e.isSemanticallyDifferent(current, enriched) { + enriched.Modified = timestamppb.Now() + } else if current.GetModified().AsTime().After(enriched.GetModified().AsTime()) { + enriched.Modified = current.GetModified() + } + + return e.Stores.Vulnerability.Write(ctx, models.WriteRequest{ + ID: enriched.GetId(), + Source: task.SourceID, + Path: task.PathInSource, + Raw: task.Vuln, + Enriched: enriched, + AffectedCommits: models.AffectedCommitsResult{ + Skip: true, + }, + }) +} + +func (e *Engine) isSemanticallyDifferent(v1, v2 *osvschema.Vulnerability) bool { + return !cmp.Equal(v1, v2, + protocmp.Transform(), + protocmp.IgnoreFields(&osvschema.Vulnerability{}, "modified", "published"), + ) +} + +func (e *Engine) handleDelete(_ context.Context, _ Task) error { + // TODO + return nil +} diff --git a/go/internal/worker/pipeline/enrich.go b/go/internal/worker/pipeline/enrich.go new file mode 100644 index 00000000000..d189bbfb174 --- /dev/null +++ b/go/internal/worker/pipeline/enrich.go @@ -0,0 +1,18 @@ +// Package pipeline contains individual vulnerability enrichers for the worker pipeline. +package pipeline + +import ( + "context" + + "github.com/google/osv.dev/go/internal/models" + "github.com/ossf/osv-schema/bindings/go/osvschema" +) + +type EnrichParams struct { + PathInSource string + SourceRepo *models.SourceRepository +} + +type Enricher interface { + Enrich(ctx context.Context, vuln *osvschema.Vulnerability, params *EnrichParams) error +} diff --git a/go/internal/worker/pipeline/registry/registry.go b/go/internal/worker/pipeline/registry/registry.go new file mode 100644 index 00000000000..e71043840aa --- /dev/null +++ b/go/internal/worker/pipeline/registry/registry.go @@ -0,0 +1,12 @@ +// Package registry contains all the enrichers that are used in the worker pipeline. +package registry + +import ( + "github.com/google/osv.dev/go/internal/worker/pipeline" + "github.com/google/osv.dev/go/internal/worker/pipeline/sourcelink" +) + +// List is the list of all enrichers used in the worker pipeline. +var List = []pipeline.Enricher{ + &sourcelink.Enricher{}, +} diff --git a/go/internal/worker/pipeline/sourcelink/sourcelink.go b/go/internal/worker/pipeline/sourcelink/sourcelink.go new file mode 100644 index 00000000000..829c2b2eb02 --- /dev/null +++ b/go/internal/worker/pipeline/sourcelink/sourcelink.go @@ -0,0 +1,33 @@ +// Package sourcelink implements an enricher that adds the source link to the vulnerability. +// The source link is added under the database_specific field under each affected range, +// with they key "source" and the value being the full path to the vulnerability in the source repo. +package sourcelink + +import ( + "context" + + "github.com/google/osv.dev/go/internal/worker/pipeline" + "github.com/ossf/osv-schema/bindings/go/osvschema" + "google.golang.org/protobuf/types/known/structpb" +) + +type Enricher struct{} + +var _ pipeline.Enricher = (*Enricher)(nil) + +func (*Enricher) Enrich(_ context.Context, vuln *osvschema.Vulnerability, params *pipeline.EnrichParams) error { + if params.SourceRepo == nil || params.SourceRepo.Link == "" { + return nil + } + sourceLink := structpb.NewStringValue(params.SourceRepo.Link + params.PathInSource) + + for _, affected := range vuln.GetAffected() { + if affected.GetDatabaseSpecific() == nil { + // The error would only be from an invalid map value, passing nil is fine. + affected.DatabaseSpecific, _ = structpb.NewStruct(nil) + } + affected.DatabaseSpecific.Fields["source"] = sourceLink + } + + return nil +} diff --git a/go/internal/worker/subscriber.go b/go/internal/worker/subscriber.go new file mode 100644 index 00000000000..e4da5ea2016 --- /dev/null +++ b/go/internal/worker/subscriber.go @@ -0,0 +1,124 @@ +package worker + +import ( + "context" + "fmt" + "log/slog" + "strconv" + "time" + + "cloud.google.com/go/pubsub/v2" + "github.com/google/osv.dev/go/logger" + "github.com/klauspost/compress/zstd" + "github.com/ossf/osv-schema/bindings/go/osvschema" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/propagation" + "google.golang.org/protobuf/proto" +) + +type Subscriber struct { + Engine Engine + PubSubSub *pubsub.Subscriber +} + +func (s *Subscriber) Run(ctx context.Context) error { + return s.PubSubSub.Receive(ctx, s.handleMessage) +} + +func (s *Subscriber) handleMessage(ctx context.Context, m *pubsub.Message) { + if taskType := m.Attributes["type"]; taskType != "update" { + logger.InfoContext(ctx, "Skipping message, not an update", slog.Any("task_type", taskType)) + m.Ack() + + return + } + + taskCtx := otel.GetTextMapPropagator().Extract(ctx, propagation.MapCarrier(m.Attributes)) + taskCtx, span := otel.Tracer("worker").Start(taskCtx, "process_message") + defer span.End() + task := Task{ + SourceID: m.Attributes["source"], + PathInSource: m.Attributes["path"], + } + + logInfo := []any{ + slog.String("source", task.SourceID), + slog.String("path", task.PathInSource), + } + + var err error + task.Vuln, err = s.parseVuln(m) + if err != nil { + logger.ErrorContext(taskCtx, "Failed to parse vulnerability", append(logInfo, slog.Any("error", err))...) + m.Nack() + + return + } + + deleted, err := strconv.ParseBool(m.Attributes["deleted"]) + if err != nil { + logger.ErrorContext(taskCtx, "Failed to parse deleted attribute, defaulting to false", append(logInfo, slog.Any("error", err))...) + deleted = false + } + if deleted { + task.Type = TaskDelete + } else { + task.Type = TaskUpdate + } + + task.ReceivedTime, err = s.timeFromUnixSeconds(m.Attributes["req_timestamp"]) + if err != nil { + logger.ErrorContext(taskCtx, "Failed to parse req_timestamp attribute, ignoring", append(logInfo, slog.Any("error", err))...) + } + srcTime := m.Attributes["src_timestamp"] + if srcTime != "" { + task.SourceTime, err = s.timeFromUnixSeconds(srcTime) + if err != nil { + logger.ErrorContext(taskCtx, "Failed to parse src_timestamp attribute, ignoring", append(logInfo, slog.Any("error", err))...) + } + } + + skipHash, ok := m.Attributes["skip_hash_check"] + if !ok || skipHash != "true" { + task.SHA256 = m.Attributes["original_sha256"] + } + + if err := s.Engine.RunTask(taskCtx, task); err != nil { + logger.ErrorContext(taskCtx, "Failed to process task", append(logInfo, slog.Any("error", err))...) + m.Nack() + } else { + m.Ack() + } +} + +func (s *Subscriber) parseVuln(m *pubsub.Message) (*osvschema.Vulnerability, error) { + if len(m.Data) == 0 { + //nolint:nilnil // this is expected for delete requests + return nil, nil + } + if m.Attributes["content_encoding"] != "zstd" { + return nil, fmt.Errorf("unrecognized content encoding: %s", m.Attributes["content_encoding"]) + } + // TODO: try to extract the actual uncompressed size from the zstd frame. + buf := make([]byte, 0, len(m.Data)*3) + buf, err := zstd.DecodeTo(buf, m.Data) + if err != nil { + return nil, fmt.Errorf("failed to decompress vulnerability: %w", err) + } + v := &osvschema.Vulnerability{} + if err := proto.Unmarshal(buf, v); err != nil { + return nil, fmt.Errorf("failed to unmarshal vulnerability: %w", err) + } + + return v, nil +} + +func (s *Subscriber) timeFromUnixSeconds(tsString string) (*time.Time, error) { + timestamp, err := strconv.ParseInt(tsString, 10, 64) + if err != nil { + return nil, err + } + ts := time.Unix(timestamp, 0) + + return &ts, nil +} diff --git a/go/internal/worker/worker.go b/go/internal/worker/worker.go new file mode 100644 index 00000000000..9d043abfe9a --- /dev/null +++ b/go/internal/worker/worker.go @@ -0,0 +1,35 @@ +// Package worker contains the implementation for the vulnerability enrichment worker pipeline. +package worker + +import ( + "time" + + "github.com/google/osv.dev/go/internal/models" + "github.com/ossf/osv-schema/bindings/go/osvschema" +) + +type TaskType int + +const ( + TaskUnknown TaskType = iota + TaskUpdate + TaskDelete +) + +type Task struct { + Type TaskType + Vuln *osvschema.Vulnerability + SourceID string + PathInSource string + // ReceivedTime is when the importer requested the vuln to be processed. + ReceivedTime *time.Time + // SourceTime is the modified time according to the source + SourceTime *time.Time + // SHA256 is only used when Vuln is not provided + SHA256 string +} + +type Stores struct { + SourceRepo models.SourceRepositoryStore + Vulnerability models.VulnerabilityStore +}