diff --git a/go.mod b/go.mod index 6a2a36f0c..caae622c9 100644 --- a/go.mod +++ b/go.mod @@ -99,7 +99,11 @@ require ( github.com/wasilibs/go-pgquery v0.0.0-20250409022910-10ac41983c07 go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.68.0 go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.68.0 + go.opentelemetry.io/contrib/propagators/b3 v1.20.0 + go.opentelemetry.io/contrib/propagators/ot v1.20.0 go.opentelemetry.io/otel v1.43.0 + go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.34.0 + go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.43.0 go.opentelemetry.io/otel/exporters/prometheus v0.62.0 go.opentelemetry.io/otel/sdk v1.43.0 go.opentelemetry.io/otel/sdk/metric v1.43.0 @@ -252,11 +256,7 @@ require ( go.opencensus.io v0.24.0 // indirect go.opentelemetry.io/auto/sdk v1.2.1 // indirect go.opentelemetry.io/contrib/detectors/gcp v1.43.0 // indirect - go.opentelemetry.io/contrib/propagators/b3 v1.20.0 // indirect - go.opentelemetry.io/contrib/propagators/ot v1.20.0 // indirect go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.43.0 // indirect - go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.34.0 // indirect - go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.43.0 // indirect go.opentelemetry.io/otel/metric v1.43.0 // indirect go.opentelemetry.io/proto/otlp v1.10.0 // indirect go.uber.org/automaxprocs v1.6.0 // indirect diff --git a/pkg/cmd/serve.go b/pkg/cmd/serve.go index 5048097ce..92dff882c 100644 --- a/pkg/cmd/serve.go +++ b/pkg/cmd/serve.go @@ -9,9 +9,9 @@ import ( "github.com/fatih/color" "github.com/jzelinskie/cobrautil/v2" - "github.com/jzelinskie/cobrautil/v2/cobraotel" "github.com/spf13/cobra" + log "github.com/authzed/spicedb/internal/logging" "github.com/authzed/spicedb/internal/telemetry" "github.com/authzed/spicedb/pkg/cmd/datastore" "github.com/authzed/spicedb/pkg/cmd/server" @@ -219,11 +219,8 @@ func RegisterServeFlags(cmd *cobra.Command, config *server.Config) error { return fmt.Errorf("could not register stored schema cache flags: %w", err) } - tracingFlags := nfs.FlagSet(BoldBlue("Tracing")) // Flags for tracing - // NOTE: cobraotel.New takes service name as an arg rather than command name. - otel := cobraotel.New("spicedb") - otel.RegisterFlags(tracingFlags) + server.RegisterOTelFlags(cmd) loggingFlagSet := nfs.FlagSet(BoldBlue("Logging")) loggingFlagSet.BoolVar(&config.EnableRequestLogs, "grpc-log-requests-enabled", false, "enable logging of API request payloads") @@ -265,14 +262,23 @@ func NewServeCommand(programName string, config *server.Config) *cobra.Command { Long: "start a SpiceDB server", PreRunE: server.DefaultPreRunE(programName), RunE: termination.PublishError(func(cmd *cobra.Command, args []string) error { - server, err := config.Complete(cmd.Context()) + srv, err := config.Complete(cmd.Context()) if err != nil { return err } signalctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) defer stop() - return server.Run(signalctx) + defer func() { + // Shutdown OTel provider to ensure all traces are flushed + if provider := server.OTelProviderFromContext(cmd.Context()); provider != nil { + if err := server.ShutdownOTelProvider(context.Background(), provider); err != nil { + log.Warn().Err(err).Msg("failed to cleanly shutdown OpenTelemetry provider") + } + } + }() + + return srv.Run(signalctx) }), Example: server.ServeExample(programName), } diff --git a/pkg/cmd/server/defaults.go b/pkg/cmd/server/defaults.go index febda4da7..a6ef3d41a 100644 --- a/pkg/cmd/server/defaults.go +++ b/pkg/cmd/server/defaults.go @@ -19,7 +19,6 @@ import ( grpclog "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/logging" "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/selector" "github.com/jzelinskie/cobrautil/v2" - "github.com/jzelinskie/cobrautil/v2/cobraotel" "github.com/jzelinskie/cobrautil/v2/cobraproclimits" "github.com/jzelinskie/cobrautil/v2/cobrazerolog" "github.com/prometheus/client_golang/prometheus" @@ -82,9 +81,7 @@ func DefaultPreRunE(programName string) cobrautil.CobraRunFunc { // and zero under the same load and 0.9 cobraproclimits.SetMemLimitRunE(memlimit.WithRatio(0.9)), cobraproclimits.SetProcLimitRunE(), - cobraotel.New("spicedb", - cobraotel.WithLogger(zerologr.New(&logging.Logger)), - ).RunE(), + OTelPreRunE, releases.CheckAndLogRunE(), runtime.RunE(), ) diff --git a/pkg/cmd/server/otel.go b/pkg/cmd/server/otel.go new file mode 100644 index 000000000..7fbbf077e --- /dev/null +++ b/pkg/cmd/server/otel.go @@ -0,0 +1,252 @@ +// pkg/cmd/server/otel.go +// +// This file replicates the OpenTelemetry provider initialization previously +// provided by github.com/jzelinskie/cobrautil/v2/cobraotel. +// +// Motivation: +// - Issue #712: otel-provider defaults to "none" so OTEL_* env vars alone +// cannot activate tracing. Native ownership lets us address this properly. +// - Issue #3095: cobraotel owns the TracerProvider with no way for the +// signal handler to call Shutdown/ForceFlush. Traces are dropped on +// SIGTERM. Native ownership closes this lifecycle gap. + +package server + +import ( + "context" + "fmt" + "strings" + "time" + + "github.com/spf13/cobra" + "go.opentelemetry.io/contrib/propagators/b3" + "go.opentelemetry.io/contrib/propagators/ot" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc" + "go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp" + "go.opentelemetry.io/otel/propagation" + "go.opentelemetry.io/otel/sdk/resource" + sdktrace "go.opentelemetry.io/otel/sdk/trace" + semconv "go.opentelemetry.io/otel/semconv/v1.7.0" +) + +// otelProviderContextKey is the unexported context key for the TracerProvider. +type otelProviderContextKey struct{} + +// OTelShutdownTimeout is the maximum time for flushing spans on shutdown. +const OTelShutdownTimeout = 30 * time.Second + +// otelShutdowner is the interface satisfied by *sdktrace.TracerProvider. +// Using a local interface instead of the concrete type makes +// ShutdownOTelProvider testable with a mock without importing sdktrace. +type otelShutdowner interface { + Shutdown(ctx context.Context) error + ForceFlush(ctx context.Context) error +} + +// RegisterOTelFlags registers all OpenTelemetry flags on cmd. +// The flags registered here are identical in name and default value to those +// previously registered by cobraotel.RegisterFlags. +func RegisterOTelFlags(cmd *cobra.Command) { + f := cmd.Flags() + f.String("otel-provider", "none", + `OpenTelemetry provider for tracing ("none", "otlpgrpc", "otlphttp")`) + f.String("otel-endpoint", "", + `OpenTelemetry collector endpoint - the endpoint can also be set by using enviroment variables`) + f.String("otel-service-name", "spicedb", + `service name for trace data`) + f.String("otel-trace-propagator", "w3c", + `OpenTelemetry trace propagation format ("b3", "w3c", "ottrace"). Add multiple propagators separated by comma.`) + f.Bool("otel-insecure", false, + `connect to the OpenTelemetry collector in plaintext`) + f.StringToString("otel-headers", nil, + `key=value pairs sent as headers to the OTel provider`) + f.Float64("otel-sample-ratio", 0.01, + `ratio of traces that are sampled`) + + // Legacy flags + f.String("otel-jaeger-endpoint", "", "OpenTelemetry collector endpoint - the endpoint can also be set by using enviroment variables") + _ = f.MarkHidden("otel-jaeger-endpoint") + f.String("otel-jaeger-service-name", "spicedb", "service name for trace data") + _ = f.MarkHidden("otel-jaeger-service-name") +} + +// InitOTelProvider reads otel-* flags from cmd, builds a TracerProvider, sets +// it as the global OTel provider, and returns it for lifecycle management. +// +// Returns (nil, nil) when otel-provider is "none". Callers must handle nil. +func InitOTelProvider(cmd *cobra.Command) (otelShutdowner, error) { + flags := cmd.Flags() + + // cobra commands may have no context when called outside Execute() + ctx := cmd.Context() + if ctx == nil { + ctx = context.Background() + } + + providerName, err := flags.GetString("otel-provider") + if err != nil { + return nil, fmt.Errorf("reading otel-provider: %w", err) + } + providerName = strings.TrimSpace(strings.ToLower(providerName)) + + if providerName == "none" || providerName == "" { + return nil, nil + } + + endpoint, err := flags.GetString("otel-endpoint") + if err != nil { + return nil, fmt.Errorf("reading otel-endpoint: %w", err) + } + + serviceName, err := flags.GetString("otel-service-name") + if err != nil { + return nil, fmt.Errorf("reading otel-service-name: %w", err) + } + + propagatorName, err := flags.GetString("otel-trace-propagator") + if err != nil { + return nil, fmt.Errorf("reading otel-trace-propagator: %w", err) + } + + insecureConn, err := flags.GetBool("otel-insecure") + if err != nil { + return nil, fmt.Errorf("reading otel-insecure: %w", err) + } + + headers, err := flags.GetStringToString("otel-headers") + if err != nil { + return nil, fmt.Errorf("reading otel-headers: %w", err) + } + + sampleRatio, err := flags.GetFloat64("otel-sample-ratio") + if err != nil { + return nil, fmt.Errorf("reading otel-sample-ratio: %w", err) + } + + res, err := resource.New(ctx, + resource.WithAttributes(semconv.ServiceNameKey.String(serviceName)), + resource.WithProcess(), + resource.WithOS(), + resource.WithHost(), + ) + if err != nil { + return nil, fmt.Errorf("building OTel resource: %w", err) + } + + var exporter sdktrace.SpanExporter + + switch providerName { + case "otlpgrpc": + opts := []otlptracegrpc.Option{} + if endpoint != "" { + opts = append(opts, otlptracegrpc.WithEndpoint(endpoint)) + } + if insecureConn { + opts = append(opts, otlptracegrpc.WithInsecure()) + } + if len(headers) > 0 { + opts = append(opts, otlptracegrpc.WithHeaders(headers)) + } + exp, err := otlptracegrpc.New(ctx, opts...) + if err != nil { + return nil, fmt.Errorf("creating otlpgrpc exporter: %w", err) + } + exporter = exp + + case "otlphttp": + opts := []otlptracehttp.Option{} + if endpoint != "" { + opts = append(opts, otlptracehttp.WithEndpoint(endpoint)) + } + if insecureConn { + opts = append(opts, otlptracehttp.WithInsecure()) + } + if len(headers) > 0 { + opts = append(opts, otlptracehttp.WithHeaders(headers)) + } + exp, err := otlptracehttp.New(ctx, opts...) + if err != nil { + return nil, fmt.Errorf("creating otlphttp exporter: %w", err) + } + exporter = exp + + default: + return nil, fmt.Errorf( + "unknown otel-provider %q: must be one of: none, otlpgrpc, otlphttp", + providerName, + ) + } + + tp := sdktrace.NewTracerProvider( + sdktrace.WithSampler(sdktrace.ParentBased(sdktrace.TraceIDRatioBased(sampleRatio))), + sdktrace.WithBatcher(exporter), + sdktrace.WithResource(res), + ) + + otel.SetTracerProvider(tp) + otel.SetTextMapPropagator(buildPropagator(propagatorName)) + + return tp, nil +} + +// ShutdownOTelProvider flushes all pending spans then shuts the provider down. +// ForceFlush is always called before Shutdown. Safe to call with nil (no-op). +func ShutdownOTelProvider(ctx context.Context, provider otelShutdowner) error { + if provider == nil { + return nil + } + + flushCtx, flushCancel := context.WithTimeout(ctx, OTelShutdownTimeout) + defer flushCancel() + if err := provider.ForceFlush(flushCtx); err != nil { + // Log but continue — Shutdown must still be attempted. + fmt.Printf("otel: ForceFlush error (continuing to Shutdown): %v\n", err) + } + + shutCtx, shutCancel := context.WithTimeout(ctx, OTelShutdownTimeout) + defer shutCancel() + return provider.Shutdown(shutCtx) +} + +// OTelPreRunE is a cobra.PreRunE function that initializes the OTel provider +// and stores it on the command context so the shutdown handler can retrieve it. +func OTelPreRunE(cmd *cobra.Command, _ []string) error { + provider, err := InitOTelProvider(cmd) + if err != nil { + return fmt.Errorf("initializing OTel provider: %w", err) + } + parent := cmd.Context() + if parent == nil { + parent = context.Background() + } + ctx := context.WithValue(parent, otelProviderContextKey{}, provider) + cmd.SetContext(ctx) + return nil +} + +// OTelProviderFromContext retrieves the otelShutdowner stored by OTelPreRunE. +// Returns nil if no provider was initialized (e.g. otel-provider was "none"). +func OTelProviderFromContext(ctx context.Context) otelShutdowner { + v, _ := ctx.Value(otelProviderContextKey{}).(otelShutdowner) + return v +} + +// buildPropagator returns the TextMapPropagator for the given name. +func buildPropagator(names string) propagation.TextMapPropagator { + var tmPropagators []propagation.TextMapPropagator + for _, p := range strings.Split(names, ",") { + switch strings.ToLower(strings.TrimSpace(p)) { + case "b3": + tmPropagators = append(tmPropagators, b3.New()) + case "ottrace": + tmPropagators = append(tmPropagators, ot.OT{}) + case "w3c": + fallthrough + default: + tmPropagators = append(tmPropagators, propagation.Baggage{}) + tmPropagators = append(tmPropagators, propagation.TraceContext{}) + } + } + return propagation.NewCompositeTextMapPropagator(tmPropagators...) +} diff --git a/pkg/cmd/server/otel_integration_test.go b/pkg/cmd/server/otel_integration_test.go new file mode 100644 index 000000000..7d0b35719 --- /dev/null +++ b/pkg/cmd/server/otel_integration_test.go @@ -0,0 +1,82 @@ +//go:build integration + +// pkg/cmd/server/otel_integration_test.go +package server + +import ( + "context" + "testing" + "time" + + "github.com/spf13/cobra" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestOTelIntegration_FullChain_EnvToProvider simulates the full +// DefaultPreRunE chain with OTel flags set and verifies the TracerProvider +// is non-nil in the command context after OTelPreRunE executes. +func TestOTelIntegration_FullChain_EnvToProvider(t *testing.T) { + cmd := &cobra.Command{Use: "test"} + RegisterOTelFlags(cmd) + cmd.SetContext(context.Background()) + require.NoError(t, cmd.Flags().Set("otel-provider", "otlpgrpc")) + require.NoError(t, cmd.Flags().Set("otel-endpoint", "localhost:4317")) + require.NoError(t, cmd.Flags().Set("otel-insecure", "true")) + + require.NoError(t, OTelPreRunE(cmd, nil)) + + provider := OTelProviderFromContext(cmd.Context()) + require.NotNil(t, provider, "TracerProvider must be non-nil after OTelPreRunE") + + t.Cleanup(func() { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _ = ShutdownOTelProvider(ctx, provider) + }) +} + +// TestOTelIntegration_ShutdownOnSignal verifies that ShutdownOTelProvider +// completes without error when called as a signal handler would call it. +func TestOTelIntegration_ShutdownOnSignal(t *testing.T) { + mock := &mockShutdowner{} + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + err := ShutdownOTelProvider(ctx, mock) + require.NoError(t, err) + assert.True(t, mock.shutdownCalled, "expected Shutdown to be called") + assert.True(t, mock.forceFlushCalled, "expected ForceFlush to be called") +} + +// TestOTelIntegration_FlushBeforeShutdown verifies ForceFlush is called +// before Shutdown — flush-then-shutdown is the required ordering. +func TestOTelIntegration_FlushBeforeShutdown(t *testing.T) { + callOrder := []string{} + provider := &callOrderShutdowner{callLog: &callOrder} + + err := ShutdownOTelProvider(context.Background(), provider) + require.NoError(t, err) + require.Len(t, callOrder, 2) + assert.Equal(t, "ForceFlush", callOrder[0], + "ForceFlush must be called before Shutdown") + assert.Equal(t, "Shutdown", callOrder[1]) +} + +// TestOTelIntegration_NoneProvider_SafeShutdown verifies that when +// OTelPreRunE ran with provider=none, the resulting nil provider can be +// passed to ShutdownOTelProvider without error. +func TestOTelIntegration_NoneProvider_SafeShutdown(t *testing.T) { + cmd := &cobra.Command{Use: "test"} + RegisterOTelFlags(cmd) + cmd.SetContext(context.Background()) + // otel-provider defaults to "none" + + require.NoError(t, OTelPreRunE(cmd, nil)) + + provider := OTelProviderFromContext(cmd.Context()) + assert.Nil(t, provider) + + err := ShutdownOTelProvider(context.Background(), provider) + assert.NoError(t, err) +} diff --git a/pkg/cmd/server/otel_system_test.go b/pkg/cmd/server/otel_system_test.go new file mode 100644 index 000000000..1af906217 --- /dev/null +++ b/pkg/cmd/server/otel_system_test.go @@ -0,0 +1,157 @@ +//go:build system + +// pkg/cmd/server/otel_system_test.go +// +// System tests use an in-process OTLP gRPC collector to verify end-to-end +// span delivery. They require no external services. +package server + +import ( + "context" + "fmt" + "net" + "sync" + "testing" + "time" + + "github.com/spf13/cobra" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel" + collectortrace "go.opentelemetry.io/proto/otlp/collector/trace/v1" + tracev1 "go.opentelemetry.io/proto/otlp/trace/v1" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" +) + +// --------------------------------------------------------------------------- +// In-process OTLP collector +// --------------------------------------------------------------------------- + +// inProcessCollector implements the OTLP TraceService gRPC server and records +// all received ResourceSpans so tests can assert on span delivery. +type inProcessCollector struct { + collectortrace.UnimplementedTraceServiceServer + mu sync.Mutex + spans []*tracev1.ResourceSpans +} + +func (c *inProcessCollector) Export( + _ context.Context, + req *collectortrace.ExportTraceServiceRequest, +) (*collectortrace.ExportTraceServiceResponse, error) { + c.mu.Lock() + defer c.mu.Unlock() + c.spans = append(c.spans, req.ResourceSpans...) + return &collectortrace.ExportTraceServiceResponse{}, nil +} + +// ReceivedSpanCount returns the total number of individual spans received. +func (c *inProcessCollector) ReceivedSpanCount() int { + c.mu.Lock() + defer c.mu.Unlock() + n := 0 + for _, rs := range c.spans { + for _, scope := range rs.ScopeSpans { + n += len(scope.Spans) + } + } + return n +} + +// startCollector starts an in-process OTLP gRPC collector on a random port. +// Returns the listener address, the collector, and a cleanup function. +func startCollector(t *testing.T) (string, *inProcessCollector, func()) { + t.Helper() + lis, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err, "failed to start in-process collector listener") + + collector := &inProcessCollector{} + srv := grpc.NewServer(grpc.Creds(insecure.NewCredentials())) + collectortrace.RegisterTraceServiceServer(srv, collector) + + go func() { _ = srv.Serve(lis) }() + + return lis.Addr().String(), collector, func() { srv.GracefulStop() } +} + +// --------------------------------------------------------------------------- +// System tests +// --------------------------------------------------------------------------- + +// TestOTelSystem_SpansDeliveredToCollector starts an in-process OTLP +// collector, initializes the OTel provider pointing at it, creates a test +// span, force-flushes, and verifies the collector received at least one span. +func TestOTelSystem_SpansDeliveredToCollector(t *testing.T) { + addr, collector, cleanup := startCollector(t) + defer cleanup() + + cmd := &cobra.Command{Use: "test"} + RegisterOTelFlags(cmd) + cmd.SetContext(context.Background()) + require.NoError(t, cmd.Flags().Set("otel-provider", "otlpgrpc")) + require.NoError(t, cmd.Flags().Set("otel-endpoint", addr)) + require.NoError(t, cmd.Flags().Set("otel-insecure", "true")) + require.NoError(t, cmd.Flags().Set("otel-service-name", "spicedb-system-test")) + require.NoError(t, cmd.Flags().Set("otel-sample-ratio", "1.0")) + + provider, err := InitOTelProvider(cmd) + require.NoError(t, err) + require.NotNil(t, provider) + t.Cleanup(func() { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + _ = ShutdownOTelProvider(ctx, provider) + }) + + tracer := otel.Tracer("system-test") + _, span := tracer.Start(context.Background(), "test-span") + span.End() + + flushCtx, flushCancel := context.WithTimeout(context.Background(), 10*time.Second) + defer flushCancel() + require.NoError(t, provider.ForceFlush(flushCtx)) + + require.Eventually(t, + func() bool { return collector.ReceivedSpanCount() >= 1 }, + 5*time.Second, 100*time.Millisecond, + "expected at least 1 span to be received by the in-process collector", + ) +} + +// TestOTelSystem_SpansNotDroppedOnShutdown verifies that spans buffered in +// the BatchSpanProcessor are flushed before the provider shuts down, fixing +// the data-loss scenario described in issue #3095. +func TestOTelSystem_SpansNotDroppedOnShutdown(t *testing.T) { + const spanCount = 5 + + addr, collector, cleanup := startCollector(t) + defer cleanup() + + cmd := &cobra.Command{Use: "test"} + RegisterOTelFlags(cmd) + cmd.SetContext(context.Background()) + require.NoError(t, cmd.Flags().Set("otel-provider", "otlpgrpc")) + require.NoError(t, cmd.Flags().Set("otel-endpoint", addr)) + require.NoError(t, cmd.Flags().Set("otel-insecure", "true")) + require.NoError(t, cmd.Flags().Set("otel-sample-ratio", "1.0")) + + provider, err := InitOTelProvider(cmd) + require.NoError(t, err) + require.NotNil(t, provider) + + tracer := otel.Tracer("shutdown-test") + for i := 0; i < spanCount; i++ { + _, span := tracer.Start(context.Background(), fmt.Sprintf("span-%d", i)) + span.End() + } + + // Shut down immediately — spans must be flushed before exit. + shutCtx, shutCancel := context.WithTimeout(context.Background(), 30*time.Second) + defer shutCancel() + require.NoError(t, ShutdownOTelProvider(shutCtx, provider), + "ShutdownOTelProvider must not error") + + assert.GreaterOrEqual(t, collector.ReceivedSpanCount(), spanCount, + "all buffered spans must be delivered before Shutdown returns") +} diff --git a/pkg/cmd/server/otel_test.go b/pkg/cmd/server/otel_test.go new file mode 100644 index 000000000..c1f942044 --- /dev/null +++ b/pkg/cmd/server/otel_test.go @@ -0,0 +1,236 @@ +// pkg/cmd/server/otel_test.go +package server + +import ( + "context" + "fmt" + "testing" + + "github.com/spf13/cobra" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +func newTestCmd(t *testing.T) *cobra.Command { + t.Helper() + cmd := &cobra.Command{ + Use: "test", + RunE: func(cmd *cobra.Command, args []string) error { return nil }, + } + RegisterOTelFlags(cmd) + cmd.SetContext(context.Background()) + return cmd +} + +// mockShutdowner is a test double that records calls to Shutdown/ForceFlush. +type mockShutdowner struct { + shutdownCalled bool + forceFlushCalled bool + shutdownErr error + forceFlushErr error +} + +func (m *mockShutdowner) Shutdown(_ context.Context) error { + m.shutdownCalled = true + return m.shutdownErr +} + +func (m *mockShutdowner) ForceFlush(_ context.Context) error { + m.forceFlushCalled = true + return m.forceFlushErr +} + +// callOrderShutdowner records the order Shutdown/ForceFlush are called. +type callOrderShutdowner struct { + callLog *[]string +} + +func (c *callOrderShutdowner) ForceFlush(_ context.Context) error { + *c.callLog = append(*c.callLog, "ForceFlush") + return nil +} + +func (c *callOrderShutdowner) Shutdown(_ context.Context) error { + *c.callLog = append(*c.callLog, "Shutdown") + return nil +} + +// --------------------------------------------------------------------------- +// RegisterOTelFlags +// --------------------------------------------------------------------------- + +// TestRegisterOTelFlags_AllFlagsPresent verifies all OTel flags are +// registered with correct names after calling RegisterOTelFlags. +func TestRegisterOTelFlags_AllFlagsPresent(t *testing.T) { + cmd := &cobra.Command{Use: "test"} + RegisterOTelFlags(cmd) + + for _, name := range []string{ + "otel-provider", + "otel-endpoint", + "otel-service-name", + "otel-trace-propagator", + "otel-insecure", + "otel-headers", + } { + assert.NotNil(t, cmd.Flags().Lookup(name), + "expected flag %q to be registered", name) + } +} + +// TestRegisterOTelFlags_ProviderDefault verifies otel-provider defaults to "none". +func TestRegisterOTelFlags_ProviderDefault(t *testing.T) { + cmd := &cobra.Command{Use: "test"} + RegisterOTelFlags(cmd) + val, err := cmd.Flags().GetString("otel-provider") + require.NoError(t, err) + assert.Equal(t, "none", val) +} + +// --------------------------------------------------------------------------- +// InitOTelProvider +// --------------------------------------------------------------------------- + +// TestInitOTelProvider_NoneSkipsInit verifies provider=none returns (nil, nil) +// without attempting any network connection. +func TestInitOTelProvider_NoneSkipsInit(t *testing.T) { + cmd := newTestCmd(t) + require.NoError(t, cmd.Flags().Set("otel-provider", "none")) + provider, err := InitOTelProvider(cmd) + require.NoError(t, err) + assert.Nil(t, provider) +} + +// TestInitOTelProvider_UnknownProviderReturnsError verifies an unrecognized +// provider string returns a non-nil error containing the bad value. +func TestInitOTelProvider_UnknownProviderReturnsError(t *testing.T) { + cmd := newTestCmd(t) + require.NoError(t, cmd.Flags().Set("otel-provider", "bogusprovider")) + _, err := InitOTelProvider(cmd) + require.Error(t, err) + assert.Contains(t, err.Error(), "bogusprovider") +} + +// TestInitOTelProvider_OtlpGrpc_ValidEndpoint verifies otlpgrpc initializes +// without error. No live collector required — connection errors surface only +// on first export, not at initialization. +func TestInitOTelProvider_OtlpGrpc_ValidEndpoint(t *testing.T) { + cmd := newTestCmd(t) + require.NoError(t, cmd.Flags().Set("otel-provider", "otlpgrpc")) + require.NoError(t, cmd.Flags().Set("otel-endpoint", "localhost:4317")) + require.NoError(t, cmd.Flags().Set("otel-insecure", "true")) + + provider, err := InitOTelProvider(cmd) + require.NoError(t, err) + assert.NotNil(t, provider) + t.Cleanup(func() { _ = provider.Shutdown(context.Background()) }) +} + +// TestInitOTelProvider_OtlpHttp_ValidEndpoint verifies otlphttp initializes +// without error. No live collector required. +func TestInitOTelProvider_OtlpHttp_ValidEndpoint(t *testing.T) { + cmd := newTestCmd(t) + require.NoError(t, cmd.Flags().Set("otel-provider", "otlphttp")) + require.NoError(t, cmd.Flags().Set("otel-endpoint", "localhost:4318")) + require.NoError(t, cmd.Flags().Set("otel-insecure", "true")) + + provider, err := InitOTelProvider(cmd) + require.NoError(t, err) + assert.NotNil(t, provider) + t.Cleanup(func() { _ = provider.Shutdown(context.Background()) }) +} + +// --------------------------------------------------------------------------- +// ShutdownOTelProvider +// --------------------------------------------------------------------------- + +// TestShutdownOTelProvider_NilProvider_NoError verifies nil provider is safe. +func TestShutdownOTelProvider_NilProvider_NoError(t *testing.T) { + err := ShutdownOTelProvider(context.Background(), nil) + assert.NoError(t, err) +} + +// TestShutdownOTelProvider_CallsFlushThenShutdown verifies ForceFlush is +// called before Shutdown, and both are called exactly once. +func TestShutdownOTelProvider_CallsFlushThenShutdown(t *testing.T) { + callOrder := []string{} + provider := &callOrderShutdowner{callLog: &callOrder} + + err := ShutdownOTelProvider(context.Background(), provider) + require.NoError(t, err) + require.Len(t, callOrder, 2) + assert.Equal(t, "ForceFlush", callOrder[0], "ForceFlush must be called before Shutdown") + assert.Equal(t, "Shutdown", callOrder[1]) +} + +// TestShutdownOTelProvider_ShutdownErrorPropagated verifies that an error +// from Shutdown is returned to the caller. +func TestShutdownOTelProvider_ShutdownErrorPropagated(t *testing.T) { + mock := &mockShutdowner{shutdownErr: fmt.Errorf("shutdown failed")} + err := ShutdownOTelProvider(context.Background(), mock) + require.Error(t, err) + assert.Contains(t, err.Error(), "shutdown failed") +} + +// TestShutdownOTelProvider_ForceFlushErrorContinuesToShutdown verifies that +// a ForceFlush error does not prevent Shutdown from being called. +func TestShutdownOTelProvider_ForceFlushErrorContinuesToShutdown(t *testing.T) { + mock := &mockShutdowner{forceFlushErr: fmt.Errorf("flush failed")} + _ = ShutdownOTelProvider(context.Background(), mock) + assert.True(t, mock.shutdownCalled, + "Shutdown must be called even when ForceFlush errors") +} + +// TestShutdownOTelProvider_ContextCancelled verifies a cancelled context +// produces no panic. The shutdown error (if any) is returned normally. +func TestShutdownOTelProvider_ContextCancelled(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + mock := &mockShutdowner{shutdownErr: context.Canceled} + err := ShutdownOTelProvider(ctx, mock) + _ = err // cancelled context may or may not surface — no panic is the guarantee +} + +// --------------------------------------------------------------------------- +// OTelPreRunE and OTelProviderFromContext +// --------------------------------------------------------------------------- + +// TestOTelPreRunE_StoresProviderInContext verifies that after OTelPreRunE +// runs with a real provider, OTelProviderFromContext returns non-nil. +func TestOTelPreRunE_StoresProviderInContext(t *testing.T) { + cmd := newTestCmd(t) + require.NoError(t, cmd.Flags().Set("otel-provider", "otlpgrpc")) + require.NoError(t, cmd.Flags().Set("otel-endpoint", "localhost:4317")) + require.NoError(t, cmd.Flags().Set("otel-insecure", "true")) + + err := OTelPreRunE(cmd, nil) + require.NoError(t, err) + + provider := OTelProviderFromContext(cmd.Context()) + assert.NotNil(t, provider, "TracerProvider must be stored in context") + t.Cleanup(func() { _ = provider.Shutdown(context.Background()) }) +} + +// TestOTelPreRunE_NoneIsNoop verifies provider=none results in nil context +// value and no error. +func TestOTelPreRunE_NoneIsNoop(t *testing.T) { + cmd := newTestCmd(t) + // otel-provider defaults to "none" — no Set needed + + err := OTelPreRunE(cmd, nil) + require.NoError(t, err) + + provider := OTelProviderFromContext(cmd.Context()) + assert.Nil(t, provider) +} + +// TestOTelProviderFromContext_MissingKey verifies that retrieving a provider +// from a context where OTelPreRunE was never called returns nil without panic. +func TestOTelProviderFromContext_MissingKey(t *testing.T) { + provider := OTelProviderFromContext(context.Background()) + assert.Nil(t, provider) +} diff --git a/pkg/cmd/util/util.go b/pkg/cmd/util/util.go index f7c605846..7505b09fa 100644 --- a/pkg/cmd/util/util.go +++ b/pkg/cmd/util/util.go @@ -13,7 +13,7 @@ import ( "net/http" "time" - "github.com/jzelinskie/cobrautil/v2/cobraotel" + _ "github.com/mostynb/go-grpc-compression/experimental/s2" // Register Snappy S2 compression "github.com/rs/zerolog" "github.com/spf13/cobra" @@ -439,8 +439,19 @@ func (d *disabledHTTPServer) Close() {} // so that they were shared across all commands, but this // made it difficult to organize the flags, so we lifted them here. func RegisterCommonFlags(cmd *cobra.Command) { - otel := cobraotel.New("spicedb") - otel.RegisterFlags(cmd.Flags()) + f := cmd.Flags() + f.String("otel-provider", "none", `OpenTelemetry provider for tracing ("none", "otlpgrpc", "otlphttp")`) + f.String("otel-endpoint", "", `OpenTelemetry collector endpoint - the endpoint can also be set by using enviroment variables`) + f.String("otel-service-name", "spicedb", `service name for trace data`) + f.String("otel-trace-propagator", "w3c", `OpenTelemetry trace propagation format ("b3", "w3c", "ottrace"). Add multiple propagators separated by comma.`) + f.Bool("otel-insecure", false, `connect to the OpenTelemetry collector in plaintext`) + f.StringToString("otel-headers", nil, `key=value pairs sent as headers to the OTel provider`) + f.Float64("otel-sample-ratio", 0.01, `ratio of traces that are sampled`) + + f.String("otel-jaeger-endpoint", "", "OpenTelemetry collector endpoint - the endpoint can also be set by using enviroment variables") + _ = f.MarkHidden("otel-jaeger-endpoint") + f.String("otel-jaeger-service-name", "spicedb", "service name for trace data") + _ = f.MarkHidden("otel-jaeger-service-name") termination.RegisterFlags(cmd.Flags()) runtime.RegisterFlags(cmd.Flags()) }