diff --git a/telemetry/DESIGN.md b/telemetry/DESIGN.md index c239ea0..a9107c7 100644 --- a/telemetry/DESIGN.md +++ b/telemetry/DESIGN.md @@ -1422,7 +1422,7 @@ func checkFeatureFlag(ctx context.Context, host string, httpClient *http.Client) // Add query parameters q := req.URL.Query() - q.Add("flags", "databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForAdbc") + q.Add("flags", "databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForGoDriver") req.URL.RawQuery = q.Encode() resp, err := httpClient.Do(req) @@ -1442,7 +1442,7 @@ func checkFeatureFlag(ctx context.Context, host string, httpClient *http.Client) return false, err } - return result.Flags["databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForAdbc"], nil + return result.Flags["databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForGoDriver"], nil } ``` @@ -1743,7 +1743,7 @@ func BenchmarkInterceptor_Disabled(b *testing.B) { - [x] Add unit tests for configuration and tags ### Phase 2: Per-Host Management -- [ ] Implement `featureflag.go` with caching and reference counting +- [x] Implement `featureflag.go` with caching and reference counting (PECOBLR-1146) - [ ] Implement `manager.go` for client management - [ ] Implement `circuitbreaker.go` with state machine - [ ] Add unit tests for all components diff --git a/telemetry/featureflag.go b/telemetry/featureflag.go new file mode 100644 index 0000000..6f19781 --- /dev/null +++ b/telemetry/featureflag.go @@ -0,0 +1,156 @@ +package telemetry + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "sync" + "time" +) + +// featureFlagCache manages feature flag state per host with reference counting. +// This prevents rate limiting by caching feature flag responses. +type featureFlagCache struct { + mu sync.RWMutex + contexts map[string]*featureFlagContext +} + +// featureFlagContext holds feature flag state and reference count for a host. +type featureFlagContext struct { + enabled *bool + lastFetched time.Time + refCount int + cacheDuration time.Duration +} + +var ( + flagCacheOnce sync.Once + flagCacheInstance *featureFlagCache +) + +// getFeatureFlagCache returns the singleton instance. +func getFeatureFlagCache() *featureFlagCache { + flagCacheOnce.Do(func() { + flagCacheInstance = &featureFlagCache{ + contexts: make(map[string]*featureFlagContext), + } + }) + return flagCacheInstance +} + +// getOrCreateContext gets or creates a feature flag context for the host. +// Increments reference count. +func (c *featureFlagCache) getOrCreateContext(host string) *featureFlagContext { + c.mu.Lock() + defer c.mu.Unlock() + + ctx, exists := c.contexts[host] + if !exists { + ctx = &featureFlagContext{ + cacheDuration: 15 * time.Minute, + } + c.contexts[host] = ctx + } + ctx.refCount++ + return ctx +} + +// releaseContext decrements reference count for the host. +// Removes context when ref count reaches zero. +func (c *featureFlagCache) releaseContext(host string) { + c.mu.Lock() + defer c.mu.Unlock() + + if ctx, exists := c.contexts[host]; exists { + ctx.refCount-- + if ctx.refCount <= 0 { + delete(c.contexts, host) + } + } +} + +// isTelemetryEnabled checks if telemetry is enabled for the host. +// Uses cached value if available and not expired. +func (c *featureFlagCache) isTelemetryEnabled(ctx context.Context, host string, httpClient *http.Client) (bool, error) { + c.mu.RLock() + flagCtx, exists := c.contexts[host] + c.mu.RUnlock() + + if !exists { + return false, nil + } + + // Check if cache is valid + if flagCtx.enabled != nil && time.Since(flagCtx.lastFetched) < flagCtx.cacheDuration { + return *flagCtx.enabled, nil + } + + // Fetch fresh value + enabled, err := fetchFeatureFlag(ctx, host, httpClient) + if err != nil { + // Return cached value on error, or false if no cache + if flagCtx.enabled != nil { + return *flagCtx.enabled, nil + } + return false, err + } + + // Update cache + c.mu.Lock() + flagCtx.enabled = &enabled + flagCtx.lastFetched = time.Now() + c.mu.Unlock() + + return enabled, nil +} + +// isExpired returns true if the cache has expired. +func (c *featureFlagContext) isExpired() bool { + return c.enabled == nil || time.Since(c.lastFetched) > c.cacheDuration +} + +// fetchFeatureFlag fetches the feature flag value from Databricks. +func fetchFeatureFlag(ctx context.Context, host string, httpClient *http.Client) (bool, error) { + // Construct endpoint URL, adding https:// if not already present + var endpoint string + if len(host) > 7 && (host[:7] == "http://" || host[:8] == "https://") { + endpoint = fmt.Sprintf("%s/api/2.0/feature-flags", host) + } else { + endpoint = fmt.Sprintf("https://%s/api/2.0/feature-flags", host) + } + + req, err := http.NewRequestWithContext(ctx, "GET", endpoint, nil) + if err != nil { + return false, fmt.Errorf("failed to create feature flag request: %w", err) + } + + // Add query parameter for the specific feature flag + q := req.URL.Query() + q.Add("flags", "databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForGoDriver") + req.URL.RawQuery = q.Encode() + + resp, err := httpClient.Do(req) + if err != nil { + return false, fmt.Errorf("failed to fetch feature flag: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return false, fmt.Errorf("feature flag check failed: %d", resp.StatusCode) + } + + var result struct { + Flags map[string]bool `json:"flags"` + } + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return false, fmt.Errorf("failed to decode feature flag response: %w", err) + } + + enabled, ok := result.Flags["databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForGoDriver"] + if !ok { + return false, nil + } + + return enabled, nil +} diff --git a/telemetry/featureflag_test.go b/telemetry/featureflag_test.go new file mode 100644 index 0000000..b45fc8f --- /dev/null +++ b/telemetry/featureflag_test.go @@ -0,0 +1,445 @@ +package telemetry + +import ( + "context" + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" +) + +func TestGetFeatureFlagCache_Singleton(t *testing.T) { + // Reset singleton for testing + flagCacheInstance = nil + flagCacheOnce = sync.Once{} + + cache1 := getFeatureFlagCache() + cache2 := getFeatureFlagCache() + + if cache1 != cache2 { + t.Error("Expected singleton instances to be the same") + } +} + +func TestFeatureFlagCache_GetOrCreateContext(t *testing.T) { + cache := &featureFlagCache{ + contexts: make(map[string]*featureFlagContext), + } + + host := "test-host.databricks.com" + + // First call should create context and increment refCount to 1 + ctx1 := cache.getOrCreateContext(host) + if ctx1 == nil { + t.Fatal("Expected context to be created") + } + if ctx1.refCount != 1 { + t.Errorf("Expected refCount to be 1, got %d", ctx1.refCount) + } + + // Second call should reuse context and increment refCount to 2 + ctx2 := cache.getOrCreateContext(host) + if ctx2 != ctx1 { + t.Error("Expected to get the same context instance") + } + if ctx2.refCount != 2 { + t.Errorf("Expected refCount to be 2, got %d", ctx2.refCount) + } + + // Verify cache duration is set + if ctx1.cacheDuration != 15*time.Minute { + t.Errorf("Expected cache duration to be 15 minutes, got %v", ctx1.cacheDuration) + } +} + +func TestFeatureFlagCache_ReleaseContext(t *testing.T) { + cache := &featureFlagCache{ + contexts: make(map[string]*featureFlagContext), + } + + host := "test-host.databricks.com" + + // Create context with refCount = 2 + cache.getOrCreateContext(host) + cache.getOrCreateContext(host) + + // First release should decrement to 1 + cache.releaseContext(host) + ctx, exists := cache.contexts[host] + if !exists { + t.Fatal("Expected context to still exist") + } + if ctx.refCount != 1 { + t.Errorf("Expected refCount to be 1, got %d", ctx.refCount) + } + + // Second release should remove context + cache.releaseContext(host) + _, exists = cache.contexts[host] + if exists { + t.Error("Expected context to be removed when refCount reaches 0") + } + + // Release non-existent context should not panic + cache.releaseContext("non-existent-host") +} + +func TestFeatureFlagCache_IsTelemetryEnabled_Cached(t *testing.T) { + cache := &featureFlagCache{ + contexts: make(map[string]*featureFlagContext), + } + + host := "test-host.databricks.com" + ctx := cache.getOrCreateContext(host) + + // Set cached value + enabled := true + ctx.enabled = &enabled + ctx.lastFetched = time.Now() + + // Should return cached value without HTTP call + result, err := cache.isTelemetryEnabled(context.Background(), host, nil) + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + if result != true { + t.Error("Expected cached value to be returned") + } +} + +func TestFeatureFlagCache_IsTelemetryEnabled_Expired(t *testing.T) { + // Create mock server + callCount := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + callCount++ + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"flags": {"databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForGoDriver": true}}`)) + })) + defer server.Close() + + cache := &featureFlagCache{ + contexts: make(map[string]*featureFlagContext), + } + + host := server.URL // Use full URL for testing + ctx := cache.getOrCreateContext(host) + + // Set expired cached value + enabled := false + ctx.enabled = &enabled + ctx.lastFetched = time.Now().Add(-20 * time.Minute) // Expired + + // Should fetch fresh value + httpClient := &http.Client{} + result, err := cache.isTelemetryEnabled(context.Background(), host, httpClient) + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + if result != true { + t.Error("Expected fresh value to be fetched and returned") + } + if callCount != 1 { + t.Errorf("Expected HTTP call to be made once, got %d calls", callCount) + } + + // Verify cache was updated + if *ctx.enabled != true { + t.Error("Expected cache to be updated with new value") + } +} + +func TestFeatureFlagCache_IsTelemetryEnabled_NoContext(t *testing.T) { + cache := &featureFlagCache{ + contexts: make(map[string]*featureFlagContext), + } + + host := "non-existent-host.databricks.com" + + // Should return false for non-existent context + result, err := cache.isTelemetryEnabled(context.Background(), host, nil) + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + if result != false { + t.Error("Expected false for non-existent context") + } +} + +func TestFeatureFlagCache_IsTelemetryEnabled_ErrorFallback(t *testing.T) { + // Create mock server that returns error + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer server.Close() + + cache := &featureFlagCache{ + contexts: make(map[string]*featureFlagContext), + } + + host := server.URL // Use full URL for testing + ctx := cache.getOrCreateContext(host) + + // Set cached value + enabled := true + ctx.enabled = &enabled + ctx.lastFetched = time.Now().Add(-20 * time.Minute) // Expired + + // Should return cached value on error + httpClient := &http.Client{} + result, err := cache.isTelemetryEnabled(context.Background(), host, httpClient) + if err != nil { + t.Errorf("Expected no error (fallback to cache), got %v", err) + } + if result != true { + t.Error("Expected cached value to be returned on fetch error") + } +} + +func TestFeatureFlagCache_IsTelemetryEnabled_ErrorNoCache(t *testing.T) { + // Create mock server that returns error + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer server.Close() + + cache := &featureFlagCache{ + contexts: make(map[string]*featureFlagContext), + } + + host := server.URL // Use full URL for testing + cache.getOrCreateContext(host) + + // No cached value, should return error + httpClient := &http.Client{} + result, err := cache.isTelemetryEnabled(context.Background(), host, httpClient) + if err == nil { + t.Error("Expected error when no cache available and fetch fails") + } + if result != false { + t.Error("Expected false when no cache available and fetch fails") + } +} + +func TestFeatureFlagCache_ConcurrentAccess(t *testing.T) { + cache := &featureFlagCache{ + contexts: make(map[string]*featureFlagContext), + } + + host := "test-host.databricks.com" + numGoroutines := 100 + + var wg sync.WaitGroup + wg.Add(numGoroutines) + + // Concurrent getOrCreateContext + for i := 0; i < numGoroutines; i++ { + go func() { + defer wg.Done() + cache.getOrCreateContext(host) + }() + } + wg.Wait() + + // Verify refCount + ctx, exists := cache.contexts[host] + if !exists { + t.Fatal("Expected context to exist") + } + if ctx.refCount != numGoroutines { + t.Errorf("Expected refCount to be %d, got %d", numGoroutines, ctx.refCount) + } + + // Concurrent releaseContext + wg.Add(numGoroutines) + for i := 0; i < numGoroutines; i++ { + go func() { + defer wg.Done() + cache.releaseContext(host) + }() + } + wg.Wait() + + // Verify context is removed + _, exists = cache.contexts[host] + if exists { + t.Error("Expected context to be removed after all releases") + } +} + +func TestFeatureFlagContext_IsExpired(t *testing.T) { + tests := []struct { + name string + enabled *bool + fetched time.Time + duration time.Duration + want bool + }{ + { + name: "no cache", + enabled: nil, + fetched: time.Time{}, + duration: 15 * time.Minute, + want: true, + }, + { + name: "fresh cache", + enabled: boolPtr(true), + fetched: time.Now(), + duration: 15 * time.Minute, + want: false, + }, + { + name: "expired cache", + enabled: boolPtr(true), + fetched: time.Now().Add(-20 * time.Minute), + duration: 15 * time.Minute, + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := &featureFlagContext{ + enabled: tt.enabled, + lastFetched: tt.fetched, + cacheDuration: tt.duration, + } + if got := ctx.isExpired(); got != tt.want { + t.Errorf("isExpired() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestFetchFeatureFlag_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify request + if r.Method != "GET" { + t.Errorf("Expected GET request, got %s", r.Method) + } + if r.URL.Path != "/api/2.0/feature-flags" { + t.Errorf("Expected /api/2.0/feature-flags path, got %s", r.URL.Path) + } + + flags := r.URL.Query().Get("flags") + expectedFlag := "databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForGoDriver" + if flags != expectedFlag { + t.Errorf("Expected flag query param %s, got %s", expectedFlag, flags) + } + + // Return success response + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"flags": {"databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForGoDriver": true}}`)) + })) + defer server.Close() + + host := server.URL // Use full URL for testing + httpClient := &http.Client{} + + enabled, err := fetchFeatureFlag(context.Background(), host, httpClient) + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + if !enabled { + t.Error("Expected feature flag to be enabled") + } +} + +func TestFetchFeatureFlag_Disabled(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"flags": {"databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForGoDriver": false}}`)) + })) + defer server.Close() + + host := server.URL // Use full URL for testing + httpClient := &http.Client{} + + enabled, err := fetchFeatureFlag(context.Background(), host, httpClient) + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + if enabled { + t.Error("Expected feature flag to be disabled") + } +} + +func TestFetchFeatureFlag_FlagNotPresent(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"flags": {}}`)) + })) + defer server.Close() + + host := server.URL // Use full URL for testing + httpClient := &http.Client{} + + enabled, err := fetchFeatureFlag(context.Background(), host, httpClient) + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + if enabled { + t.Error("Expected feature flag to be false when not present") + } +} + +func TestFetchFeatureFlag_HTTPError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer server.Close() + + host := server.URL // Use full URL for testing + httpClient := &http.Client{} + + _, err := fetchFeatureFlag(context.Background(), host, httpClient) + if err == nil { + t.Error("Expected error for HTTP 500") + } +} + +func TestFetchFeatureFlag_InvalidJSON(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`invalid json`)) + })) + defer server.Close() + + host := server.URL // Use full URL for testing + httpClient := &http.Client{} + + _, err := fetchFeatureFlag(context.Background(), host, httpClient) + if err == nil { + t.Error("Expected error for invalid JSON") + } +} + +func TestFetchFeatureFlag_ContextCancellation(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(100 * time.Millisecond) + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + host := server.URL // Use full URL for testing + httpClient := &http.Client{} + + ctx, cancel := context.WithCancel(context.Background()) + cancel() // Cancel immediately + + _, err := fetchFeatureFlag(ctx, host, httpClient) + if err == nil { + t.Error("Expected error for cancelled context") + } +} + +// Helper function to create bool pointer +func boolPtr(b bool) *bool { + return &b +}