Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions go/internal/database/datastore/vulnerability.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (

"cloud.google.com/go/datastore"
"github.com/google/osv.dev/go/internal/models"
"github.com/ossf/osv-schema/bindings/go/osvschema"
"google.golang.org/api/iterator"
)

Expand Down Expand Up @@ -80,3 +81,15 @@ func (s *VulnerabilityStore) GetSourceModified(ctx context.Context, id string) (

return v.ModifiedRaw, nil
}

func (s *VulnerabilityStore) Get(_ context.Context, _ string) (*osvschema.Vulnerability, error) {
panic("not implemented")
}

func (s *VulnerabilityStore) Write(_ context.Context, _ models.WriteRequest) error {
panic("not implemented")
}

func (s *VulnerabilityStore) Withdraw(_ context.Context, _ string) error {
panic("not implemented")
}
13 changes: 13 additions & 0 deletions go/internal/importer/mock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"time"

"github.com/google/osv.dev/go/internal/models"
"github.com/ossf/osv-schema/bindings/go/osvschema"
)

type mockSourceRepositoryStore struct {
Expand Down Expand Up @@ -53,6 +54,18 @@ func (m *mockVulnerabilityStore) GetSourceModified(_ context.Context, vuln strin
return time.Time{}, models.ErrNotFound
}

func (m *mockVulnerabilityStore) Get(_ context.Context, _ string) (*osvschema.Vulnerability, error) {
panic("not implemented")
}

func (m *mockVulnerabilityStore) Write(_ context.Context, _ models.WriteRequest) error {
panic("not implemented")
}

func (m *mockVulnerabilityStore) Withdraw(_ context.Context, _ string) error {
panic("not implemented")
}

type mockSourceRecord struct {
DataToRead []byte
ReadError error
Expand Down
30 changes: 30 additions & 0 deletions go/internal/models/vulnerability.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"context"
"iter"
"time"

"github.com/ossf/osv-schema/bindings/go/osvschema"
)

// VulnSourceRef represents a minimal vulnerability entry for indexing/reconciliation.
Expand All @@ -15,10 +17,38 @@ type VulnSourceRef struct {
ModifiedRaw time.Time
}

// WriteRequest bundles everything needed to perform a permanent update to an OSV record.
type WriteRequest struct {
ID string
Source string // The source name (e.g. "debian")
Path string // The relative path in the source (e.g. "CVE-2023.json")
Raw *osvschema.Vulnerability // The original input proto
Enriched *osvschema.Vulnerability // The final enriched proto
AffectedCommits AffectedCommitsResult // Derived affected commits
}

// AffectedCommitsResult handles the distinction between 'no change' and 'set to empty'.
type AffectedCommitsResult struct {
Commits [][]byte
Skip bool // If true, the store should not modify existing affected commits for this ID.
}

type VulnerabilityStore interface {
// ListBySource returns an iterator over vulnerabilities for a given source.
ListBySource(ctx context.Context, source string, skipWithdrawn bool) iter.Seq2[*VulnSourceRef, error]

// GetSourceModified returns the modified time of a vulnerability according to the source.
// Returns ErrNotFound if the vulnerability is not found.
GetSourceModified(ctx context.Context, id string) (time.Time, error)

// Get returns the fully enriched vulnerability.
Get(ctx context.Context, id string) (*osvschema.Vulnerability, error)

// Write atomically updates the base record and its derived indexes.
Write(ctx context.Context, req WriteRequest) error

// Withdraw marks a vulnerability as withdrawn/deleted.
// Sets the withdrawn + modified dates to the current time.
// If the vulnerability is already withdrawn, this is a no-op.
Withdraw(ctx context.Context, id string) error
}
103 changes: 103 additions & 0 deletions go/internal/worker/engine.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
// Package worker contains the implementation for the vulnerability enrichment worker pipeline.
package worker

import (
"context"
"errors"
"fmt"
"log/slog"

"github.com/google/go-cmp/cmp"
"github.com/google/osv.dev/go/internal/models"
"github.com/google/osv.dev/go/internal/worker/pipeline"
"github.com/google/osv.dev/go/logger"
"github.com/ossf/osv-schema/bindings/go/osvschema"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/testing/protocmp"
"google.golang.org/protobuf/types/known/timestamppb"
)

type Engine struct {
Stores Stores
Pipeline []pipeline.Enricher
}

func (e *Engine) RunTask(ctx context.Context, task Task) error {
switch task.Type {
case TaskDelete:
return e.handleDelete(ctx, task)
case TaskUpdate:
return e.handleUpdate(ctx, task)
default:
return fmt.Errorf("unknown task type: %v", task.Type)
}
}

func (e *Engine) handleUpdate(ctx context.Context, task Task) error {
params := pipeline.EnrichParams{
PathInSource: task.PathInSource,
}
var err error
params.SourceRepo, err = e.Stores.SourceRepo.Get(ctx, task.SourceID)
if err != nil {
return err
}
if task.Vuln == nil {
// TODO: Download Vuln from source
return errors.New("vuln not provided")
}

enriched := proto.Clone(task.Vuln).(*osvschema.Vulnerability)
for _, enricher := range e.Pipeline {
if err := enricher.Enrich(ctx, enriched, &params); err != nil {
logger.ErrorContext(ctx, "Enricher failed with error",
slog.String("id", task.Vuln.GetId()),
slog.String("enricher", fmt.Sprintf("%T", enricher)),
slog.Any("error", err),
)

return err
}
}

// TODO: affected commits

// Get the current state of the vuln to check against
current, err := e.Stores.Vulnerability.Get(ctx, enriched.GetId())
isNotFound := errors.Is(err, models.ErrNotFound)

if err != nil && !isNotFound {
logger.ErrorContext(ctx, "Failed to get current vuln state", slog.String("vuln_id", enriched.GetId()), slog.Any("error", err))

return fmt.Errorf("failed to get current vuln state: %w", err)
}

if isNotFound || e.isSemanticallyDifferent(current, enriched) {
enriched.Modified = timestamppb.Now()
} else if current.GetModified().AsTime().After(enriched.GetModified().AsTime()) {
enriched.Modified = current.GetModified()
}

return e.Stores.Vulnerability.Write(ctx, models.WriteRequest{
ID: enriched.GetId(),
Source: task.SourceID,
Path: task.PathInSource,
Raw: task.Vuln,
Enriched: enriched,
AffectedCommits: models.AffectedCommitsResult{
Skip: true,
},
})
}

func (e *Engine) isSemanticallyDifferent(v1, v2 *osvschema.Vulnerability) bool {
return !cmp.Equal(v1, v2,
protocmp.Transform(),
protocmp.IgnoreFields(&osvschema.Vulnerability{}, "modified", "published"),
)
}

func (e *Engine) handleDelete(_ context.Context, _ Task) error {
// TODO
return nil
}
18 changes: 18 additions & 0 deletions go/internal/worker/pipeline/enrich.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// Package pipeline contains individual vulnerability enrichers for the worker pipeline.
package pipeline

import (
"context"

"github.com/google/osv.dev/go/internal/models"
"github.com/ossf/osv-schema/bindings/go/osvschema"
)

type EnrichParams struct {
PathInSource string
SourceRepo *models.SourceRepository
}

type Enricher interface {
Enrich(ctx context.Context, vuln *osvschema.Vulnerability, params *EnrichParams) error
}
12 changes: 12 additions & 0 deletions go/internal/worker/pipeline/registry/registry.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// Package registry contains all the enrichers that are used in the worker pipeline.
package registry

import (
"github.com/google/osv.dev/go/internal/worker/pipeline"
"github.com/google/osv.dev/go/internal/worker/pipeline/sourcelink"
)

// List is the list of all enrichers used in the worker pipeline.
var List = []pipeline.Enricher{
&sourcelink.Enricher{},
}
33 changes: 33 additions & 0 deletions go/internal/worker/pipeline/sourcelink/sourcelink.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Package sourcelink implements an enricher that adds the source link to the vulnerability.
// The source link is added under the database_specific field under each affected range,
// with they key "source" and the value being the full path to the vulnerability in the source repo.
package sourcelink

import (
"context"

"github.com/google/osv.dev/go/internal/worker/pipeline"
"github.com/ossf/osv-schema/bindings/go/osvschema"
"google.golang.org/protobuf/types/known/structpb"
)

type Enricher struct{}

var _ pipeline.Enricher = (*Enricher)(nil)

func (*Enricher) Enrich(_ context.Context, vuln *osvschema.Vulnerability, params *pipeline.EnrichParams) error {
if params.SourceRepo == nil || params.SourceRepo.Link == "" {
return nil
}
sourceLink := structpb.NewStringValue(params.SourceRepo.Link + params.PathInSource)

for _, affected := range vuln.GetAffected() {
if affected.GetDatabaseSpecific() == nil {
// The error would only be from an invalid map value, passing nil is fine.
affected.DatabaseSpecific, _ = structpb.NewStruct(nil)
}
affected.DatabaseSpecific.Fields["source"] = sourceLink
}

return nil
}
124 changes: 124 additions & 0 deletions go/internal/worker/subscriber.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
package worker

import (
"context"
"fmt"
"log/slog"
"strconv"
"time"

"cloud.google.com/go/pubsub/v2"
"github.com/google/osv.dev/go/logger"
"github.com/klauspost/compress/zstd"
"github.com/ossf/osv-schema/bindings/go/osvschema"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/propagation"
"google.golang.org/protobuf/proto"
)

type Subscriber struct {
Engine Engine
PubSubSub *pubsub.Subscriber
}

func (s *Subscriber) Run(ctx context.Context) error {
return s.PubSubSub.Receive(ctx, s.handleMessage)
}

func (s *Subscriber) handleMessage(ctx context.Context, m *pubsub.Message) {
if taskType := m.Attributes["type"]; taskType != "update" {
logger.InfoContext(ctx, "Skipping message, not an update", slog.Any("task_type", taskType))
m.Ack()

return
}

taskCtx := otel.GetTextMapPropagator().Extract(ctx, propagation.MapCarrier(m.Attributes))
taskCtx, span := otel.Tracer("worker").Start(taskCtx, "process_message")
defer span.End()
task := Task{
SourceID: m.Attributes["source"],
PathInSource: m.Attributes["path"],
}

logInfo := []any{
slog.String("source", task.SourceID),
slog.String("path", task.PathInSource),
}

var err error
task.Vuln, err = s.parseVuln(m)
if err != nil {
logger.ErrorContext(taskCtx, "Failed to parse vulnerability", append(logInfo, slog.Any("error", err))...)
m.Nack()

return
}

deleted, err := strconv.ParseBool(m.Attributes["deleted"])
if err != nil {
logger.ErrorContext(taskCtx, "Failed to parse deleted attribute, defaulting to false", append(logInfo, slog.Any("error", err))...)
deleted = false
}
if deleted {
task.Type = TaskDelete
} else {
task.Type = TaskUpdate
}

task.ReceivedTime, err = s.timeFromUnixSeconds(m.Attributes["req_timestamp"])
if err != nil {
logger.ErrorContext(taskCtx, "Failed to parse req_timestamp attribute, ignoring", append(logInfo, slog.Any("error", err))...)
}
srcTime := m.Attributes["src_timestamp"]
if srcTime != "" {
task.SourceTime, err = s.timeFromUnixSeconds(srcTime)
if err != nil {
logger.ErrorContext(taskCtx, "Failed to parse src_timestamp attribute, ignoring", append(logInfo, slog.Any("error", err))...)
}
}

skipHash, ok := m.Attributes["skip_hash_check"]
if !ok || skipHash != "true" {
task.SHA256 = m.Attributes["original_sha256"]
}

if err := s.Engine.RunTask(taskCtx, task); err != nil {
logger.ErrorContext(taskCtx, "Failed to process task", append(logInfo, slog.Any("error", err))...)
m.Nack()
} else {
m.Ack()
}
}

func (s *Subscriber) parseVuln(m *pubsub.Message) (*osvschema.Vulnerability, error) {
if len(m.Data) == 0 {
//nolint:nilnil // this is expected for delete requests
return nil, nil
}
if m.Attributes["content_encoding"] != "zstd" {
return nil, fmt.Errorf("unrecognized content encoding: %s", m.Attributes["content_encoding"])
}
// TODO: try to extract the actual uncompressed size from the zstd frame.
buf := make([]byte, 0, len(m.Data)*3)
buf, err := zstd.DecodeTo(buf, m.Data)
if err != nil {
return nil, fmt.Errorf("failed to decompress vulnerability: %w", err)
}
v := &osvschema.Vulnerability{}
if err := proto.Unmarshal(buf, v); err != nil {
return nil, fmt.Errorf("failed to unmarshal vulnerability: %w", err)
}

return v, nil
}

func (s *Subscriber) timeFromUnixSeconds(tsString string) (*time.Time, error) {
timestamp, err := strconv.ParseInt(tsString, 10, 64)
if err != nil {
return nil, err
}
ts := time.Unix(timestamp, 0)

return &ts, nil
}
Loading
Loading