Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions services/tool-learning/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
demo
seed-lake
tool-learning
2 changes: 1 addition & 1 deletion services/tool-learning/cmd/demo/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ func startServices(db *sql.DB) (
}

store, err := valkey.NewPolicyStoreFromAddress(
context.Background(), redisSrv.Addr(), "", 0, "tool_policy", 2*time.Hour,
context.Background(), redisSrv.Addr(), "", 0, "tool_policy", 2*time.Hour, nil,
)
if err != nil {
redisSrv.Close()
Expand Down
4 changes: 2 additions & 2 deletions services/tool-learning/cmd/tool-learning/adapters_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func buildTestAdapters(t *testing.T) (

// Valkey via miniredis
redisSrv := startMiniredis(t)
store, err := valkey.NewPolicyStoreFromAddress(context.Background(), redisSrv.Addr(), "", 0, "tool_policy", 10*time.Minute)
store, err := valkey.NewPolicyStoreFromAddress(context.Background(), redisSrv.Addr(), "", 0, "tool_policy", 10*time.Minute, nil)
if err != nil {
t.Fatalf("valkey: %v", err)
}
Expand All @@ -84,7 +84,7 @@ func buildTestAdapters(t *testing.T) (
pub := natspub.NewPublisher(conn, "hourly")

// S3 audit store (client creation doesn't connect)
audit, err := s3store.NewAuditStoreFromConfig("localhost:9000", "test", "test", "test-audit", false)
audit, err := s3store.NewAuditStoreFromConfig("localhost:9000", "test", "test", "test-audit", false, nil)
if err != nil {
t.Fatalf("audit store: %v", err)
}
Expand Down
72 changes: 69 additions & 3 deletions services/tool-learning/cmd/tool-learning/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@ package main

import (
"context"
"crypto/tls"
"crypto/x509"
"flag"
"fmt"
"log/slog"
"os"
"strconv"
"strings"
"time"

"github.com/underpass-ai/underpass-runtime/services/tool-learning/internal/adapters/duckdb"
Expand Down Expand Up @@ -135,14 +138,17 @@ type adapterConfig struct {
S3SecretKey string
S3Region string
S3UseSSL bool
S3TLS *tls.Config
LakeBucket string
AuditBucket string
ValkeyAddr string
ValkeyPass string
ValkeyDB int
ValkeyPfx string
ValkeyTTL time.Duration
ValkeyTLS *tls.Config
NATSURL string
NATSTLS *tls.Config
Schedule string
}

Expand All @@ -156,20 +162,52 @@ func loadConfig(schedule string) (adapterConfig, error) {
if err != nil {
return adapterConfig{}, fmt.Errorf("invalid VALKEY_TTL: %w", err)
}
valkeyTLS, err := buildClientTLS(
os.Getenv("VALKEY_TLS_ENABLED") == "true",
os.Getenv("VALKEY_TLS_CA_PATH"),
os.Getenv("VALKEY_TLS_CERT_PATH"),
os.Getenv("VALKEY_TLS_KEY_PATH"),
)
if err != nil {
return adapterConfig{}, fmt.Errorf("valkey TLS: %w", err)
}

natsTLS, err := buildClientTLS(
strings.TrimSpace(os.Getenv("NATS_TLS_MODE")) != "" && strings.TrimSpace(os.Getenv("NATS_TLS_MODE")) != "disabled",
os.Getenv("NATS_TLS_CA_PATH"),
os.Getenv("NATS_TLS_CERT_PATH"),
os.Getenv("NATS_TLS_KEY_PATH"),
)
if err != nil {
return adapterConfig{}, fmt.Errorf("nats TLS: %w", err)
}

s3TLS, err := buildClientTLS(
envOrDefault("S3_USE_SSL", "false") == "true",
os.Getenv("S3_CA_PATH"),
"", "",
)
if err != nil {
return adapterConfig{}, fmt.Errorf("s3 TLS: %w", err)
}

return adapterConfig{
S3Endpoint: envOrDefault("S3_ENDPOINT", "localhost:9000"),
S3AccessKey: envOrDefault("S3_ACCESS_KEY", ""),
S3SecretKey: envOrDefault("S3_SECRET_KEY", ""),
S3Region: envOrDefault("S3_REGION", "us-east-1"),
S3UseSSL: envOrDefault("S3_USE_SSL", "false") == "true",
S3TLS: s3TLS,
LakeBucket: envOrDefault("LAKE_BUCKET", "telemetry-lake"),
AuditBucket: envOrDefault("AUDIT_BUCKET", "policy-audit"),
ValkeyAddr: envOrDefault("VALKEY_ADDR", "localhost:6379"),
ValkeyPass: os.Getenv("VALKEY_PASSWORD"),
ValkeyDB: valkeyDB,
ValkeyPfx: envOrDefault("VALKEY_KEY_PREFIX", "tool_policy"),
ValkeyTTL: valkeyTTL,
ValkeyTLS: valkeyTLS,
NATSURL: envOrDefault("NATS_URL", "nats://localhost:4222"),
NATSTLS: natsTLS,
Schedule: schedule,
}, nil
}
Expand All @@ -190,21 +228,21 @@ func buildAdapters(cfg adapterConfig, logger *slog.Logger) (
}
logger.Info(logAdapterReady, "adapter", "duckdb-lake-reader", "bucket", cfg.LakeBucket)

store, err := valkey.NewPolicyStoreFromAddress(context.Background(), cfg.ValkeyAddr, cfg.ValkeyPass, cfg.ValkeyDB, cfg.ValkeyPfx, cfg.ValkeyTTL)
store, err := valkey.NewPolicyStoreFromAddress(context.Background(), cfg.ValkeyAddr, cfg.ValkeyPass, cfg.ValkeyDB, cfg.ValkeyPfx, cfg.ValkeyTTL, cfg.ValkeyTLS)
if err != nil {
_ = lake.Close()
return nil, nil, nil, nil, nil, fmt.Errorf("valkey policy store: %w", err)
}
logger.Info(logAdapterReady, "adapter", "valkey-policy-store", "addr", cfg.ValkeyAddr)

pub, natsConn, err := natspub.NewPublisherFromURL(cfg.NATSURL, cfg.Schedule)
pub, natsConn, err := natspub.NewPublisherFromURL(cfg.NATSURL, cfg.Schedule, cfg.NATSTLS)
if err != nil {
_ = lake.Close()
return nil, nil, nil, nil, nil, fmt.Errorf("nats publisher: %w", err)
}
logger.Info(logAdapterReady, "adapter", "nats-publisher", "url", cfg.NATSURL)

audit, err := s3.NewAuditStoreFromConfig(cfg.S3Endpoint, cfg.S3AccessKey, cfg.S3SecretKey, cfg.AuditBucket, cfg.S3UseSSL)
audit, err := s3.NewAuditStoreFromConfig(cfg.S3Endpoint, cfg.S3AccessKey, cfg.S3SecretKey, cfg.AuditBucket, cfg.S3UseSSL, cfg.S3TLS)
if err != nil {
_ = lake.Close()
natsConn.Close()
Expand All @@ -221,6 +259,34 @@ func buildAdapters(cfg adapterConfig, logger *slog.Logger) (
return lake, store, pub, audit, cleanup, nil
}

// buildClientTLS builds a *tls.Config for outgoing connections.
// Returns nil when enabled is false (TLS disabled).
func buildClientTLS(enabled bool, caPath, certPath, keyPath string) (*tls.Config, error) {
if !enabled {
return nil, nil
}
cfg := &tls.Config{MinVersion: tls.VersionTLS13}
if caPath != "" {
data, err := os.ReadFile(caPath)
if err != nil {
return nil, fmt.Errorf("read CA %s: %w", caPath, err)
}
pool := x509.NewCertPool()
if !pool.AppendCertsFromPEM(data) {
return nil, fmt.Errorf("no valid certs in %s", caPath)
}
cfg.RootCAs = pool
}
if certPath != "" && keyPath != "" {
cert, err := tls.LoadX509KeyPair(certPath, keyPath)
if err != nil {
return nil, fmt.Errorf("load cert/key: %w", err)
}
cfg.Certificates = []tls.Certificate{cert}
}
return cfg, nil
}

func envOrDefault(key, fallback string) string {
if v := os.Getenv(key); v != "" {
return v
Expand Down
9 changes: 7 additions & 2 deletions services/tool-learning/internal/adapters/nats/publisher.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package nats

import (
"context"
"crypto/tls"
"encoding/json"
"fmt"
"time"
Expand Down Expand Up @@ -35,8 +36,12 @@ func NewPublisher(conn *nats.Conn, schedule string) *Publisher {
}

// NewPublisherFromURL connects to NATS and returns a publisher.
func NewPublisherFromURL(url, schedule string) (*Publisher, *nats.Conn, error) {
conn, err := nats.Connect(url)
func NewPublisherFromURL(url, schedule string, tlsCfg *tls.Config) (*Publisher, *nats.Conn, error) {
var opts []nats.Option
if tlsCfg != nil {
opts = append(opts, nats.Secure(tlsCfg))
}
conn, err := nats.Connect(url, opts...)
if err != nil {
return nil, nil, fmt.Errorf("nats connect: %w", err)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ func TestPublishPolicyUpdated(t *testing.T) {
func TestNewPublisherFromURL(t *testing.T) {
srv := startTestNATS(t)

pub, conn, err := NewPublisherFromURL(srv.ClientURL(), "daily")
pub, conn, err := NewPublisherFromURL(srv.ClientURL(), "daily", nil)
if err != nil {
t.Fatalf("NewPublisherFromURL: %v", err)
}
Expand Down Expand Up @@ -123,7 +123,7 @@ func TestPublisherCloseNilConn(t *testing.T) {
}

func TestNewPublisherFromURLInvalid(t *testing.T) {
_, _, err := NewPublisherFromURL("nats://invalid:9999", "daily")
_, _, err := NewPublisherFromURL("nats://invalid:9999", "daily", nil)
if err == nil {
t.Fatal("expected error connecting to invalid NATS URL")
}
Expand Down
12 changes: 9 additions & 3 deletions services/tool-learning/internal/adapters/s3/audit_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@ package s3
import (
"bytes"
"context"
"crypto/tls"
"encoding/json"
"fmt"
"io"
"net/http"
"time"

"github.com/minio/minio-go/v7"
Expand All @@ -32,11 +34,15 @@ func NewAuditStore(client ObjectClient, bucket string) *AuditStore {
}

// NewAuditStoreFromConfig creates an audit store connecting to MinIO/S3.
func NewAuditStoreFromConfig(endpoint, accessKey, secretKey, bucket string, useSSL bool) (*AuditStore, error) {
client, err := minio.New(endpoint, &minio.Options{
func NewAuditStoreFromConfig(endpoint, accessKey, secretKey, bucket string, useSSL bool, tlsCfg *tls.Config) (*AuditStore, error) {
opts := &minio.Options{
Creds: credentials.NewStaticV4(accessKey, secretKey, ""),
Secure: useSSL,
})
}
if tlsCfg != nil {
opts.Transport = &http.Transport{TLSClientConfig: tlsCfg}
}
client, err := minio.New(endpoint, opts)
if err != nil {
return nil, fmt.Errorf("minio client: %w", err)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ func TestEnsureBucketCheckError(t *testing.T) {
}

func TestNewAuditStoreFromConfig(t *testing.T) {
store, err := NewAuditStoreFromConfig("localhost:9000", "access", "secret", "bucket", false)
store, err := NewAuditStoreFromConfig("localhost:9000", "access", "secret", "bucket", false, nil)
if err != nil {
t.Fatalf("NewAuditStoreFromConfig: %v", err)
}
Expand Down
10 changes: 6 additions & 4 deletions services/tool-learning/internal/adapters/valkey/policy_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package valkey

import (
"context"
"crypto/tls"
"encoding/json"
"fmt"
"time"
Expand All @@ -28,11 +29,12 @@ func NewPolicyStore(client redis.Cmdable, keyPrefix string, ttl time.Duration) *
}

// NewPolicyStoreFromAddress creates a PolicyStore connecting to a Valkey address.
func NewPolicyStoreFromAddress(ctx context.Context, addr, password string, db int, keyPrefix string, ttl time.Duration) (*PolicyStore, error) {
func NewPolicyStoreFromAddress(ctx context.Context, addr, password string, db int, keyPrefix string, ttl time.Duration, tlsCfg *tls.Config) (*PolicyStore, error) {
client := redis.NewClient(&redis.Options{
Addr: addr,
Password: password,
DB: db,
Addr: addr,
Password: password,
DB: db,
TLSConfig: tlsCfg,
})
if err := client.Ping(ctx).Err(); err != nil {
return nil, fmt.Errorf("valkey ping: %w", err)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func TestPolicyRoundTripJSON(t *testing.T) {

func TestPolicyStoreWithMiniredis(t *testing.T) {
srv := startMiniredis(t)
store, err := NewPolicyStoreFromAddress(context.Background(), srv.Addr(), "", 0, "tool_policy", 10*time.Minute)
store, err := NewPolicyStoreFromAddress(context.Background(), srv.Addr(), "", 0, "tool_policy", 10*time.Minute, nil)
if err != nil {
t.Fatalf("NewPolicyStoreFromAddress: %v", err)
}
Expand Down Expand Up @@ -111,15 +111,15 @@ func TestPolicyStoreWithMiniredis(t *testing.T) {
}

func TestNewPolicyStoreFromAddressFailure(t *testing.T) {
_, err := NewPolicyStoreFromAddress(context.Background(), "localhost:1", "", 0, "tp", time.Minute)
_, err := NewPolicyStoreFromAddress(context.Background(), "localhost:1", "", 0, "tp", time.Minute, nil)
if err == nil {
t.Fatal("expected error from invalid address")
}
}

func TestPolicyStoreWriteBatch(t *testing.T) {
srv := startMiniredis(t)
store, err := NewPolicyStoreFromAddress(context.Background(), srv.Addr(), "", 0, "tool_policy", 10*time.Minute)
store, err := NewPolicyStoreFromAddress(context.Background(), srv.Addr(), "", 0, "tool_policy", 10*time.Minute, nil)
if err != nil {
t.Fatalf("NewPolicyStoreFromAddress: %v", err)
}
Expand Down
6 changes: 3 additions & 3 deletions services/tool-learning/internal/integration/pipeline_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ func TestFullPipelineHourly(t *testing.T) {
// --- Valkey (miniredis) ---
redisSrv := startMiniredis(t)
store, err := valkey.NewPolicyStoreFromAddress(
context.Background(), redisSrv.Addr(), "", 0, "tool_policy", 10*time.Minute,
context.Background(), redisSrv.Addr(), "", 0, "tool_policy", 10*time.Minute, nil,
)
if err != nil {
t.Fatalf("valkey store: %v", err)
Expand Down Expand Up @@ -320,7 +320,7 @@ func TestPipelineWithConstraints(t *testing.T) {

redisSrv := startMiniredis(t)
store, err := valkey.NewPolicyStoreFromAddress(
context.Background(), redisSrv.Addr(), "", 0, "tool_policy", 10*time.Minute,
context.Background(), redisSrv.Addr(), "", 0, "tool_policy", 10*time.Minute, nil,
)
if err != nil {
t.Fatalf("valkey store: %v", err)
Expand Down Expand Up @@ -414,7 +414,7 @@ func setupPipeline(t *testing.T, now time.Time, seed bool) *pipelineEnv {

redisSrv := startMiniredis(t)
store, err := valkey.NewPolicyStoreFromAddress(
context.Background(), redisSrv.Addr(), "", 0, "tool_policy", 10*time.Minute,
context.Background(), redisSrv.Addr(), "", 0, "tool_policy", 10*time.Minute, nil,
)
if err != nil {
t.Fatalf("valkey store: %v", err)
Expand Down
Loading