diff --git a/services/tool-learning/.gitignore b/services/tool-learning/.gitignore new file mode 100644 index 0000000..ff6ebc3 --- /dev/null +++ b/services/tool-learning/.gitignore @@ -0,0 +1,3 @@ +demo +seed-lake +tool-learning diff --git a/services/tool-learning/cmd/demo/main.go b/services/tool-learning/cmd/demo/main.go index cce13a8..d581319 100644 --- a/services/tool-learning/cmd/demo/main.go +++ b/services/tool-learning/cmd/demo/main.go @@ -322,7 +322,7 @@ func startServices(db *sql.DB) ( } store, err := valkey.NewPolicyStoreFromAddress( - context.Background(), redisSrv.Addr(), "", 0, "tool_policy", 2*time.Hour, + context.Background(), redisSrv.Addr(), "", 0, "tool_policy", 2*time.Hour, nil, ) if err != nil { redisSrv.Close() diff --git a/services/tool-learning/cmd/tool-learning/adapters_test.go b/services/tool-learning/cmd/tool-learning/adapters_test.go index c655cdb..35fc4ab 100644 --- a/services/tool-learning/cmd/tool-learning/adapters_test.go +++ b/services/tool-learning/cmd/tool-learning/adapters_test.go @@ -70,7 +70,7 @@ func buildTestAdapters(t *testing.T) ( // Valkey via miniredis redisSrv := startMiniredis(t) - store, err := valkey.NewPolicyStoreFromAddress(context.Background(), redisSrv.Addr(), "", 0, "tool_policy", 10*time.Minute) + store, err := valkey.NewPolicyStoreFromAddress(context.Background(), redisSrv.Addr(), "", 0, "tool_policy", 10*time.Minute, nil) if err != nil { t.Fatalf("valkey: %v", err) } @@ -84,7 +84,7 @@ func buildTestAdapters(t *testing.T) ( pub := natspub.NewPublisher(conn, "hourly") // S3 audit store (client creation doesn't connect) - audit, err := s3store.NewAuditStoreFromConfig("localhost:9000", "test", "test", "test-audit", false) + audit, err := s3store.NewAuditStoreFromConfig("localhost:9000", "test", "test", "test-audit", false, nil) if err != nil { t.Fatalf("audit store: %v", err) } diff --git a/services/tool-learning/cmd/tool-learning/main.go b/services/tool-learning/cmd/tool-learning/main.go index 1832f35..d8710ee 100644 --- a/services/tool-learning/cmd/tool-learning/main.go +++ b/services/tool-learning/cmd/tool-learning/main.go @@ -2,11 +2,14 @@ package main import ( "context" + "crypto/tls" + "crypto/x509" "flag" "fmt" "log/slog" "os" "strconv" + "strings" "time" "github.com/underpass-ai/underpass-runtime/services/tool-learning/internal/adapters/duckdb" @@ -135,6 +138,7 @@ type adapterConfig struct { S3SecretKey string S3Region string S3UseSSL bool + S3TLS *tls.Config LakeBucket string AuditBucket string ValkeyAddr string @@ -142,7 +146,9 @@ type adapterConfig struct { ValkeyDB int ValkeyPfx string ValkeyTTL time.Duration + ValkeyTLS *tls.Config NATSURL string + NATSTLS *tls.Config Schedule string } @@ -156,12 +162,42 @@ func loadConfig(schedule string) (adapterConfig, error) { if err != nil { return adapterConfig{}, fmt.Errorf("invalid VALKEY_TTL: %w", err) } + valkeyTLS, err := buildClientTLS( + os.Getenv("VALKEY_TLS_ENABLED") == "true", + os.Getenv("VALKEY_TLS_CA_PATH"), + os.Getenv("VALKEY_TLS_CERT_PATH"), + os.Getenv("VALKEY_TLS_KEY_PATH"), + ) + if err != nil { + return adapterConfig{}, fmt.Errorf("valkey TLS: %w", err) + } + + natsTLS, err := buildClientTLS( + strings.TrimSpace(os.Getenv("NATS_TLS_MODE")) != "" && strings.TrimSpace(os.Getenv("NATS_TLS_MODE")) != "disabled", + os.Getenv("NATS_TLS_CA_PATH"), + os.Getenv("NATS_TLS_CERT_PATH"), + os.Getenv("NATS_TLS_KEY_PATH"), + ) + if err != nil { + return adapterConfig{}, fmt.Errorf("nats TLS: %w", err) + } + + s3TLS, err := buildClientTLS( + envOrDefault("S3_USE_SSL", "false") == "true", + os.Getenv("S3_CA_PATH"), + "", "", + ) + if err != nil { + return adapterConfig{}, fmt.Errorf("s3 TLS: %w", err) + } + return adapterConfig{ S3Endpoint: envOrDefault("S3_ENDPOINT", "localhost:9000"), S3AccessKey: envOrDefault("S3_ACCESS_KEY", ""), S3SecretKey: envOrDefault("S3_SECRET_KEY", ""), S3Region: envOrDefault("S3_REGION", "us-east-1"), S3UseSSL: envOrDefault("S3_USE_SSL", "false") == "true", + S3TLS: s3TLS, LakeBucket: envOrDefault("LAKE_BUCKET", "telemetry-lake"), AuditBucket: envOrDefault("AUDIT_BUCKET", "policy-audit"), ValkeyAddr: envOrDefault("VALKEY_ADDR", "localhost:6379"), @@ -169,7 +205,9 @@ func loadConfig(schedule string) (adapterConfig, error) { ValkeyDB: valkeyDB, ValkeyPfx: envOrDefault("VALKEY_KEY_PREFIX", "tool_policy"), ValkeyTTL: valkeyTTL, + ValkeyTLS: valkeyTLS, NATSURL: envOrDefault("NATS_URL", "nats://localhost:4222"), + NATSTLS: natsTLS, Schedule: schedule, }, nil } @@ -190,21 +228,21 @@ func buildAdapters(cfg adapterConfig, logger *slog.Logger) ( } logger.Info(logAdapterReady, "adapter", "duckdb-lake-reader", "bucket", cfg.LakeBucket) - store, err := valkey.NewPolicyStoreFromAddress(context.Background(), cfg.ValkeyAddr, cfg.ValkeyPass, cfg.ValkeyDB, cfg.ValkeyPfx, cfg.ValkeyTTL) + store, err := valkey.NewPolicyStoreFromAddress(context.Background(), cfg.ValkeyAddr, cfg.ValkeyPass, cfg.ValkeyDB, cfg.ValkeyPfx, cfg.ValkeyTTL, cfg.ValkeyTLS) if err != nil { _ = lake.Close() return nil, nil, nil, nil, nil, fmt.Errorf("valkey policy store: %w", err) } logger.Info(logAdapterReady, "adapter", "valkey-policy-store", "addr", cfg.ValkeyAddr) - pub, natsConn, err := natspub.NewPublisherFromURL(cfg.NATSURL, cfg.Schedule) + pub, natsConn, err := natspub.NewPublisherFromURL(cfg.NATSURL, cfg.Schedule, cfg.NATSTLS) if err != nil { _ = lake.Close() return nil, nil, nil, nil, nil, fmt.Errorf("nats publisher: %w", err) } logger.Info(logAdapterReady, "adapter", "nats-publisher", "url", cfg.NATSURL) - audit, err := s3.NewAuditStoreFromConfig(cfg.S3Endpoint, cfg.S3AccessKey, cfg.S3SecretKey, cfg.AuditBucket, cfg.S3UseSSL) + audit, err := s3.NewAuditStoreFromConfig(cfg.S3Endpoint, cfg.S3AccessKey, cfg.S3SecretKey, cfg.AuditBucket, cfg.S3UseSSL, cfg.S3TLS) if err != nil { _ = lake.Close() natsConn.Close() @@ -221,6 +259,34 @@ func buildAdapters(cfg adapterConfig, logger *slog.Logger) ( return lake, store, pub, audit, cleanup, nil } +// buildClientTLS builds a *tls.Config for outgoing connections. +// Returns nil when enabled is false (TLS disabled). +func buildClientTLS(enabled bool, caPath, certPath, keyPath string) (*tls.Config, error) { + if !enabled { + return nil, nil + } + cfg := &tls.Config{MinVersion: tls.VersionTLS13} + if caPath != "" { + data, err := os.ReadFile(caPath) + if err != nil { + return nil, fmt.Errorf("read CA %s: %w", caPath, err) + } + pool := x509.NewCertPool() + if !pool.AppendCertsFromPEM(data) { + return nil, fmt.Errorf("no valid certs in %s", caPath) + } + cfg.RootCAs = pool + } + if certPath != "" && keyPath != "" { + cert, err := tls.LoadX509KeyPair(certPath, keyPath) + if err != nil { + return nil, fmt.Errorf("load cert/key: %w", err) + } + cfg.Certificates = []tls.Certificate{cert} + } + return cfg, nil +} + func envOrDefault(key, fallback string) string { if v := os.Getenv(key); v != "" { return v diff --git a/services/tool-learning/internal/adapters/nats/publisher.go b/services/tool-learning/internal/adapters/nats/publisher.go index ec026da..2744054 100644 --- a/services/tool-learning/internal/adapters/nats/publisher.go +++ b/services/tool-learning/internal/adapters/nats/publisher.go @@ -2,6 +2,7 @@ package nats import ( "context" + "crypto/tls" "encoding/json" "fmt" "time" @@ -35,8 +36,12 @@ func NewPublisher(conn *nats.Conn, schedule string) *Publisher { } // NewPublisherFromURL connects to NATS and returns a publisher. -func NewPublisherFromURL(url, schedule string) (*Publisher, *nats.Conn, error) { - conn, err := nats.Connect(url) +func NewPublisherFromURL(url, schedule string, tlsCfg *tls.Config) (*Publisher, *nats.Conn, error) { + var opts []nats.Option + if tlsCfg != nil { + opts = append(opts, nats.Secure(tlsCfg)) + } + conn, err := nats.Connect(url, opts...) if err != nil { return nil, nil, fmt.Errorf("nats connect: %w", err) } diff --git a/services/tool-learning/internal/adapters/nats/publisher_test.go b/services/tool-learning/internal/adapters/nats/publisher_test.go index 2b7dcc4..cfd5c48 100644 --- a/services/tool-learning/internal/adapters/nats/publisher_test.go +++ b/services/tool-learning/internal/adapters/nats/publisher_test.go @@ -90,7 +90,7 @@ func TestPublishPolicyUpdated(t *testing.T) { func TestNewPublisherFromURL(t *testing.T) { srv := startTestNATS(t) - pub, conn, err := NewPublisherFromURL(srv.ClientURL(), "daily") + pub, conn, err := NewPublisherFromURL(srv.ClientURL(), "daily", nil) if err != nil { t.Fatalf("NewPublisherFromURL: %v", err) } @@ -123,7 +123,7 @@ func TestPublisherCloseNilConn(t *testing.T) { } func TestNewPublisherFromURLInvalid(t *testing.T) { - _, _, err := NewPublisherFromURL("nats://invalid:9999", "daily") + _, _, err := NewPublisherFromURL("nats://invalid:9999", "daily", nil) if err == nil { t.Fatal("expected error connecting to invalid NATS URL") } diff --git a/services/tool-learning/internal/adapters/s3/audit_store.go b/services/tool-learning/internal/adapters/s3/audit_store.go index 5651576..0fb79a2 100644 --- a/services/tool-learning/internal/adapters/s3/audit_store.go +++ b/services/tool-learning/internal/adapters/s3/audit_store.go @@ -3,9 +3,11 @@ package s3 import ( "bytes" "context" + "crypto/tls" "encoding/json" "fmt" "io" + "net/http" "time" "github.com/minio/minio-go/v7" @@ -32,11 +34,15 @@ func NewAuditStore(client ObjectClient, bucket string) *AuditStore { } // NewAuditStoreFromConfig creates an audit store connecting to MinIO/S3. -func NewAuditStoreFromConfig(endpoint, accessKey, secretKey, bucket string, useSSL bool) (*AuditStore, error) { - client, err := minio.New(endpoint, &minio.Options{ +func NewAuditStoreFromConfig(endpoint, accessKey, secretKey, bucket string, useSSL bool, tlsCfg *tls.Config) (*AuditStore, error) { + opts := &minio.Options{ Creds: credentials.NewStaticV4(accessKey, secretKey, ""), Secure: useSSL, - }) + } + if tlsCfg != nil { + opts.Transport = &http.Transport{TLSClientConfig: tlsCfg} + } + client, err := minio.New(endpoint, opts) if err != nil { return nil, fmt.Errorf("minio client: %w", err) } diff --git a/services/tool-learning/internal/adapters/s3/audit_store_test.go b/services/tool-learning/internal/adapters/s3/audit_store_test.go index 2a2616d..9d39084 100644 --- a/services/tool-learning/internal/adapters/s3/audit_store_test.go +++ b/services/tool-learning/internal/adapters/s3/audit_store_test.go @@ -186,7 +186,7 @@ func TestEnsureBucketCheckError(t *testing.T) { } func TestNewAuditStoreFromConfig(t *testing.T) { - store, err := NewAuditStoreFromConfig("localhost:9000", "access", "secret", "bucket", false) + store, err := NewAuditStoreFromConfig("localhost:9000", "access", "secret", "bucket", false, nil) if err != nil { t.Fatalf("NewAuditStoreFromConfig: %v", err) } diff --git a/services/tool-learning/internal/adapters/valkey/policy_store.go b/services/tool-learning/internal/adapters/valkey/policy_store.go index 24754e7..9a1783c 100644 --- a/services/tool-learning/internal/adapters/valkey/policy_store.go +++ b/services/tool-learning/internal/adapters/valkey/policy_store.go @@ -2,6 +2,7 @@ package valkey import ( "context" + "crypto/tls" "encoding/json" "fmt" "time" @@ -28,11 +29,12 @@ func NewPolicyStore(client redis.Cmdable, keyPrefix string, ttl time.Duration) * } // NewPolicyStoreFromAddress creates a PolicyStore connecting to a Valkey address. -func NewPolicyStoreFromAddress(ctx context.Context, addr, password string, db int, keyPrefix string, ttl time.Duration) (*PolicyStore, error) { +func NewPolicyStoreFromAddress(ctx context.Context, addr, password string, db int, keyPrefix string, ttl time.Duration, tlsCfg *tls.Config) (*PolicyStore, error) { client := redis.NewClient(&redis.Options{ - Addr: addr, - Password: password, - DB: db, + Addr: addr, + Password: password, + DB: db, + TLSConfig: tlsCfg, }) if err := client.Ping(ctx).Err(); err != nil { return nil, fmt.Errorf("valkey ping: %w", err) diff --git a/services/tool-learning/internal/adapters/valkey/policy_store_test.go b/services/tool-learning/internal/adapters/valkey/policy_store_test.go index c8208f8..351b3c6 100644 --- a/services/tool-learning/internal/adapters/valkey/policy_store_test.go +++ b/services/tool-learning/internal/adapters/valkey/policy_store_test.go @@ -65,7 +65,7 @@ func TestPolicyRoundTripJSON(t *testing.T) { func TestPolicyStoreWithMiniredis(t *testing.T) { srv := startMiniredis(t) - store, err := NewPolicyStoreFromAddress(context.Background(), srv.Addr(), "", 0, "tool_policy", 10*time.Minute) + store, err := NewPolicyStoreFromAddress(context.Background(), srv.Addr(), "", 0, "tool_policy", 10*time.Minute, nil) if err != nil { t.Fatalf("NewPolicyStoreFromAddress: %v", err) } @@ -111,7 +111,7 @@ func TestPolicyStoreWithMiniredis(t *testing.T) { } func TestNewPolicyStoreFromAddressFailure(t *testing.T) { - _, err := NewPolicyStoreFromAddress(context.Background(), "localhost:1", "", 0, "tp", time.Minute) + _, err := NewPolicyStoreFromAddress(context.Background(), "localhost:1", "", 0, "tp", time.Minute, nil) if err == nil { t.Fatal("expected error from invalid address") } @@ -119,7 +119,7 @@ func TestNewPolicyStoreFromAddressFailure(t *testing.T) { func TestPolicyStoreWriteBatch(t *testing.T) { srv := startMiniredis(t) - store, err := NewPolicyStoreFromAddress(context.Background(), srv.Addr(), "", 0, "tool_policy", 10*time.Minute) + store, err := NewPolicyStoreFromAddress(context.Background(), srv.Addr(), "", 0, "tool_policy", 10*time.Minute, nil) if err != nil { t.Fatalf("NewPolicyStoreFromAddress: %v", err) } diff --git a/services/tool-learning/internal/integration/pipeline_test.go b/services/tool-learning/internal/integration/pipeline_test.go index b78ff12..063b322 100644 --- a/services/tool-learning/internal/integration/pipeline_test.go +++ b/services/tool-learning/internal/integration/pipeline_test.go @@ -159,7 +159,7 @@ func TestFullPipelineHourly(t *testing.T) { // --- Valkey (miniredis) --- redisSrv := startMiniredis(t) store, err := valkey.NewPolicyStoreFromAddress( - context.Background(), redisSrv.Addr(), "", 0, "tool_policy", 10*time.Minute, + context.Background(), redisSrv.Addr(), "", 0, "tool_policy", 10*time.Minute, nil, ) if err != nil { t.Fatalf("valkey store: %v", err) @@ -320,7 +320,7 @@ func TestPipelineWithConstraints(t *testing.T) { redisSrv := startMiniredis(t) store, err := valkey.NewPolicyStoreFromAddress( - context.Background(), redisSrv.Addr(), "", 0, "tool_policy", 10*time.Minute, + context.Background(), redisSrv.Addr(), "", 0, "tool_policy", 10*time.Minute, nil, ) if err != nil { t.Fatalf("valkey store: %v", err) @@ -414,7 +414,7 @@ func setupPipeline(t *testing.T, now time.Time, seed bool) *pipelineEnv { redisSrv := startMiniredis(t) store, err := valkey.NewPolicyStoreFromAddress( - context.Background(), redisSrv.Addr(), "", 0, "tool_policy", 10*time.Minute, + context.Background(), redisSrv.Addr(), "", 0, "tool_policy", 10*time.Minute, nil, ) if err != nil { t.Fatalf("valkey store: %v", err)