From 76dbd6295a08f49149a88838ad113e2643149b0a Mon Sep 17 00:00:00 2001 From: Danial Beg Date: Mon, 16 Mar 2026 17:55:32 -0700 Subject: [PATCH] Fix codebase audit issues: tenant filtering, orphaned migrations, CLI parity - Add tenant_id filtering to QueryCostTimeseries interface and both SQLite/Postgres implementations - Add --tenant flag and Postgres driver support to CLI costs command - Pass tenant parameter through dashboard summary/timeseries handlers - Remove orphaned root-level migration files (only sqlite/ and postgres/ subdirectories are embedded) - Update all test stubs/mocks for new QueryCostTimeseries signature - Add .gitignore entry for built binary --- .gitignore | 1 + cmd/agentledger/costs.go | 21 +++++-- internal/budget/budget_test.go | 2 +- internal/dashboard/handlers.go | 19 +++--- internal/dashboard/handlers_test.go | 62 ++++++++++++++++++- internal/ledger/ledger.go | 4 +- .../migrations/001_create_usage_records.sql | 27 -------- .../migrations/002_create_agent_sessions.sql | 19 ------ .../ledger/migrations/003_add_tenant_id.sql | 9 --- .../migrations/004_create_admin_config.sql | 9 --- internal/ledger/postgres.go | 15 +++-- internal/ledger/postgres_test.go | 2 +- internal/ledger/recorder_test.go | 4 +- internal/ledger/sqlite.go | 15 +++-- internal/mcp/interceptor_test.go | 2 +- internal/proxy/proxy_test.go | 2 +- 16 files changed, 119 insertions(+), 94 deletions(-) delete mode 100644 internal/ledger/migrations/001_create_usage_records.sql delete mode 100644 internal/ledger/migrations/002_create_agent_sessions.sql delete mode 100644 internal/ledger/migrations/003_add_tenant_id.sql delete mode 100644 internal/ledger/migrations/004_create_admin_config.sql diff --git a/.gitignore b/.gitignore index aad6452..b7ab767 100644 --- a/.gitignore +++ b/.gitignore @@ -36,6 +36,7 @@ Thumbs.db # Build dist/ build/output/ +/agentledger # Node (dashboard) web/node_modules/ diff --git a/cmd/agentledger/costs.go b/cmd/agentledger/costs.go index 3f4eb5f..652d02d 100644 --- a/cmd/agentledger/costs.go +++ b/cmd/agentledger/costs.go @@ -31,30 +31,38 @@ func costsCmd() *cobra.Command { configPath string last string groupBy string + tenant string ) cmd := &cobra.Command{ Use: "costs", Short: "Show cost report", RunE: func(_ *cobra.Command, _ []string) error { - return runCosts(configPath, last, groupBy) + return runCosts(configPath, last, groupBy, tenant) }, } cmd.Flags().StringVarP(&configPath, "config", "c", "", "path to config file") cmd.Flags().StringVar(&last, "last", "24h", "time window (e.g., 1h, 24h, 7d)") cmd.Flags().StringVar(&groupBy, "by", "model", "group by: model, provider, key, agent, session") + cmd.Flags().StringVar(&tenant, "tenant", "", "filter by tenant ID") return cmd } -func runCosts(configPath, last, groupBy string) error { +func runCosts(configPath, last, groupBy, tenant string) error { cfg, err := config.Load(configPath) if err != nil { return err } - store, err := ledger.NewSQLite(cfg.Storage.DSN) + var store ledger.Ledger + switch cfg.Storage.Driver { + case "postgres": + store, err = ledger.NewPostgres(cfg.Storage.DSN, cfg.Storage.MaxOpenConns, cfg.Storage.MaxIdleConns) + default: + store, err = ledger.NewSQLite(cfg.Storage.DSN) + } if err != nil { return err } @@ -67,9 +75,10 @@ func runCosts(configPath, last, groupBy string) error { now := time.Now() filter := ledger.CostFilter{ - Since: now.Add(-window), - Until: now, - GroupBy: groupBy, + Since: now.Add(-window), + Until: now, + GroupBy: groupBy, + TenantID: tenant, } entries, err := store.QueryCosts(context.Background(), filter) diff --git a/internal/budget/budget_test.go b/internal/budget/budget_test.go index 3b0ed59..1e66c52 100644 --- a/internal/budget/budget_test.go +++ b/internal/budget/budget_test.go @@ -31,7 +31,7 @@ func (s *stubLedger) GetTotalSpend(_ context.Context, _ string, since, _ time.Ti func (s *stubLedger) GetTotalSpendByTenant(_ context.Context, _ string, _, _ time.Time) (float64, error) { return 0, nil } -func (s *stubLedger) QueryCostTimeseries(_ context.Context, _ string, _, _ time.Time) ([]ledger.TimeseriesPoint, error) { +func (s *stubLedger) QueryCostTimeseries(_ context.Context, _ string, _, _ time.Time, _ string) ([]ledger.TimeseriesPoint, error) { return nil, nil } func (s *stubLedger) Close() error { return nil } diff --git a/internal/dashboard/handlers.go b/internal/dashboard/handlers.go index c59a1d5..96978e0 100644 --- a/internal/dashboard/handlers.go +++ b/internal/dashboard/handlers.go @@ -33,12 +33,14 @@ func (h *Handler) handleSummary(w http.ResponseWriter, r *http.Request) { now := time.Now().UTC() dayStart := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, time.UTC) monthStart := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, time.UTC) + tenantID := r.URL.Query().Get("tenant") // Get today's costs by model. todayCosts, err := h.ledger.QueryCosts(r.Context(), ledger.CostFilter{ - Since: dayStart, - Until: now, - GroupBy: "model", + Since: dayStart, + Until: now, + GroupBy: "model", + TenantID: tenantID, }) if err != nil { writeError(w, http.StatusInternalServerError, err.Error()) @@ -54,9 +56,10 @@ func (h *Handler) handleSummary(w http.ResponseWriter, r *http.Request) { // Get month's costs. monthCosts, err := h.ledger.QueryCosts(r.Context(), ledger.CostFilter{ - Since: monthStart, - Until: now, - GroupBy: "model", + Since: monthStart, + Until: now, + GroupBy: "model", + TenantID: tenantID, }) if err != nil { writeError(w, http.StatusInternalServerError, err.Error()) @@ -93,10 +96,12 @@ func (h *Handler) handleTimeseries(w http.ResponseWriter, r *http.Request) { hours = 24 } + tenantID := r.URL.Query().Get("tenant") + now := time.Now().UTC() since := now.Add(-time.Duration(hours) * time.Hour) - points, err := h.ledger.QueryCostTimeseries(r.Context(), interval, since, now) + points, err := h.ledger.QueryCostTimeseries(r.Context(), interval, since, now, tenantID) if err != nil { writeError(w, http.StatusInternalServerError, err.Error()) return diff --git a/internal/dashboard/handlers_test.go b/internal/dashboard/handlers_test.go index 42471a7..28be7c0 100644 --- a/internal/dashboard/handlers_test.go +++ b/internal/dashboard/handlers_test.go @@ -26,7 +26,7 @@ func (s *stubLedger) GetTotalSpend(_ context.Context, _ string, _, _ time.Time) func (s *stubLedger) GetTotalSpendByTenant(_ context.Context, _ string, _, _ time.Time) (float64, error) { return 0, nil } -func (s *stubLedger) QueryCostTimeseries(_ context.Context, _ string, _, _ time.Time) ([]ledger.TimeseriesPoint, error) { +func (s *stubLedger) QueryCostTimeseries(_ context.Context, _ string, _, _ time.Time, _ string) ([]ledger.TimeseriesPoint, error) { return s.timeseries, nil } func (s *stubLedger) Close() error { return nil } @@ -120,6 +120,66 @@ func TestHandleCosts(t *testing.T) { } } +func TestHandleCostsWithTenant(t *testing.T) { + store := &stubLedger{ + costs: []ledger.CostEntry{ + {Model: "gpt-4o-mini", Requests: 3, TotalCostUSD: 0.15}, + }, + } + h := NewHandler(store, nil) + + mux := http.NewServeMux() + h.RegisterRoutes(mux) + + req := httptest.NewRequest("GET", "/api/dashboard/costs?group_by=model&tenant=alpha", nil) + w := httptest.NewRecorder() + mux.ServeHTTP(w, req) + + if w.Code != 200 { + t.Fatalf("status = %d, want 200", w.Code) + } +} + +func TestHandleSummaryWithTenant(t *testing.T) { + store := &stubLedger{ + costs: []ledger.CostEntry{ + {Model: "gpt-4o-mini", Requests: 5, TotalCostUSD: 0.25}, + }, + } + h := NewHandler(store, nil) + + mux := http.NewServeMux() + h.RegisterRoutes(mux) + + req := httptest.NewRequest("GET", "/api/dashboard/summary?tenant=beta", nil) + w := httptest.NewRecorder() + mux.ServeHTTP(w, req) + + if w.Code != 200 { + t.Fatalf("status = %d, want 200", w.Code) + } +} + +func TestHandleTimeseriesWithTenant(t *testing.T) { + store := &stubLedger{ + timeseries: []ledger.TimeseriesPoint{ + {Timestamp: time.Now(), CostUSD: 0.10, Requests: 2}, + }, + } + h := NewHandler(store, nil) + + mux := http.NewServeMux() + h.RegisterRoutes(mux) + + req := httptest.NewRequest("GET", "/api/dashboard/timeseries?tenant=gamma", nil) + w := httptest.NewRecorder() + mux.ServeHTTP(w, req) + + if w.Code != 200 { + t.Fatalf("status = %d, want 200", w.Code) + } +} + func TestHandleSessionsWithoutTracker(t *testing.T) { store := &stubLedger{} h := NewHandler(store, nil) diff --git a/internal/ledger/ledger.go b/internal/ledger/ledger.go index 90910d5..e5eb989 100644 --- a/internal/ledger/ledger.go +++ b/internal/ledger/ledger.go @@ -22,8 +22,8 @@ type Ledger interface { GetTotalSpendByTenant(ctx context.Context, tenantID string, since, until time.Time) (float64, error) // QueryCostTimeseries returns cost and request counts bucketed by time interval. - // interval should be "hour" or "day". - QueryCostTimeseries(ctx context.Context, interval string, since, until time.Time) ([]TimeseriesPoint, error) + // interval should be "hour" or "day". tenantID is optional (empty = all tenants). + QueryCostTimeseries(ctx context.Context, interval string, since, until time.Time, tenantID string) ([]TimeseriesPoint, error) // Close releases any held resources. Close() error diff --git a/internal/ledger/migrations/001_create_usage_records.sql b/internal/ledger/migrations/001_create_usage_records.sql deleted file mode 100644 index 06bfd9d..0000000 --- a/internal/ledger/migrations/001_create_usage_records.sql +++ /dev/null @@ -1,27 +0,0 @@ --- +goose Up -CREATE TABLE IF NOT EXISTS usage_records ( - id TEXT PRIMARY KEY, - timestamp DATETIME NOT NULL, - provider TEXT NOT NULL, - model TEXT NOT NULL, - api_key_hash TEXT NOT NULL DEFAULT '', - input_tokens INTEGER NOT NULL DEFAULT 0, - output_tokens INTEGER NOT NULL DEFAULT 0, - total_tokens INTEGER NOT NULL DEFAULT 0, - cost_usd REAL NOT NULL DEFAULT 0.0, - estimated BOOLEAN NOT NULL DEFAULT FALSE, - duration_ms INTEGER NOT NULL DEFAULT 0, - status_code INTEGER NOT NULL DEFAULT 0, - path TEXT NOT NULL DEFAULT '', - agent_id TEXT NOT NULL DEFAULT '', - session_id TEXT NOT NULL DEFAULT '', - user_id TEXT NOT NULL DEFAULT '' -); - -CREATE INDEX idx_usage_records_timestamp ON usage_records(timestamp); -CREATE INDEX idx_usage_records_api_key_hash ON usage_records(api_key_hash); -CREATE INDEX idx_usage_records_model ON usage_records(model); -CREATE INDEX idx_usage_records_provider ON usage_records(provider); - --- +goose Down -DROP TABLE IF EXISTS usage_records; diff --git a/internal/ledger/migrations/002_create_agent_sessions.sql b/internal/ledger/migrations/002_create_agent_sessions.sql deleted file mode 100644 index 1df9ba8..0000000 --- a/internal/ledger/migrations/002_create_agent_sessions.sql +++ /dev/null @@ -1,19 +0,0 @@ --- +goose Up -CREATE TABLE IF NOT EXISTS agent_sessions ( - id TEXT PRIMARY KEY, - agent_id TEXT NOT NULL DEFAULT '', - user_id TEXT NOT NULL DEFAULT '', - task TEXT NOT NULL DEFAULT '', - started_at DATETIME NOT NULL, - ended_at DATETIME, - status TEXT NOT NULL DEFAULT 'active', - call_count INTEGER NOT NULL DEFAULT 0, - total_cost_usd REAL NOT NULL DEFAULT 0.0, - total_tokens INTEGER NOT NULL DEFAULT 0 -); - -CREATE INDEX IF NOT EXISTS idx_sessions_status ON agent_sessions(status); -CREATE INDEX IF NOT EXISTS idx_sessions_agent_id ON agent_sessions(agent_id); - --- +goose Down -DROP TABLE IF EXISTS agent_sessions; diff --git a/internal/ledger/migrations/003_add_tenant_id.sql b/internal/ledger/migrations/003_add_tenant_id.sql deleted file mode 100644 index 5b20d91..0000000 --- a/internal/ledger/migrations/003_add_tenant_id.sql +++ /dev/null @@ -1,9 +0,0 @@ --- +goose Up -ALTER TABLE usage_records ADD COLUMN tenant_id TEXT NOT NULL DEFAULT ''; -ALTER TABLE agent_sessions ADD COLUMN tenant_id TEXT NOT NULL DEFAULT ''; -CREATE INDEX idx_usage_records_tenant_id ON usage_records(tenant_id); -CREATE INDEX idx_sessions_tenant_id ON agent_sessions(tenant_id); - --- +goose Down -DROP INDEX IF EXISTS idx_sessions_tenant_id; -DROP INDEX IF EXISTS idx_usage_records_tenant_id; diff --git a/internal/ledger/migrations/004_create_admin_config.sql b/internal/ledger/migrations/004_create_admin_config.sql deleted file mode 100644 index 68b48bc..0000000 --- a/internal/ledger/migrations/004_create_admin_config.sql +++ /dev/null @@ -1,9 +0,0 @@ --- +goose Up -CREATE TABLE IF NOT EXISTS admin_config ( - key TEXT PRIMARY KEY, - value TEXT NOT NULL DEFAULT '', - updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP -); - --- +goose Down -DROP TABLE IF EXISTS admin_config; diff --git a/internal/ledger/postgres.go b/internal/ledger/postgres.go index db548a4..7a423f0 100644 --- a/internal/ledger/postgres.go +++ b/internal/ledger/postgres.go @@ -127,22 +127,29 @@ func (p *Postgres) QueryCosts(ctx context.Context, filter CostFilter) ([]CostEnt return entries, rows.Err() } -func (p *Postgres) QueryCostTimeseries(ctx context.Context, interval string, since, until time.Time) ([]TimeseriesPoint, error) { +func (p *Postgres) QueryCostTimeseries(ctx context.Context, interval string, since, until time.Time, tenantID string) ([]TimeseriesPoint, error) { bucket := "date_trunc('hour', timestamp)" if interval == "day" { bucket = "date_trunc('day', timestamp)" } + where := "timestamp >= $1 AND timestamp <= $2" + args := []any{since.UTC(), until.UTC()} + if tenantID != "" { + args = append(args, tenantID) + where += fmt.Sprintf(" AND tenant_id = $%d", len(args)) + } + q := fmt.Sprintf(`SELECT %s as bucket, COALESCE(SUM(cost_usd), 0), COUNT(*) FROM usage_records - WHERE timestamp >= $1 AND timestamp <= $2 + WHERE %s GROUP BY bucket - ORDER BY bucket ASC`, bucket) + ORDER BY bucket ASC`, bucket, where) - rows, err := p.db.QueryContext(ctx, q, since.UTC(), until.UTC()) + rows, err := p.db.QueryContext(ctx, q, args...) if err != nil { return nil, fmt.Errorf("querying cost timeseries: %w", err) } diff --git a/internal/ledger/postgres_test.go b/internal/ledger/postgres_test.go index 97063ce..2a492b0 100644 --- a/internal/ledger/postgres_test.go +++ b/internal/ledger/postgres_test.go @@ -154,7 +154,7 @@ func TestPostgres_QueryCostTimeseries(t *testing.T) { } } - points, err := pg.QueryCostTimeseries(ctx, "hour", hourAgo.Add(-time.Minute), now.Add(time.Minute)) + points, err := pg.QueryCostTimeseries(ctx, "hour", hourAgo.Add(-time.Minute), now.Add(time.Minute), "") if err != nil { t.Fatalf("QueryCostTimeseries: %v", err) } diff --git a/internal/ledger/recorder_test.go b/internal/ledger/recorder_test.go index e45585a..b9307cb 100644 --- a/internal/ledger/recorder_test.go +++ b/internal/ledger/recorder_test.go @@ -30,7 +30,7 @@ func (c *countingLedger) GetTotalSpendByTenant(_ context.Context, _ string, _, _ return 0, nil } -func (c *countingLedger) QueryCostTimeseries(_ context.Context, _ string, _, _ time.Time) ([]TimeseriesPoint, error) { +func (c *countingLedger) QueryCostTimeseries(_ context.Context, _ string, _, _ time.Time, _ string) ([]TimeseriesPoint, error) { return nil, nil } @@ -53,7 +53,7 @@ func (f *failingLedger) GetTotalSpendByTenant(_ context.Context, _ string, _, _ return 0, nil } -func (f *failingLedger) QueryCostTimeseries(_ context.Context, _ string, _, _ time.Time) ([]TimeseriesPoint, error) { +func (f *failingLedger) QueryCostTimeseries(_ context.Context, _ string, _, _ time.Time, _ string) ([]TimeseriesPoint, error) { return nil, nil } diff --git a/internal/ledger/sqlite.go b/internal/ledger/sqlite.go index 0d0fd0c..4e810d1 100644 --- a/internal/ledger/sqlite.go +++ b/internal/ledger/sqlite.go @@ -129,7 +129,7 @@ func (s *SQLite) QueryCosts(ctx context.Context, filter CostFilter) ([]CostEntry return entries, rows.Err() } -func (s *SQLite) QueryCostTimeseries(ctx context.Context, interval string, since, until time.Time) ([]TimeseriesPoint, error) { +func (s *SQLite) QueryCostTimeseries(ctx context.Context, interval string, since, until time.Time, tenantID string) ([]TimeseriesPoint, error) { // Go's time.Time stores as "2006-01-02 15:04:05.999999 +0000 UTC" in SQLite, // but strftime only parses ISO8601. Use substr to extract the datetime portion. bucket := "strftime('%Y-%m-%d %H:00:00', substr(timestamp, 1, 19))" @@ -137,16 +137,23 @@ func (s *SQLite) QueryCostTimeseries(ctx context.Context, interval string, since bucket = "strftime('%Y-%m-%d 00:00:00', substr(timestamp, 1, 19))" } + where := "timestamp >= ? AND timestamp <= ?" + args := []any{since.UTC(), until.UTC()} + if tenantID != "" { + where += " AND tenant_id = ?" + args = append(args, tenantID) + } + q := fmt.Sprintf(`SELECT %s as bucket, COALESCE(SUM(cost_usd), 0), COUNT(*) FROM usage_records - WHERE timestamp >= ? AND timestamp <= ? + WHERE %s GROUP BY bucket - ORDER BY bucket ASC`, bucket) + ORDER BY bucket ASC`, bucket, where) - rows, err := s.db.QueryContext(ctx, q, since.UTC(), until.UTC()) + rows, err := s.db.QueryContext(ctx, q, args...) if err != nil { return nil, fmt.Errorf("querying cost timeseries: %w", err) } diff --git a/internal/mcp/interceptor_test.go b/internal/mcp/interceptor_test.go index 03aa6bb..e15abf8 100644 --- a/internal/mcp/interceptor_test.go +++ b/internal/mcp/interceptor_test.go @@ -38,7 +38,7 @@ func (r *recordingLedger) GetTotalSpendByTenant(_ context.Context, _ string, _, return 0, nil } -func (r *recordingLedger) QueryCostTimeseries(_ context.Context, _ string, _, _ time.Time) ([]ledger.TimeseriesPoint, error) { +func (r *recordingLedger) QueryCostTimeseries(_ context.Context, _ string, _, _ time.Time, _ string) ([]ledger.TimeseriesPoint, error) { return nil, nil } diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index 9936d4e..b94ef4b 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -43,7 +43,7 @@ func (m *mockStore) GetTotalSpendByTenant(_ context.Context, _ string, _, _ time return m.totalSpend, nil } -func (m *mockStore) QueryCostTimeseries(_ context.Context, _ string, _, _ time.Time) ([]ledger.TimeseriesPoint, error) { +func (m *mockStore) QueryCostTimeseries(_ context.Context, _ string, _, _ time.Time, _ string) ([]ledger.TimeseriesPoint, error) { return nil, nil }