diff --git a/csharp/Directory.Packages.props b/csharp/Directory.Packages.props index 58713e76..c98b74e0 100644 --- a/csharp/Directory.Packages.props +++ b/csharp/Directory.Packages.props @@ -35,6 +35,8 @@ + + diff --git a/csharp/doc/telemetry-design.md b/csharp/doc/telemetry-design.md index 9e0717b7..f012b2dd 100644 --- a/csharp/doc/telemetry-design.md +++ b/csharp/doc/telemetry-design.md @@ -176,69 +176,419 @@ sequenceDiagram ### 3.1 FeatureFlagCache (Per-Host) -**Purpose**: Cache feature flag values at the host level to avoid repeated API calls and rate limiting. +**Purpose**: Cache **all** feature flag values at the host level to avoid repeated API calls and rate limiting. This is a generic cache that can be used for any driver configuration controlled by server-side feature flags, not just telemetry. -**Location**: `Apache.Arrow.Adbc.Drivers.Databricks.Telemetry.FeatureFlagCache` +**Location**: `AdbcDrivers.Databricks.FeatureFlagCache` (note: not in Telemetry namespace - this is a general-purpose component) #### Rationale +- **Generic feature flag support**: Cache returns all flags, allowing any driver feature to be controlled server-side - **Per-host caching**: Feature flags cached by host (not per connection) to prevent rate limiting - **Reference counting**: Tracks number of connections per host for proper cleanup -- **Automatic expiration**: Refreshes cached flags after TTL expires (15 minutes) +- **Server-controlled TTL**: Refresh interval controlled by server-provided `ttl_seconds` (default: 15 minutes) +- **Background refresh**: Scheduled refresh at server-specified intervals - **Thread-safe**: Uses ConcurrentDictionary for concurrent access from multiple connections +#### Configuration Priority Order + +Feature flags are integrated directly into the existing ADBC driver property parsing logic as an **extra layer** in the property value resolution. The priority order is: + +``` +1. User-specified properties (highest priority) +2. Feature flags from server +3. Driver default values (lowest priority) +``` + +**Integration Approach**: Feature flags are merged into the `Properties` dictionary at connection initialization time. This means: +- The existing `Properties.TryGetValue()` pattern continues to work unchanged +- Feature flags are transparently available as properties +- No changes needed to existing property parsing code + +```mermaid +flowchart LR + A[User Properties] --> D[Merged Properties] + B[Feature Flags] --> D + C[Driver Defaults] --> D + D --> E[Properties.TryGetValue] + E --> F[Existing Parsing Logic] +``` + +**Merge Logic** (in `DatabricksConnection` initialization): +```csharp +// Current flow: +// 1. User properties from connection string/config +// 2. Environment properties from DATABRICKS_CONFIG_FILE + +// New flow with feature flags: +// 1. User properties from connection string/config (highest priority) +// 2. Feature flags from server (middle priority) +// 3. Environment properties / driver defaults (lowest priority) + +private Dictionary MergePropertiesWithFeatureFlags( + Dictionary userProperties, + IReadOnlyDictionary featureFlags) +{ + var merged = new Dictionary(StringComparer.OrdinalIgnoreCase); + + // Start with feature flags as base (lower priority) + foreach (var flag in featureFlags) + { + // Map feature flag names to property names if needed + string propertyName = MapFeatureFlagToPropertyName(flag.Key); + if (propertyName != null) + { + merged[propertyName] = flag.Value; + } + } + + // Override with user properties (higher priority) + foreach (var prop in userProperties) + { + merged[prop.Key] = prop.Value; + } + + return merged; +} +``` + +**Feature Flag to Property Name Mapping**: +```csharp +// Feature flags have long names, map to driver property names +private static readonly Dictionary FeatureFlagToPropertyMap = new() +{ + ["databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForAdbc"] = "telemetry.enabled", + ["databricks.partnerplatform.clientConfigsFeatureFlags.enableCloudFetch"] = "cloudfetch.enabled", + // ... more mappings +}; +``` + +This approach: +- **Preserves existing code**: All `Properties.TryGetValue()` calls work unchanged +- **Transparent integration**: Feature flags appear as regular properties after merge +- **Clear priority**: User settings always win over server flags +- **Single merge point**: Feature flag integration happens once at connection initialization +- **Fresh values per connection**: Each new connection uses the latest cached feature flag values + +#### Per-Connection Property Resolution + +Each new connection applies property merging with the **latest** cached feature flag values: + +```mermaid +sequenceDiagram + participant C1 as Connection 1 + participant C2 as Connection 2 + participant FFC as FeatureFlagCache + participant BG as Background Refresh + + Note over FFC: Cache has flags v1 + + C1->>FFC: GetOrCreateContext(host) + FFC-->>C1: context (refCount=1) + C1->>C1: Merge properties with flags v1 + Note over C1: Properties frozen with v1 + + BG->>FFC: Refresh flags + Note over FFC: Cache updated to flags v2 + + C2->>FFC: GetOrCreateContext(host) + FFC-->>C2: context (refCount=2) + C2->>C2: Merge properties with flags v2 + Note over C2: Properties frozen with v2 + + Note over C1,C2: C1 uses v1, C2 uses v2 +``` + +**Key Points**: +- **Shared cache, per-connection merge**: The `FeatureFlagCache` is shared (per-host), but property merging happens at each connection initialization +- **Latest values for new connections**: When a new connection is created, it reads the current cached values (which may have been updated by background refresh) +- **Stable within connection**: Once merged, a connection's `Properties` are stable for its lifetime (no mid-connection changes) +- **Background refresh benefits new connections**: The scheduled refresh ensures new connections get up-to-date flag values without waiting for a fetch + +#### Feature Flag API + +**Endpoint**: `GET /api/2.0/connector-service/feature-flags/OSS_JDBC/{driver_version}` + +> **Note**: Currently using the JDBC endpoint (`OSS_JDBC`) until the ADBC endpoint (`OSS_ADBC`) is configured server-side. The feature flag name will still use `enableTelemetryForAdbc` to distinguish ADBC telemetry from JDBC telemetry. + +Where `{driver_version}` is the driver version (e.g., `1.0.0`). + +**Request Headers**: +- `Authorization`: Bearer token (same as connection auth) +- `User-Agent`: Custom user agent for connector service + +**Response Format** (JSON): +```json +{ + "flags": [ + { + "name": "databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForAdbc", + "value": "true" + }, + { + "name": "databricks.partnerplatform.clientConfigsFeatureFlags.enableCloudFetch", + "value": "true" + }, + { + "name": "databricks.partnerplatform.clientConfigsFeatureFlags.maxDownloadThreads", + "value": "10" + } + ], + "ttl_seconds": 900 +} +``` + +**Response Fields**: +- `flags`: Array of feature flag entries with `name` and `value` (string). Names can be mapped to driver property names. +- `ttl_seconds`: Server-controlled refresh interval in seconds (default: 900 = 15 minutes) + +**JDBC Reference**: See `DatabricksDriverFeatureFlagsContext.java:30-33` for endpoint format. + +#### Refresh Strategy + +The feature flag cache follows the JDBC driver pattern: + +1. **Initial Blocking Fetch**: On connection open, make a blocking HTTP call to fetch all feature flags +2. **Cache All Flags**: Store all returned flags in a local cache (Guava Cache in JDBC, ConcurrentDictionary in C#) +3. **Scheduled Background Refresh**: Start a daemon thread that refreshes flags at intervals based on `ttl_seconds` +4. **Dynamic TTL**: If server returns a different `ttl_seconds`, reschedule the refresh interval + +```mermaid +sequenceDiagram + participant Conn as Connection.Open + participant FFC as FeatureFlagContext + participant Server as Databricks Server + + Conn->>FFC: GetOrCreateContext(host) + FFC->>Server: GET /api/2.0/connector-service/feature-flags/OSS_JDBC/{version} + Note over FFC,Server: Blocking initial fetch + Server-->>FFC: {flags: [...], ttl_seconds: 900} + FFC->>FFC: Cache all flags + FFC->>FFC: Schedule refresh at ttl_seconds interval + + loop Every ttl_seconds + FFC->>Server: GET /api/2.0/connector-service/feature-flags/OSS_JDBC/{version} + Server-->>FFC: {flags: [...], ttl_seconds: N} + FFC->>FFC: Update cache, reschedule if TTL changed + end +``` + +**JDBC Reference**: See `DatabricksDriverFeatureFlagsContext.java:48-58` for initial fetch and scheduling. + +#### HTTP Client Pattern + +The feature flag cache does **not** use a separate dedicated HTTP client. Instead, it reuses the connection's existing HTTP client infrastructure: + +```mermaid +graph LR + A[DatabricksConnection] -->|provides| B[HttpClient] + A -->|provides| C[Auth Headers] + B --> D[FeatureFlagContext] + C --> D + D -->|HTTP GET| E[Feature Flag Endpoint] +``` + +**Key Points**: +1. **Reuse connection's HttpClient**: The `FeatureFlagContext` receives the connection's `HttpClient` (already configured with base address, timeouts, etc.) +2. **Reuse connection's authentication**: Auth headers (Bearer token) come from the connection's authentication mechanism +3. **Custom User-Agent**: Set a connector-service-specific User-Agent header for the feature flag requests + +**JDBC Implementation** (`DatabricksDriverFeatureFlagsContext.java:89-105`): +```java +// Get shared HTTP client from connection +IDatabricksHttpClient httpClient = + DatabricksHttpClientFactory.getInstance().getClient(connectionContext); + +// Create request +HttpGet request = new HttpGet(featureFlagEndpoint); + +// Set custom User-Agent for connector service +request.setHeader("User-Agent", + UserAgentManager.buildUserAgentForConnectorService(connectionContext)); + +// Add auth headers from connection's auth config +DatabricksClientConfiguratorManager.getInstance() + .getConfigurator(connectionContext) + .getDatabricksConfig() + .authenticate() + .forEach(request::addHeader); +``` + +**C# Equivalent Pattern**: +```csharp +// In DatabricksConnection - create HttpClient for feature flags +private HttpClient CreateFeatureFlagHttpClient() +{ + var handler = HiveServer2TlsImpl.NewHttpClientHandler(TlsOptions, _proxyConfigurator); + var httpClient = new HttpClient(handler); + + // Set base address + httpClient.BaseAddress = new Uri($"https://{_host}"); + + // Set auth header (reuse connection's token) + if (Properties.TryGetValue(SparkParameters.Token, out string? token)) + { + httpClient.DefaultRequestHeaders.Authorization = + new AuthenticationHeaderValue("Bearer", token); + } + + // Set custom User-Agent for connector service + httpClient.DefaultRequestHeaders.UserAgent.ParseAdd( + BuildConnectorServiceUserAgent()); + + return httpClient; +} + +// Pass to FeatureFlagContext +var context = featureFlagCache.GetOrCreateContext(_host, CreateFeatureFlagHttpClient()); +``` + +This approach: +- Avoids duplicating HTTP client configuration +- Ensures consistent authentication across all API calls +- Allows proper resource cleanup when connection closes + #### Interface ```csharp -namespace Apache.Arrow.Adbc.Drivers.Databricks.Telemetry +namespace AdbcDrivers.Databricks { /// /// Singleton that manages feature flag cache per host. /// Prevents rate limiting by caching feature flag responses. + /// This is a generic cache for all feature flags, not just telemetry. /// internal sealed class FeatureFlagCache { - private static readonly FeatureFlagCache Instance = new(); - public static FeatureFlagCache GetInstance() => Instance; + private static readonly FeatureFlagCache s_instance = new FeatureFlagCache(); + public static FeatureFlagCache GetInstance() => s_instance; /// /// Gets or creates a feature flag context for the host. /// Increments reference count. + /// Makes initial blocking fetch if context is new. /// - public FeatureFlagContext GetOrCreateContext(string host); + public FeatureFlagContext GetOrCreateContext(string host, HttpClient httpClient, string driverVersion); /// /// Decrements reference count for the host. - /// Removes context when ref count reaches zero. + /// Removes context and stops refresh scheduler when ref count reaches zero. /// public void ReleaseContext(string host); + } + /// + /// Holds feature flag state and reference count for a host. + /// Manages background refresh scheduling. + /// Uses the HttpClient provided by the connection for API calls. + /// + internal sealed class FeatureFlagContext : IDisposable + { /// - /// Checks if telemetry is enabled for the host. - /// Uses cached value if available and not expired. + /// Creates a new context with the given HTTP client. + /// Makes initial blocking fetch to populate cache. + /// Starts background refresh scheduler. /// - public Task IsTelemetryEnabledAsync( - string host, - HttpClient httpClient, - CancellationToken ct = default); + /// The Databricks host. + /// + /// HttpClient from the connection, pre-configured with: + /// - Base address (https://{host}) + /// - Auth headers (Bearer token) + /// - Custom User-Agent for connector service + /// + public FeatureFlagContext(string host, HttpClient httpClient); + + public int RefCount { get; } + public TimeSpan RefreshInterval { get; } // From server ttl_seconds + + /// + /// Gets a feature flag value by name. + /// Returns null if the flag is not found. + /// + public string? GetFlagValue(string flagName); + + /// + /// Checks if a feature flag is enabled (value is "true"). + /// Returns false if flag is not found or value is not "true". + /// + public bool IsFeatureEnabled(string flagName); + + /// + /// Gets all cached feature flags as a dictionary. + /// Can be used to merge with user properties. + /// + public IReadOnlyDictionary GetAllFlags(); + + /// + /// Stops the background refresh scheduler. + /// + public void Shutdown(); + + public void Dispose(); } /// - /// Holds feature flag state and reference count for a host. + /// Response model for feature flags API. /// - internal sealed class FeatureFlagContext + internal sealed class FeatureFlagsResponse { - public bool? TelemetryEnabled { get; set; } - public DateTime? LastFetched { get; set; } - public int RefCount { get; set; } - public TimeSpan CacheDuration { get; } = TimeSpan.FromMinutes(15); + public List? Flags { get; set; } + public int? TtlSeconds { get; set; } + } - public bool IsExpired => LastFetched == null || - DateTime.UtcNow - LastFetched.Value > CacheDuration; + internal sealed class FeatureFlagEntry + { + public string Name { get; set; } = string.Empty; + public string Value { get; set; } = string.Empty; } } ``` -**JDBC Reference**: `DatabricksDriverFeatureFlagsContextFactory.java:27` maintains per-compute (host) feature flag contexts with reference counting. +#### Usage Example + +```csharp +// In DatabricksConnection constructor/initialization +// This runs for EACH new connection, using LATEST cached feature flags + +// Step 1: Get or create feature flag context +// - If context exists: returns existing context with latest cached flags +// - If new: creates context, does initial blocking fetch, starts background refresh +var featureFlagCache = FeatureFlagCache.GetInstance(); +var featureFlagContext = featureFlagCache.GetOrCreateContext(_host, CreateFeatureFlagHttpClient()); + +// Step 2: Merge feature flags into properties using LATEST cached values +// Each new connection gets a fresh merge with current flag values +Properties = MergePropertiesWithFeatureFlags( + userProperties, + featureFlagContext.GetAllFlags()); // Returns current cached flags + +// Step 3: Existing property parsing works unchanged! +// Feature flags are now transparently available as properties +bool IsTelemetryEnabled() +{ + // This works whether the value came from: + // - User property (highest priority) + // - Feature flag (merged in) + // - Or falls back to driver default + if (Properties.TryGetValue("telemetry.enabled", out var value)) + { + return bool.TryParse(value, out var result) && result; + } + return true; // Driver default +} + +// Same pattern for all other properties - no changes needed! +if (Properties.TryGetValue(DatabricksParameters.CloudFetchEnabled, out var cfValue)) +{ + // Value could be from user OR from feature flag - transparent! +} +``` + +**Key Benefits**: +- Existing code like `Properties.TryGetValue()` continues to work unchanged +- Each new connection uses the **latest** cached feature flag values +- Feature flag integration is a one-time merge at connection initialization +- Properties are stable for the lifetime of the connection (no mid-connection changes) + +**JDBC Reference**: `DatabricksDriverFeatureFlagsContextFactory.java:27` maintains per-compute (host) feature flag contexts with reference counting. `DatabricksDriverFeatureFlagsContext.java` implements the caching, refresh scheduling, and API calls. --- @@ -257,7 +607,7 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.Telemetry #### Interface ```csharp -namespace Apache.Arrow.Adbc.Drivers.Databricks.Telemetry +namespace AdbcDrivers.Databricks.Telemetry { /// /// Singleton factory that manages one telemetry client per host. @@ -319,7 +669,7 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.Telemetry #### Interface ```csharp -namespace Apache.Arrow.Adbc.Drivers.Databricks.Telemetry +namespace AdbcDrivers.Databricks.Telemetry { /// /// Wraps telemetry exporter with circuit breaker pattern. @@ -372,7 +722,7 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.Telemetry #### Interface ```csharp -namespace Apache.Arrow.Adbc.Drivers.Databricks.Telemetry +namespace AdbcDrivers.Databricks.Telemetry { /// /// Custom ActivityListener that aggregates metrics from Activity events @@ -460,7 +810,7 @@ private ActivityListener CreateListener() #### Interface ```csharp -namespace Apache.Arrow.Adbc.Drivers.Databricks.Telemetry +namespace AdbcDrivers.Databricks.Telemetry { /// /// Aggregates metrics from activities by statement_id and includes session_id. @@ -549,38 +899,58 @@ flowchart TD **Purpose**: Export aggregated metrics to Databricks telemetry service. -**Location**: `Apache.Arrow.Adbc.Drivers.Databricks.Telemetry.DatabricksTelemetryExporter` +**Location**: `AdbcDrivers.Databricks.Telemetry.DatabricksTelemetryExporter` + +**Status**: Implemented (WI-3.4) #### Interface ```csharp -namespace Apache.Arrow.Adbc.Drivers.Databricks.Telemetry +namespace AdbcDrivers.Databricks.Telemetry { public interface ITelemetryExporter { /// - /// Export metrics to Databricks service. Never throws. + /// Export telemetry frontend logs to the backend service. + /// Never throws exceptions (all swallowed and logged at TRACE level). /// Task ExportAsync( - IReadOnlyList metrics, + IReadOnlyList logs, CancellationToken ct = default); } internal sealed class DatabricksTelemetryExporter : ITelemetryExporter { + // Authenticated telemetry endpoint + internal const string AuthenticatedEndpoint = "/telemetry-ext"; + + // Unauthenticated telemetry endpoint + internal const string UnauthenticatedEndpoint = "/telemetry-unauth"; + public DatabricksTelemetryExporter( HttpClient httpClient, - DatabricksConnection connection, + string host, + bool isAuthenticated, TelemetryConfiguration config); public Task ExportAsync( - IReadOnlyList metrics, + IReadOnlyList logs, CancellationToken ct = default); + + // Creates TelemetryRequest wrapper with uploadTime and protoLogs + internal TelemetryRequest CreateTelemetryRequest(IReadOnlyList logs); } } ``` -**Same implementation as original design**: Circuit breaker, retry logic, endpoints. +**Implementation Details**: +- Creates `TelemetryRequest` with `uploadTime` (Unix ms) and `protoLogs` (JSON-serialized `TelemetryFrontendLog` array) +- Uses `/telemetry-ext` for authenticated requests +- Uses `/telemetry-unauth` for unauthenticated requests +- Implements retry logic for transient failures (configurable via `MaxRetries` and `RetryDelayMs`) +- Uses `ExceptionClassifier` to identify terminal vs retryable errors +- Never throws exceptions (all caught and logged at TRACE level) +- Cancellation is propagated (not swallowed) --- @@ -609,7 +979,7 @@ Telemetry/ **File**: `TagDefinitions/TelemetryTag.cs` ```csharp -namespace Apache.Arrow.Adbc.Drivers.Databricks.Telemetry.TagDefinitions +namespace AdbcDrivers.Databricks.Telemetry.TagDefinitions { /// /// Defines export scope for telemetry tags. @@ -648,7 +1018,7 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.Telemetry.TagDefinitions **File**: `TagDefinitions/ConnectionOpenEvent.cs` ```csharp -namespace Apache.Arrow.Adbc.Drivers.Databricks.Telemetry.TagDefinitions +namespace AdbcDrivers.Databricks.Telemetry.TagDefinitions { /// /// Tag definitions for Connection.Open events. @@ -726,7 +1096,7 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.Telemetry.TagDefinitions **File**: `TagDefinitions/StatementExecutionEvent.cs` ```csharp -namespace Apache.Arrow.Adbc.Drivers.Databricks.Telemetry.TagDefinitions +namespace AdbcDrivers.Databricks.Telemetry.TagDefinitions { /// /// Tag definitions for Statement execution events. @@ -812,7 +1182,7 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.Telemetry.TagDefinitions **File**: `TagDefinitions/TelemetryTagRegistry.cs` ```csharp -namespace Apache.Arrow.Adbc.Drivers.Databricks.Telemetry.TagDefinitions +namespace AdbcDrivers.Databricks.Telemetry.TagDefinitions { /// /// Central registry for all telemetry tags and events. @@ -1069,9 +1439,15 @@ public sealed class TelemetryConfiguration public int CircuitBreakerThreshold { get; set; } = 5; public TimeSpan CircuitBreakerTimeout { get; set; } = TimeSpan.FromMinutes(1); - // Feature flag + // Feature flag name to check in the cached flags public const string FeatureFlagName = "databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForAdbc"; + + // Feature flag endpoint (relative to host) + // {0} = driver version without OSS suffix + // NOTE: Using OSS_JDBC endpoint until OSS_ADBC is configured server-side + public const string FeatureFlagEndpointFormat = + "/api/2.0/connector-service/feature-flags/OSS_JDBC/{0}"; } ``` @@ -1801,13 +2177,25 @@ The Activity-based design was selected because it: ## 12. Implementation Checklist ### Phase 1: Feature Flag Cache & Per-Host Management -- [ ] Create `FeatureFlagCache` singleton with per-host contexts +- [ ] Create `FeatureFlagCache` singleton with per-host contexts (in `Apache.Arrow.Adbc.Drivers.Databricks` namespace, not Telemetry) +- [ ] Make cache generic - return all flags, not just telemetry-specific ones - [ ] Implement `FeatureFlagContext` with reference counting -- [ ] Add cache expiration logic (15 minute TTL) -- [ ] Implement `FetchFeatureFlagAsync` to call feature endpoint +- [ ] Implement `GetFlagValue(string)` and `GetAllFlags()` methods for generic flag access +- [ ] Implement API call to `/api/2.0/connector-service/feature-flags/OSS_JDBC/{version}` (use JDBC endpoint initially) +- [ ] Parse `FeatureFlagsResponse` with `flags` array and `ttl_seconds` +- [ ] Implement initial blocking fetch on context creation +- [ ] Implement background refresh scheduler using server-provided `ttl_seconds` +- [ ] Add `Shutdown()` method to stop scheduler and cleanup +- [ ] Implement configuration priority: user properties > feature flags > driver defaults - [ ] Create `TelemetryClientManager` singleton - [ ] Implement `TelemetryClientHolder` with reference counting - [ ] Add unit tests for cache behavior and reference counting +- [ ] Add unit tests for background refresh scheduling +- [ ] Add unit tests for configuration priority order + +**Code Guidelines**: +- Avoid `#if` preprocessor directives - write code compatible with all target .NET versions (netstandard2.0, net472, net8.0) +- Use polyfills or runtime checks instead of compile-time conditionals where needed ### Phase 2: Circuit Breaker - [ ] Create `CircuitBreaker` class with state machine @@ -1837,7 +2225,7 @@ The Activity-based design was selected because it: ### Phase 5: Core Implementation - [ ] Create `DatabricksActivityListener` class - [ ] Create `MetricsAggregator` class (with exception buffering) -- [ ] Create `DatabricksTelemetryExporter` class +- [x] Create `DatabricksTelemetryExporter` class (WI-3.4) - [ ] Add necessary tags to existing activities (using defined constants) - [ ] Update connection to use per-host management @@ -1896,6 +2284,19 @@ This ensures compatibility with OTEL ecosystem. - Listener is optional (only activated when telemetry enabled) - Activity overhead already exists +### 13.4 Feature Flag Endpoint Migration + +**Question**: When should we migrate from `OSS_JDBC` to `OSS_ADBC` endpoint? + +**Current State**: The ADBC driver currently uses the JDBC feature flag endpoint (`/api/2.0/connector-service/feature-flags/OSS_JDBC/{version}`) because the ADBC endpoint is not yet configured on the server side. + +**Migration Plan**: +1. Server team configures the `OSS_ADBC` endpoint with appropriate feature flags +2. Update `TelemetryConfiguration.FeatureFlagEndpointFormat` to use `OSS_ADBC` +3. Coordinate with server team on feature flag name (`enableTelemetryForAdbc`) + +**Tracking**: Create a follow-up ticket to track this migration once server-side support is ready. + --- ## 14. References @@ -1920,8 +2321,12 @@ This ensures compatibility with OTEL ecosystem. - `CircuitBreakerTelemetryPushClient.java:15`: Circuit breaker wrapper - `CircuitBreakerManager.java:25`: Per-host circuit breaker management - `TelemetryPushClient.java:86-94`: Exception re-throwing for circuit breaker -- `TelemetryHelper.java:60-71`: Feature flag checking -- `DatabricksDriverFeatureFlagsContextFactory.java:27`: Per-host feature flag cache +- `TelemetryHelper.java:45-46,77`: Feature flag name and checking +- `DatabricksDriverFeatureFlagsContextFactory.java`: Per-host feature flag cache with reference counting +- `DatabricksDriverFeatureFlagsContext.java:30-33`: Feature flag API endpoint format +- `DatabricksDriverFeatureFlagsContext.java:48-58`: Initial fetch and background refresh scheduling +- `DatabricksDriverFeatureFlagsContext.java:89-140`: HTTP call and response parsing +- `FeatureFlagsResponse.java`: Response model with `flags` array and `ttl_seconds` --- diff --git a/csharp/doc/telemetry-sprint-plan.md b/csharp/doc/telemetry-sprint-plan.md index 02905a1c..19169659 100644 --- a/csharp/doc/telemetry-sprint-plan.md +++ b/csharp/doc/telemetry-sprint-plan.md @@ -436,6 +436,8 @@ Implement the core telemetry infrastructure including feature flag management, p #### WI-5.3: MetricsAggregator **Description**: Aggregates Activity data by statement_id, handles exception buffering. +**Status**: ✅ **COMPLETED** + **Location**: `csharp/src/Telemetry/MetricsAggregator.cs` **Input**: @@ -443,7 +445,7 @@ Implement the core telemetry infrastructure including feature flag management, p - ITelemetryExporter for flushing **Output**: -- Aggregated TelemetryMetric per statement +- Aggregated TelemetryEvent per statement - Batched flush on threshold or interval **Test Expectations**: @@ -452,13 +454,35 @@ Implement the core telemetry infrastructure including feature flag management, p |-----------|-----------|-------|-----------------| | Unit | `MetricsAggregator_ProcessActivity_ConnectionOpen_EmitsImmediately` | Connection.Open activity | Metric queued for export | | Unit | `MetricsAggregator_ProcessActivity_Statement_AggregatesByStatementId` | Multiple activities with same statement_id | Single aggregated metric | -| Unit | `MetricsAggregator_CompleteStatement_EmitsAggregatedMetric` | Call CompleteStatement() | Queues aggregated metric | -| Unit | `MetricsAggregator_FlushAsync_BatchSizeReached_ExportsMetrics` | 100 metrics (batch size) | Calls exporter | -| Unit | `MetricsAggregator_FlushAsync_TimeInterval_ExportsMetrics` | Wait 5 seconds | Calls exporter | +| Unit | `MetricsAggregator_CompleteStatement_EmitsAggregatedEvent` | Call CompleteStatement() | Queues aggregated metric | +| Unit | `MetricsAggregator_FlushAsync_BatchSizeReached_ExportsEvents` | Batch size reached | Calls exporter | +| Unit | `MetricsAggregator_FlushAsync_TimeInterval_ExportsEvents` | Wait for interval | Calls exporter | | Unit | `MetricsAggregator_RecordException_Terminal_FlushesImmediately` | Terminal exception | Immediately exports error metric | | Unit | `MetricsAggregator_RecordException_Retryable_BuffersUntilComplete` | Retryable exception | Buffers, exports on CompleteStatement | | Unit | `MetricsAggregator_ProcessActivity_ExceptionSwallowed_NoThrow` | Activity processing throws | No exception propagated | | Unit | `MetricsAggregator_ProcessActivity_FiltersTags_UsingRegistry` | Activity with sensitive tags | Only safe tags in metric | +| Unit | `MetricsAggregator_WrapInFrontendLog_CreatesValidStructure` | TelemetryEvent | Valid TelemetryFrontendLog structure | + +**Implementation Notes**: +- Uses `ConcurrentDictionary` for thread-safe aggregation by statement_id +- Connection events emit immediately without aggregation +- Statement events are aggregated until `CompleteStatement()` is called +- Terminal exceptions (via `ExceptionClassifier`) are queued immediately +- Retryable exceptions are buffered and only emitted when `CompleteStatement(failed: true)` is called +- Uses `TelemetryTagRegistry.ShouldExportToDatabricks()` for tag filtering +- Creates `TelemetryFrontendLog` wrapper with workspace_id, client context, and timestamp +- All exceptions swallowed and logged at TRACE level using `Debug.WriteLine()` +- Timer-based periodic flush using `System.Threading.Timer` +- Comprehensive test coverage with 29 unit tests in `MetricsAggregatorTests.cs` +- Test file location: `csharp/test/Unit/Telemetry/MetricsAggregatorTests.cs` + +**Key Design Decisions**: +1. **ConcurrentDictionary for aggregation**: Thread-safe statement aggregation without explicit locking +2. **Nested StatementTelemetryContext**: Holds aggregated metrics and buffered exceptions per statement +3. **Immediate connection events**: Connection open events don't require aggregation and are emitted immediately +4. **Exception buffering**: Retryable exceptions are buffered per statement and only emitted on failed completion +5. **Timer-based flush**: Uses `System.Threading.Timer` for periodic flush based on `FlushIntervalMs` +6. **Graceful disposal**: `Dispose()` stops timer and performs final flush --- @@ -492,6 +516,8 @@ Implement the core telemetry infrastructure including feature flag management, p #### WI-5.5: TelemetryClient **Description**: Main telemetry client that coordinates listener, aggregator, and exporter. +**Status**: ✅ **COMPLETED** + **Location**: `csharp/src/Telemetry/TelemetryClient.cs` **Input**: @@ -511,6 +537,23 @@ Implement the core telemetry infrastructure including feature flag management, p | Unit | `TelemetryClient_CloseAsync_FlushesAndCancels` | N/A | Pending metrics flushed, background task cancelled | | Unit | `TelemetryClient_CloseAsync_ExceptionSwallowed` | Flush throws | No exception propagated | +**Implementation Notes**: +- Implements `ITelemetryClient` interface with `Host`, `ExportAsync`, and `CloseAsync` members +- Constructor creates the full telemetry pipeline: DatabricksTelemetryExporter → CircuitBreakerTelemetryExporter → MetricsAggregator → DatabricksActivityListener +- Starts a background flush task that periodically flushes metrics based on `FlushIntervalMs` configuration +- `CloseAsync` implements graceful shutdown: cancels background task, stops listener (which flushes pending metrics), waits for background task with 5s timeout, disposes resources +- All operations in `CloseAsync` wrapped in try-catch to swallow exceptions per telemetry requirement +- Updated `TelemetryClientManager.CreateClient()` to use `TelemetryClient` instead of `TelemetryClientAdapter` +- Comprehensive test coverage with 21 unit tests covering constructor, export, close, background flush, and thread safety +- Test file location: `csharp/test/Unit/Telemetry/TelemetryClientTests.cs` + +**Key Design Decisions**: +1. **Background flush task**: Uses `Task.Run` with internal loop and `Task.Delay` for periodic flushing +2. **Graceful shutdown**: CloseAsync uses 5-second timeout waiting for background task to prevent hanging +3. **Cross-framework compatibility**: Uses conditional compilation (`#if NET6_0_OR_GREATER`) for `Task.WaitAsync` vs `Task.WhenAny` fallback +4. **Exception swallowing**: Every operation in CloseAsync wrapped in try-catch per design requirement +5. **Idempotent close**: Uses lock and boolean flag to ensure CloseAsync can be called multiple times safely + --- ### Phase 6: Integration diff --git a/csharp/src/AdbcDrivers.Databricks.csproj b/csharp/src/AdbcDrivers.Databricks.csproj index 2ca93e77..0a54336a 100644 --- a/csharp/src/AdbcDrivers.Databricks.csproj +++ b/csharp/src/AdbcDrivers.Databricks.csproj @@ -7,6 +7,7 @@ + diff --git a/csharp/src/Auth/AuthHelper.cs b/csharp/src/Auth/AuthHelper.cs new file mode 100644 index 00000000..0cfb01e8 --- /dev/null +++ b/csharp/src/Auth/AuthHelper.cs @@ -0,0 +1,128 @@ +/* +* Copyright (c) 2025 ADBC Drivers Contributors +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Generic; +using Apache.Arrow.Adbc.Drivers.Apache.Spark; + +namespace AdbcDrivers.Databricks.Auth +{ + /// + /// Helper methods for authentication operations. + /// Provides shared functionality used by HttpHandlerFactory and FeatureFlagCache. + /// + internal static class AuthHelper + { + /// + /// Gets the access token from connection properties. + /// Tries access_token first, then falls back to token. + /// + /// Connection properties. + /// The token, or null if not found. + public static string? GetTokenFromProperties(IReadOnlyDictionary properties) + { + if (properties.TryGetValue(SparkParameters.AccessToken, out string? accessToken) && !string.IsNullOrEmpty(accessToken)) + { + return accessToken; + } + + if (properties.TryGetValue(SparkParameters.Token, out string? token) && !string.IsNullOrEmpty(token)) + { + return token; + } + + return null; + } + + /// + /// Gets the access token based on authentication configuration. + /// Supports token-based (PAT) and OAuth M2M (client_credentials) authentication. + /// + /// The Databricks host. + /// Connection properties containing auth configuration. + /// HTTP client timeout for OAuth token requests. + /// The access token, or null if no valid authentication is configured. + public static string? GetAccessToken(string host, IReadOnlyDictionary properties, TimeSpan timeout) + { + // Check if OAuth authentication is configured + bool useOAuth = properties.TryGetValue(SparkParameters.AuthType, out string? authType) && + SparkAuthTypeParser.TryParse(authType, out SparkAuthType authTypeValue) && + authTypeValue == SparkAuthType.OAuth; + + if (useOAuth) + { + // Determine grant type (defaults to AccessToken if not specified) + properties.TryGetValue(DatabricksParameters.OAuthGrantType, out string? grantTypeStr); + DatabricksOAuthGrantTypeParser.TryParse(grantTypeStr, out DatabricksOAuthGrantType grantType); + + if (grantType == DatabricksOAuthGrantType.ClientCredentials) + { + // OAuth M2M authentication + return GetOAuthClientCredentialsToken(host, properties, timeout); + } + else if (grantType == DatabricksOAuthGrantType.AccessToken) + { + // OAuth with access_token grant type + return GetTokenFromProperties(properties); + } + } + + // Non-OAuth authentication: use static Bearer token if provided + return GetTokenFromProperties(properties); + } + + /// + /// Gets an OAuth access token using client credentials (M2M) flow. + /// + /// The Databricks host. + /// Connection properties containing OAuth credentials. + /// HTTP client timeout. + /// The access token, or null if credentials are missing or token acquisition fails. + public static string? GetOAuthClientCredentialsToken(string host, IReadOnlyDictionary properties, TimeSpan timeout) + { + properties.TryGetValue(DatabricksParameters.OAuthClientId, out string? clientId); + properties.TryGetValue(DatabricksParameters.OAuthClientSecret, out string? clientSecret); + properties.TryGetValue(DatabricksParameters.OAuthScope, out string? scope); + + if (string.IsNullOrEmpty(clientId) || string.IsNullOrEmpty(clientSecret)) + { + return null; + } + + try + { + // Create a separate HttpClient for OAuth token acquisition with TLS and proxy settings + using var oauthHttpClient = Http.HttpClientFactory.CreateOAuthHttpClient(properties, timeout); + + using var tokenProvider = new OAuthClientCredentialsProvider( + oauthHttpClient, + clientId, + clientSecret, + host, + scope: scope ?? "sql", + timeoutMinutes: 1); + + // Get access token synchronously (blocking call) + return tokenProvider.GetAccessTokenAsync().GetAwaiter().GetResult(); + } + catch + { + // Auth failures should be handled by caller + return null; + } + } + } +} diff --git a/csharp/src/DatabricksConnection.cs b/csharp/src/DatabricksConnection.cs index 6306363f..1cfa5187 100644 --- a/csharp/src/DatabricksConnection.cs +++ b/csharp/src/DatabricksConnection.cs @@ -126,12 +126,13 @@ internal DatabricksConnection( IReadOnlyDictionary properties, Microsoft.IO.RecyclableMemoryStreamManager? memoryStreamManager, System.Buffers.ArrayPool? lz4BufferPool) - : base(MergeWithDefaultEnvironmentConfig(properties)) + : base(FeatureFlagCache.GetInstance().MergePropertiesWithFeatureFlags(MergeWithDefaultEnvironmentConfig(properties), s_assemblyVersion)) { // Use provided manager (from Database) or create new instance (for direct construction) RecyclableMemoryStreamManager = memoryStreamManager ?? new Microsoft.IO.RecyclableMemoryStreamManager(); // Use provided pool (from Database) or create new instance (for direct construction) Lz4BufferPool = lz4BufferPool ?? System.Buffers.ArrayPool.Create(maxArrayLength: 4 * 1024 * 1024, maxArraysPerBucket: 10); + ValidateProperties(); } @@ -526,7 +527,7 @@ internal override IArrowArrayStream NewReader(T statement, Schema schema, IRe isLz4Compressed = metadataResp.Lz4Compressed; } - HttpClient httpClient = new HttpClient(HiveServer2TlsImpl.NewHttpClientHandler(TlsOptions, _proxyConfigurator)); + HttpClient httpClient = HttpClientFactory.CreateCloudFetchHttpClient(Properties); return new DatabricksCompositeReader(databricksStatement, schema, response, isLz4Compressed, httpClient); } diff --git a/csharp/src/DatabricksParameters.cs b/csharp/src/DatabricksParameters.cs index 77236931..3bfd4ba1 100644 --- a/csharp/src/DatabricksParameters.cs +++ b/csharp/src/DatabricksParameters.cs @@ -357,6 +357,12 @@ public class DatabricksParameters : SparkParameters /// public const string EnableSessionManagement = "adbc.databricks.rest.enable_session_management"; + /// + /// Timeout in seconds for feature flag API requests. + /// Default value is 10 seconds if not specified. + /// + public const string FeatureFlagTimeoutSeconds = "adbc.databricks.feature_flag_timeout_seconds"; + } /// diff --git a/csharp/src/FeatureFlagCache.cs b/csharp/src/FeatureFlagCache.cs new file mode 100644 index 00000000..8a0598e6 --- /dev/null +++ b/csharp/src/FeatureFlagCache.cs @@ -0,0 +1,492 @@ +/* +* Copyright (c) 2025 ADBC Drivers Contributors +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; +using Apache.Arrow.Adbc.Drivers.Apache; +using Apache.Arrow.Adbc.Drivers.Apache.Spark; +using Microsoft.Extensions.Caching.Memory; + +namespace AdbcDrivers.Databricks +{ + /// + /// Singleton that manages feature flag cache per host using IMemoryCache. + /// Prevents rate limiting by caching feature flag responses with TTL-based expiration. + /// + /// + /// + /// This class implements a per-host caching pattern: + /// - Feature flags are cached by host to prevent rate limiting + /// - TTL-based expiration using server's ttl_seconds + /// - Background refresh task runs based on TTL within each FeatureFlagContext + /// - Automatic cleanup via IMemoryCache eviction + /// - Thread-safe using IMemoryCache and SemaphoreSlim for async locking + /// + /// + /// JDBC Reference: DatabricksDriverFeatureFlagsContextFactory.java + /// + /// + internal sealed class FeatureFlagCache : IDisposable + { + private static readonly FeatureFlagCache s_instance = new FeatureFlagCache(); + + /// + /// Activity source for feature flag tracing. + /// + private static readonly ActivitySource s_activitySource = new ActivitySource("AdbcDrivers.Databricks.FeatureFlagCache"); + + private readonly IMemoryCache _cache; + private readonly SemaphoreSlim _createLock = new SemaphoreSlim(1, 1); + private bool _disposed; + + /// + /// Gets the singleton instance of the FeatureFlagCache. + /// + public static FeatureFlagCache GetInstance() => s_instance; + + /// + /// Creates a new FeatureFlagCache with default MemoryCache. + /// + internal FeatureFlagCache() : this(new MemoryCache(new MemoryCacheOptions())) + { + } + + /// + /// Creates a new FeatureFlagCache with custom IMemoryCache (for testing). + /// + /// The memory cache to use. + internal FeatureFlagCache(IMemoryCache cache) + { + _cache = cache ?? throw new ArgumentNullException(nameof(cache)); + } + + /// + /// Gets or creates a feature flag context for the host asynchronously. + /// Waits for initial fetch to complete before returning. + /// + /// The host (Databricks workspace URL) to get or create a context for. + /// + /// HttpClient from the connection, pre-configured with: + /// - Base address (https://{host}) + /// - Auth headers (Bearer token) + /// - Custom User-Agent for connector service + /// + /// The driver version for the API endpoint. + /// Optional custom endpoint format. If null, uses the default endpoint. + /// Cancellation token. + /// The feature flag context for the host. + /// Thrown when host is null or whitespace. + /// Thrown when httpClient is null. + public async Task GetOrCreateContextAsync( + string host, + HttpClient httpClient, + string driverVersion, + string? endpointFormat = null, + CancellationToken cancellationToken = default) + { + if (string.IsNullOrWhiteSpace(host)) + { + throw new ArgumentException("Host cannot be null or whitespace.", nameof(host)); + } + + if (httpClient == null) + { + throw new ArgumentNullException(nameof(httpClient)); + } + + var cacheKey = GetCacheKey(host); + + // Try to get existing context (fast path) + if (_cache.TryGetValue(cacheKey, out FeatureFlagContext? context) && context != null) + { + Activity.Current?.AddEvent(new ActivityEvent("feature_flags.cache_hit", + tags: new ActivityTagsCollection { { "host", host } })); + + return context; + } + + // Cache miss - create new context with async lock to prevent duplicate creation + await _createLock.WaitAsync(cancellationToken).ConfigureAwait(false); + try + { + // Double-check after acquiring lock + if (_cache.TryGetValue(cacheKey, out context) && context != null) + { + return context; + } + + // Create context asynchronously - this waits for initial fetch to complete + context = await FeatureFlagContext.CreateAsync( + host, + httpClient, + driverVersion, + endpointFormat, + cancellationToken).ConfigureAwait(false); + + // Set cache options with TTL from context + var cacheOptions = new MemoryCacheEntryOptions() + .SetAbsoluteExpiration(context.Ttl) + .RegisterPostEvictionCallback(OnCacheEntryEvicted); + + _cache.Set(cacheKey, context, cacheOptions); + + Activity.Current?.AddEvent(new ActivityEvent("feature_flags.context_created", + tags: new ActivityTagsCollection + { + { "host", host }, + { "ttl_seconds", context.Ttl.TotalSeconds } + })); + + return context; + } + finally + { + _createLock.Release(); + } + } + + /// + /// Synchronous wrapper for GetOrCreateContextAsync. + /// Used for backward compatibility with synchronous callers. + /// + public FeatureFlagContext GetOrCreateContext( + string host, + HttpClient httpClient, + string driverVersion, + string? endpointFormat = null) + { + return GetOrCreateContextAsync(host, httpClient, driverVersion, endpointFormat) + .ConfigureAwait(false) + .GetAwaiter() + .GetResult(); + } + + /// + /// Callback invoked when a cache entry is evicted. + /// Disposes the context to clean up resources (stops background refresh task). + /// + private static void OnCacheEntryEvicted(object key, object? value, EvictionReason reason, object? state) + { + if (value is FeatureFlagContext context) + { + context.Dispose(); + + Activity.Current?.AddEvent(new ActivityEvent("feature_flags.context_evicted", + tags: new ActivityTagsCollection + { + { "host", context.Host }, + { "reason", reason.ToString() } + })); + } + } + + /// + /// Gets the cache key for a host. + /// + private static string GetCacheKey(string host) + { + return $"feature_flags:{host.ToLowerInvariant()}"; + } + + /// + /// Gets the number of hosts currently cached. + /// Note: IMemoryCache doesn't expose count directly, so this returns -1 if not available. + /// + internal int CachedHostCount + { + get + { + if (_cache is MemoryCache memoryCache) + { + return memoryCache.Count; + } + return -1; + } + } + + /// + /// Checks if a context exists for the specified host. + /// + /// The host to check. + /// True if a context exists, false otherwise. + internal bool HasContext(string host) + { + if (string.IsNullOrWhiteSpace(host)) + { + return false; + } + + return _cache.TryGetValue(GetCacheKey(host), out _); + } + + /// + /// Gets the context for the specified host, if it exists. + /// Does not create a new context. + /// + /// The host to get the context for. + /// The context if found, null otherwise. + /// True if the context was found, false otherwise. + internal bool TryGetContext(string host, out FeatureFlagContext? context) + { + context = null; + + if (string.IsNullOrWhiteSpace(host)) + { + return false; + } + + return _cache.TryGetValue(GetCacheKey(host), out context); + } + + /// + /// Removes a context from the cache. + /// + /// The host to remove. + internal void RemoveContext(string host) + { + if (!string.IsNullOrWhiteSpace(host)) + { + _cache.Remove(GetCacheKey(host)); + } + } + + /// + /// Clears all cached contexts. + /// This is primarily for testing purposes. + /// + internal void Clear() + { + if (_cache is MemoryCache memoryCache) + { + memoryCache.Compact(1.0); // Remove all entries + } + } + + /// + /// Disposes the cache and all cached contexts. + /// + public void Dispose() + { + if (_disposed) + { + return; + } + + _disposed = true; + + _createLock.Dispose(); + + if (_cache is IDisposable disposableCache) + { + disposableCache.Dispose(); + } + } + + /// + /// Merges feature flags from server into properties asynchronously. + /// Feature flags (remote properties) have lower priority than user-specified properties (local properties). + /// Priority: Local Properties > Remote Properties (Feature Flags) > Driver Defaults + /// + /// Local properties from user configuration and environment. + /// The driver version for the API endpoint. + /// Cancellation token. + /// Properties with remote feature flags merged in (local properties take precedence). + public async Task> MergePropertiesWithFeatureFlagsAsync( + IReadOnlyDictionary localProperties, + string assemblyVersion, + CancellationToken cancellationToken = default) + { + using var activity = s_activitySource.StartActivity("MergePropertiesWithFeatureFlags"); + + try + { + // Extract host from local properties + var host = TryGetHost(localProperties); + if (string.IsNullOrEmpty(host)) + { + activity?.AddEvent(new ActivityEvent("feature_flags.skipped", + tags: new ActivityTagsCollection { { "reason", "no_host" } })); + return localProperties; + } + + activity?.SetTag("feature_flags.host", host); + + // Create HttpClient for feature flag API + using var httpClient = CreateFeatureFlagHttpClient(host, assemblyVersion, localProperties); + + if (httpClient == null) + { + activity?.AddEvent(new ActivityEvent("feature_flags.skipped", + tags: new ActivityTagsCollection { { "reason", "no_auth_credentials" } })); + return localProperties; + } + + // Get or create feature flag context asynchronously (waits for initial fetch) + var context = await GetOrCreateContextAsync(host, httpClient, assemblyVersion, cancellationToken: cancellationToken).ConfigureAwait(false); + + // Get all flags from cache (remote properties) + var remoteProperties = context.GetAllFlags(); + + if (remoteProperties.Count == 0) + { + activity?.AddEvent(new ActivityEvent("feature_flags.skipped", + tags: new ActivityTagsCollection { { "reason", "no_flags_returned" } })); + return localProperties; + } + + activity?.SetTag("feature_flags.count", remoteProperties.Count); + activity?.AddEvent(new ActivityEvent("feature_flags.merging", + tags: new ActivityTagsCollection { { "flags_count", remoteProperties.Count } })); + + // Merge: remote properties (feature flags) as base, local properties override + // This ensures local properties take precedence over remote flags + return MergeProperties(remoteProperties, localProperties); + } + catch (Exception ex) + { + // Feature flag failures should never break the connection + activity?.SetStatus(ActivityStatusCode.Error, ex.Message); + activity?.AddEvent(new ActivityEvent("feature_flags.error", + tags: new ActivityTagsCollection + { + { "error.type", ex.GetType().Name }, + { "error.message", ex.Message } + })); + return localProperties; + } + } + + /// + /// Synchronous wrapper for MergePropertiesWithFeatureFlagsAsync. + /// Used for backward compatibility with synchronous callers. + /// + public IReadOnlyDictionary MergePropertiesWithFeatureFlags( + IReadOnlyDictionary localProperties, + string assemblyVersion) + { + return MergePropertiesWithFeatureFlagsAsync(localProperties, assemblyVersion) + .ConfigureAwait(false) + .GetAwaiter() + .GetResult(); + } + + /// + /// Tries to extract the host from properties without throwing. + /// Handles cases where user puts protocol in host (e.g., "https://myhost.databricks.com"). + /// + /// Connection properties. + /// The host (without protocol), or null if not found. + internal static string? TryGetHost(IReadOnlyDictionary properties) + { + if (properties.TryGetValue(SparkParameters.HostName, out string? host) && !string.IsNullOrEmpty(host)) + { + // Handle case where user puts protocol in host + return StripProtocol(host); + } + + if (properties.TryGetValue(Apache.Arrow.Adbc.AdbcOptions.Uri, out string? uri) && !string.IsNullOrEmpty(uri)) + { + if (Uri.TryCreate(uri, UriKind.Absolute, out Uri? parsedUri)) + { + return parsedUri.Host; + } + } + + return null; + } + + /// + /// Strips protocol prefix from a host string if present. + /// + /// The host string that may contain a protocol. + /// The host without protocol prefix. + private static string StripProtocol(string host) + { + // Try to parse as URI first to handle full URLs + if (Uri.TryCreate(host, UriKind.Absolute, out Uri? parsedUri) && + (parsedUri.Scheme == Uri.UriSchemeHttp || parsedUri.Scheme == Uri.UriSchemeHttps)) + { + return parsedUri.Host; + } + + // Fallback: strip common protocol prefixes manually + if (host.StartsWith("https://", StringComparison.OrdinalIgnoreCase)) + { + return host.Substring(8); + } + if (host.StartsWith("http://", StringComparison.OrdinalIgnoreCase)) + { + return host.Substring(7); + } + + return host; + } + + /// + /// Creates an HttpClient configured for the feature flag API. + /// Supports all authentication methods including: + /// - PAT (Personal Access Token) + /// - OAuth M2M (client_credentials) + /// - OAuth U2M (access_token) + /// - Workload Identity Federation (via MandatoryTokenExchangeDelegatingHandler) + /// Respects proxy settings and TLS options from connection properties. + /// + /// The Databricks host (without protocol). + /// The driver version for the User-Agent. + /// Connection properties for proxy, TLS, and auth configuration. + /// Configured HttpClient, or null if no valid authentication is available. + private static HttpClient? CreateFeatureFlagHttpClient( + string host, + string assemblyVersion, + IReadOnlyDictionary properties) + { + // Use centralized factory to create HttpClient with full auth handler chain + // This properly supports Workload Identity Federation via MandatoryTokenExchangeDelegatingHandler + return Http.HttpClientFactory.CreateFeatureFlagHttpClient(properties, host, assemblyVersion); + } + + /// + /// Merges two property dictionaries. Additional properties override base properties. + /// + /// Base properties (lower priority). + /// Additional properties (higher priority). + /// Merged properties. + private static IReadOnlyDictionary MergeProperties( + IReadOnlyDictionary baseProperties, + IReadOnlyDictionary additionalProperties) + { + var merged = new Dictionary(StringComparer.OrdinalIgnoreCase); + + // Add base properties first + foreach (var kvp in baseProperties) + { + merged[kvp.Key] = kvp.Value; + } + + // Additional properties override base properties + foreach (var kvp in additionalProperties) + { + merged[kvp.Key] = kvp.Value; + } + + return merged; + } + } +} diff --git a/csharp/src/FeatureFlagContext.cs b/csharp/src/FeatureFlagContext.cs new file mode 100644 index 00000000..1cf44f7b --- /dev/null +++ b/csharp/src/FeatureFlagContext.cs @@ -0,0 +1,370 @@ +/* +* Copyright (c) 2025 ADBC Drivers Contributors +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Diagnostics; +using System.Net.Http; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using Apache.Arrow.Adbc.Drivers.Apache; +using Apache.Arrow.Adbc.Tracing; + +namespace AdbcDrivers.Databricks +{ + /// + /// Holds feature flag state for a host. + /// Cached by FeatureFlagCache with TTL-based expiration. + /// + /// + /// Each host (Databricks workspace) has one FeatureFlagContext instance + /// that is shared across all connections to that host. The context: + /// - Caches all feature flags returned by the server + /// - Tracks TTL from server response for cache expiration + /// - Runs a background refresh task based on TTL + /// + /// Thread-safety is ensured using ConcurrentDictionary for flag storage. + /// + /// JDBC Reference: DatabricksDriverFeatureFlagsContext.java + /// + internal sealed class FeatureFlagContext : IDisposable + { + /// + /// Activity source for feature flag tracing. + /// + private static readonly ActivitySource s_activitySource = new ActivitySource("AdbcDrivers.Databricks.FeatureFlags"); + + /// + /// Assembly version for the driver. + /// + private static readonly string s_assemblyVersion = ApacheUtility.GetAssemblyVersion(typeof(FeatureFlagContext)); + + /// + /// Default TTL (15 minutes) if server doesn't specify ttl_seconds. + /// + public static readonly TimeSpan DefaultTtl = TimeSpan.FromMinutes(15); + + /// + /// Default feature flag endpoint format. {0} = driver version. + /// NOTE: Using OSS_JDBC endpoint until OSS_ADBC is configured server-side. + /// + internal const string DefaultFeatureFlagEndpointFormat = "/api/2.0/connector-service/feature-flags/OSS_JDBC/{0}"; + + private readonly string _host; + private readonly string _driverVersion; + private readonly string _endpointFormat; + private readonly HttpClient? _httpClient; + private readonly ConcurrentDictionary _flags; + private readonly CancellationTokenSource _refreshCts; + private readonly object _ttlLock = new object(); + + private Task? _refreshTask; + private TimeSpan _ttl; + private bool _disposed; + + /// + /// Gets the current TTL (from server ttl_seconds). + /// + public TimeSpan Ttl + { + get + { + lock (_ttlLock) + { + return _ttl; + } + } + internal set + { + lock (_ttlLock) + { + _ttl = value; + } + } + } + + /// + /// Gets the host this context is for. + /// + public string Host => _host; + + /// + /// Gets the current refresh interval (alias for Ttl). + /// + public TimeSpan RefreshInterval => Ttl; + + /// + /// Internal constructor - use CreateAsync factory method for production code. + /// Made internal to allow test code to create instances without HTTP calls. + /// + internal FeatureFlagContext(string host, HttpClient? httpClient, string driverVersion, string? endpointFormat) + { + _host = host; + _httpClient = httpClient; + _driverVersion = driverVersion ?? "1.0.0"; + _endpointFormat = endpointFormat ?? DefaultFeatureFlagEndpointFormat; + _flags = new ConcurrentDictionary(StringComparer.OrdinalIgnoreCase); + _ttl = DefaultTtl; + _refreshCts = new CancellationTokenSource(); + } + + /// + /// Creates a new context with the given HTTP client. + /// Performs initial async fetch to populate cache, then starts background refresh task. + /// + /// The Databricks host. + /// + /// HttpClient from the connection, pre-configured with: + /// - Base address (https://{host}) + /// - Auth headers (Bearer token) + /// - Custom User-Agent for connector service + /// + /// The driver version for the API endpoint. + /// Optional custom endpoint format. If null, uses the default endpoint. + /// Cancellation token for the initial fetch. + /// A fully initialized FeatureFlagContext. + public static async Task CreateAsync( + string host, + HttpClient httpClient, + string driverVersion, + string? endpointFormat = null, + CancellationToken cancellationToken = default) + { + if (string.IsNullOrWhiteSpace(host)) + { + throw new ArgumentException("Host cannot be null or whitespace.", nameof(host)); + } + + if (httpClient == null) + { + throw new ArgumentNullException(nameof(httpClient)); + } + + var context = new FeatureFlagContext(host, httpClient, driverVersion, endpointFormat); + + // Initial async fetch - wait for it to complete + await context.FetchFeatureFlagsAsync("Initial", cancellationToken).ConfigureAwait(false); + + // Start background refresh task + context.StartBackgroundRefresh(); + + return context; + } + + /// + /// Gets a feature flag value by name. + /// Returns null if the flag is not found. + /// + /// The feature flag name. + /// The flag value, or null if not found. + public string? GetFlagValue(string flagName) + { + if (string.IsNullOrWhiteSpace(flagName)) + { + return null; + } + + return _flags.TryGetValue(flagName, out var value) ? value : null; + } + + /// + /// Gets all cached feature flags as a dictionary. + /// Can be used to merge with user properties. + /// + /// A read-only dictionary of all cached flags. + public IReadOnlyDictionary GetAllFlags() + { + // Return a snapshot to avoid concurrency issues + return new Dictionary(_flags, StringComparer.OrdinalIgnoreCase); + } + + /// + /// Starts the background refresh task that periodically fetches flags based on TTL. + /// + private void StartBackgroundRefresh() + { + _refreshTask = Task.Run(async () => + { + while (!_refreshCts.Token.IsCancellationRequested) + { + try + { + // Wait for TTL duration before refreshing + await Task.Delay(Ttl, _refreshCts.Token).ConfigureAwait(false); + + if (!_refreshCts.Token.IsCancellationRequested) + { + await FetchFeatureFlagsAsync("Background", _refreshCts.Token).ConfigureAwait(false); + } + } + catch (OperationCanceledException) + { + // Normal cancellation, exit the loop + break; + } + catch (Exception ex) + { + // Log error but continue the refresh loop + Activity.Current?.AddEvent("feature_flags.background_refresh.error", [ + new("error.message", ex.Message), + new("error.type", ex.GetType().Name) + ]); + } + } + }, _refreshCts.Token); + } + + /// + /// Disposes the context and stops the background refresh task. + /// + public void Dispose() + { + if (_disposed) + { + return; + } + + _disposed = true; + + // Cancel the background refresh task + _refreshCts.Cancel(); + + // Wait briefly for the task to complete (don't block indefinitely) + try + { + _refreshTask?.Wait(TimeSpan.FromSeconds(1)); + } + catch (AggregateException) + { + // Task was cancelled, ignore + } + + _refreshCts.Dispose(); + } + + /// + /// Fetches feature flags from the API endpoint asynchronously. + /// + /// Type of fetch for logging purposes (e.g., "Initial" or "Background"). + /// Cancellation token. + private async Task FetchFeatureFlagsAsync(string fetchType, CancellationToken cancellationToken) + { + if (_httpClient == null) + { + return; + } + + using var activity = s_activitySource.StartActivity($"FetchFeatureFlags.{fetchType}"); + activity?.SetTag("feature_flags.host", _host); + activity?.SetTag("feature_flags.fetch_type", fetchType); + + try + { + var endpoint = string.Format(_endpointFormat, _driverVersion); + activity?.SetTag("feature_flags.endpoint", endpoint); + + var response = await _httpClient.GetAsync(endpoint, cancellationToken).ConfigureAwait(false); + + activity?.SetTag("feature_flags.response.status_code", (int)response.StatusCode); + + // Use the standard EnsureSuccessOrThrow extension method + response.EnsureSuccessOrThrow(); + + var content = await response.Content.ReadAsStringAsync().ConfigureAwait(false); + ProcessResponse(content, activity); + + activity?.SetStatus(ActivityStatusCode.Ok); + } + catch (OperationCanceledException) + { + // Propagate cancellation + throw; + } + catch (Exception ex) + { + // Swallow exceptions - telemetry should not break the connection + activity?.SetStatus(ActivityStatusCode.Error, ex.Message); + activity?.AddEvent("feature_flags.fetch.failed", [ + new("error.message", ex.Message), + new("error.type", ex.GetType().Name) + ]); + } + } + + /// + /// Processes the JSON response and updates the cache. + /// + /// The JSON response content. + /// The current activity for tracing. + private void ProcessResponse(string content, Activity? activity) + { + try + { + var response = JsonSerializer.Deserialize(content); + + if (response?.Flags != null) + { + foreach (var flag in response.Flags) + { + if (!string.IsNullOrEmpty(flag.Name)) + { + _flags[flag.Name] = flag.Value ?? string.Empty; + } + } + + activity?.SetTag("feature_flags.count", response.Flags.Count); + activity?.AddEvent("feature_flags.updated", [ + new("flags_count", response.Flags.Count) + ]); + } + + // Update TTL if server provides a different value + if (response?.TtlSeconds != null && response.TtlSeconds > 0) + { + Ttl = TimeSpan.FromSeconds(response.TtlSeconds.Value); + activity?.SetTag("feature_flags.ttl_seconds", response.TtlSeconds.Value); + } + } + catch (JsonException ex) + { + activity?.AddEvent("feature_flags.parse.failed", [ + new("error.message", ex.Message), + new("error.type", ex.GetType().Name) + ]); + } + } + + /// + /// Clears all cached flags. + /// This is primarily for testing purposes. + /// + internal void ClearFlags() + { + _flags.Clear(); + } + + /// + /// Sets a flag value directly. + /// This is primarily for testing purposes. + /// + internal void SetFlag(string name, string value) + { + _flags[name] = value; + } + } +} diff --git a/csharp/src/FeatureFlagsResponse.cs b/csharp/src/FeatureFlagsResponse.cs new file mode 100644 index 00000000..088fcf3d --- /dev/null +++ b/csharp/src/FeatureFlagsResponse.cs @@ -0,0 +1,59 @@ +/* +* Copyright (c) 2025 ADBC Drivers Contributors +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System.Collections.Generic; +using System.Text.Json.Serialization; + +namespace AdbcDrivers.Databricks +{ + /// + /// Response model for the feature flags API. + /// Maps to the JSON response from /api/2.0/connector-service/feature-flags/{driver_type}/{version}. + /// + internal sealed class FeatureFlagsResponse + { + /// + /// Array of feature flag entries with name and value. + /// + [JsonPropertyName("flags")] + public List? Flags { get; set; } + + /// + /// Server-controlled refresh interval in seconds. + /// Default is 900 (15 minutes) if not provided. + /// + [JsonPropertyName("ttl_seconds")] + public int? TtlSeconds { get; set; } + } + + /// + /// Individual feature flag entry with name and value. + /// + internal sealed class FeatureFlagEntry + { + /// + /// The feature flag name (e.g., "databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForAdbc"). + /// + [JsonPropertyName("name")] + public string Name { get; set; } = string.Empty; + + /// + /// The feature flag value as a string (e.g., "true", "false", "10"). + /// + [JsonPropertyName("value")] + public string Value { get; set; } = string.Empty; + } +} diff --git a/csharp/src/Http/HttpClientFactory.cs b/csharp/src/Http/HttpClientFactory.cs new file mode 100644 index 00000000..b9282ae3 --- /dev/null +++ b/csharp/src/Http/HttpClientFactory.cs @@ -0,0 +1,140 @@ +/* +* Copyright (c) 2025 ADBC Drivers Contributors +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Generic; +using System.Net.Http; +using System.Net.Http.Headers; +using Apache.Arrow.Adbc.Drivers.Apache; +using Apache.Arrow.Adbc.Drivers.Apache.Hive2; + +namespace AdbcDrivers.Databricks.Http +{ + /// + /// Centralized factory for creating HttpClient instances with proper TLS and proxy configuration. + /// All HttpClient creation should go through this factory to ensure consistent configuration. + /// + internal static class HttpClientFactory + { + /// + /// Creates an HttpClientHandler with TLS and proxy settings from connection properties. + /// This is the base handler used by all HttpClient instances. + /// + /// Connection properties containing TLS and proxy configuration. + /// Configured HttpClientHandler. + public static HttpClientHandler CreateHandler(IReadOnlyDictionary properties) + { + var tlsOptions = HiveServer2TlsImpl.GetHttpTlsOptions(properties); + var proxyConfigurator = HiveServer2ProxyConfigurator.FromProperties(properties); + return HiveServer2TlsImpl.NewHttpClientHandler(tlsOptions, proxyConfigurator); + } + + /// + /// Creates a basic HttpClient with TLS and proxy settings. + /// Use this for simple HTTP operations that don't need auth handlers or retries. + /// + /// Connection properties containing TLS and proxy configuration. + /// Optional timeout. If not specified, uses default HttpClient timeout. + /// Configured HttpClient. + public static HttpClient CreateBasicHttpClient(IReadOnlyDictionary properties, TimeSpan? timeout = null) + { + var handler = CreateHandler(properties); + var httpClient = new HttpClient(handler); + + if (timeout.HasValue) + { + httpClient.Timeout = timeout.Value; + } + + return httpClient; + } + + /// + /// Creates an HttpClient for CloudFetch downloads. + /// Includes TLS and proxy settings but no auth headers (CloudFetch uses pre-signed URLs). + /// + /// Connection properties containing TLS and proxy configuration. + /// Configured HttpClient for CloudFetch. + public static HttpClient CreateCloudFetchHttpClient(IReadOnlyDictionary properties) + { + int timeoutMinutes = PropertyHelper.GetPositiveIntPropertyWithValidation( + properties, + DatabricksParameters.CloudFetchTimeoutMinutes, + DatabricksConstants.DefaultCloudFetchTimeoutMinutes); + + return CreateBasicHttpClient(properties, TimeSpan.FromMinutes(timeoutMinutes)); + } + + /// + /// Creates an HttpClient for feature flag API calls. + /// Includes TLS, proxy settings, and full authentication handler chain. + /// Supports PAT, OAuth M2M, OAuth U2M, and Workload Identity Federation. + /// + /// Connection properties. + /// The Databricks host (without protocol). + /// The driver version for the User-Agent. + /// Configured HttpClient for feature flags, or null if no valid authentication is available. + public static HttpClient? CreateFeatureFlagHttpClient( + IReadOnlyDictionary properties, + string host, + string assemblyVersion) + { + const int DefaultFeatureFlagTimeoutSeconds = 10; + + var timeoutSeconds = PropertyHelper.GetPositiveIntPropertyWithValidation( + properties, + DatabricksParameters.FeatureFlagTimeoutSeconds, + DefaultFeatureFlagTimeoutSeconds); + + // Create handler with full auth chain (including WIF support) + var handler = HttpHandlerFactory.CreateFeatureFlagHandler(properties, host, timeoutSeconds); + if (handler == null) + { + return null; + } + + var httpClient = new HttpClient(handler) + { + BaseAddress = new Uri($"https://{host}"), + Timeout = TimeSpan.FromSeconds(timeoutSeconds) + }; + + // Use same User-Agent format as other Databricks HTTP clients + string userAgent = $"DatabricksJDBCDriverOSS/{assemblyVersion} (ADBC)"; + string userAgentEntry = PropertyHelper.GetStringProperty(properties, "adbc.spark.user_agent_entry", string.Empty); + if (!string.IsNullOrWhiteSpace(userAgentEntry)) + { + userAgent = $"{userAgent} {userAgentEntry}"; + } + + httpClient.DefaultRequestHeaders.UserAgent.ParseAdd(userAgent); + + return httpClient; + } + + /// + /// Creates an HttpClient for OAuth token operations. + /// Includes TLS and proxy settings with configurable timeout. + /// + /// Connection properties. + /// Timeout for OAuth operations. + /// Configured HttpClient for OAuth. + public static HttpClient CreateOAuthHttpClient(IReadOnlyDictionary properties, TimeSpan timeout) + { + return CreateBasicHttpClient(properties, timeout); + } + } +} diff --git a/csharp/src/Http/HttpHandlerFactory.cs b/csharp/src/Http/HttpHandlerFactory.cs index 31778e5d..0f27363c 100644 --- a/csharp/src/Http/HttpHandlerFactory.cs +++ b/csharp/src/Http/HttpHandlerFactory.cs @@ -126,6 +126,151 @@ internal class HandlerResult public HttpClient? AuthHttpClient { get; set; } } + /// + /// Checks if OAuth authentication is configured in properties. + /// + private static bool IsOAuthEnabled(IReadOnlyDictionary properties) + { + return properties.TryGetValue(SparkParameters.AuthType, out string? authType) && + SparkAuthTypeParser.TryParse(authType, out SparkAuthType authTypeValue) && + authTypeValue == SparkAuthType.OAuth; + } + + /// + /// Gets the OAuth grant type from properties. + /// + private static DatabricksOAuthGrantType GetOAuthGrantType(IReadOnlyDictionary properties) + { + properties.TryGetValue(DatabricksParameters.OAuthGrantType, out string? grantTypeStr); + DatabricksOAuthGrantTypeParser.TryParse(grantTypeStr, out DatabricksOAuthGrantType grantType); + return grantType; + } + + /// + /// Creates an OAuthClientCredentialsProvider for M2M authentication. + /// Returns null if client ID or secret is missing. + /// + private static OAuthClientCredentialsProvider? CreateOAuthClientCredentialsProvider( + IReadOnlyDictionary properties, + HttpClient authHttpClient, + string host) + { + properties.TryGetValue(DatabricksParameters.OAuthClientId, out string? clientId); + properties.TryGetValue(DatabricksParameters.OAuthClientSecret, out string? clientSecret); + properties.TryGetValue(DatabricksParameters.OAuthScope, out string? scope); + + if (string.IsNullOrEmpty(clientId) || string.IsNullOrEmpty(clientSecret)) + { + return null; + } + + return new OAuthClientCredentialsProvider( + authHttpClient, + clientId, + clientSecret, + host, + scope: scope ?? "sql", + timeoutMinutes: 1); + } + + /// + /// Adds authentication handlers to the handler chain. + /// Returns the new handler chain with auth handlers added, or null if auth is required but not available. + /// + /// The current handler chain. + /// Connection properties. + /// The Databricks host. + /// HTTP client for auth operations (required for OAuth). + /// Identity federation client ID (optional). + /// Whether to enable JWT token refresh for access_token grant type. + /// Output: the auth HTTP client if created. + /// Handler with auth handlers added, or null if required auth is not available. + private static HttpMessageHandler? AddAuthHandlers( + HttpMessageHandler handler, + IReadOnlyDictionary properties, + string host, + HttpClient? authHttpClient, + string? identityFederationClientId, + bool enableTokenRefresh, + out HttpClient? authHttpClientOut) + { + authHttpClientOut = authHttpClient; + + if (IsOAuthEnabled(properties)) + { + if (authHttpClient == null) + { + return null; // OAuth requires auth HTTP client + } + + ITokenExchangeClient tokenExchangeClient = new TokenExchangeClient(authHttpClient, host); + + // Mandatory token exchange should be the inner handler so that it happens + // AFTER the OAuth handlers (e.g. after M2M sets the access token) + handler = new MandatoryTokenExchangeDelegatingHandler( + handler, + tokenExchangeClient, + identityFederationClientId); + + var grantType = GetOAuthGrantType(properties); + + if (grantType == DatabricksOAuthGrantType.ClientCredentials) + { + var tokenProvider = CreateOAuthClientCredentialsProvider(properties, authHttpClient, host); + if (tokenProvider == null) + { + return null; // Missing client credentials + } + handler = new OAuthDelegatingHandler(handler, tokenProvider); + } + else if (grantType == DatabricksOAuthGrantType.AccessToken) + { + string? accessToken = AuthHelper.GetTokenFromProperties(properties); + if (string.IsNullOrEmpty(accessToken)) + { + return null; // No access token + } + + if (enableTokenRefresh && + properties.TryGetValue(DatabricksParameters.TokenRenewLimit, out string? tokenRenewLimitStr) && + int.TryParse(tokenRenewLimitStr, out int tokenRenewLimit) && + tokenRenewLimit > 0 && + JwtTokenDecoder.TryGetExpirationTime(accessToken, out DateTime expiryTime)) + { + handler = new TokenRefreshDelegatingHandler( + handler, + tokenExchangeClient, + accessToken, + expiryTime, + tokenRenewLimit); + } + else + { + handler = new StaticBearerTokenHandler(handler, accessToken); + } + } + else + { + return null; // Unknown grant type + } + } + else + { + // Non-OAuth authentication: use static Bearer token if provided + string? accessToken = AuthHelper.GetTokenFromProperties(properties); + if (!string.IsNullOrEmpty(accessToken)) + { + handler = new StaticBearerTokenHandler(handler, accessToken); + } + else + { + return null; // No auth available + } + } + + return handler; + } + /// /// Creates HTTP handlers with OAuth and other delegating handlers. /// @@ -175,115 +320,83 @@ public static HandlerResult CreateHandlers(HandlerConfig config) authHandler = new ThriftErrorMessageHandler(authHandler); } + // Create auth HTTP client for OAuth if needed HttpClient? authHttpClient = null; - - // Check if OAuth authentication is configured - bool useOAuth = config.Properties.TryGetValue(SparkParameters.AuthType, out string? authType) && - SparkAuthTypeParser.TryParse(authType, out SparkAuthType authTypeValue) && - authTypeValue == SparkAuthType.OAuth; - - if (useOAuth) + if (IsOAuthEnabled(config.Properties)) { - // Create auth HTTP client for token operations authHttpClient = new HttpClient(authHandler) { Timeout = TimeSpan.FromMinutes(config.TimeoutMinutes) }; + } - ITokenExchangeClient tokenExchangeClient = new TokenExchangeClient(authHttpClient, config.Host); - - // Mandatory token exchange should be the inner handler so that it happens - // AFTER the OAuth handlers (e.g. after M2M sets the access token) - handler = new MandatoryTokenExchangeDelegatingHandler( - handler, - tokenExchangeClient, - config.IdentityFederationClientId); - - // Determine grant type (defaults to AccessToken if not specified) - config.Properties.TryGetValue(DatabricksParameters.OAuthGrantType, out string? grantTypeStr); - DatabricksOAuthGrantTypeParser.TryParse(grantTypeStr, out DatabricksOAuthGrantType grantType); + // Add auth handlers + var resultHandler = AddAuthHandlers( + handler, + config.Properties, + config.Host, + authHttpClient, + config.IdentityFederationClientId, + enableTokenRefresh: true, + out authHttpClient); - // Add OAuth client credentials handler if OAuth M2M authentication is being used - if (grantType == DatabricksOAuthGrantType.ClientCredentials) - { - config.Properties.TryGetValue(DatabricksParameters.OAuthClientId, out string? clientId); - config.Properties.TryGetValue(DatabricksParameters.OAuthClientSecret, out string? clientSecret); - config.Properties.TryGetValue(DatabricksParameters.OAuthScope, out string? scope); - - var tokenProvider = new OAuthClientCredentialsProvider( - authHttpClient, - clientId!, - clientSecret!, - config.Host, - scope: scope ?? "sql", - timeoutMinutes: 1 - ); + return new HandlerResult + { + Handler = resultHandler ?? handler, // Fall back to handler without auth if auth not configured + AuthHttpClient = authHttpClient + }; + } - handler = new OAuthDelegatingHandler(handler, tokenProvider); - } - // For access_token grant type, get the access token from properties - else if (grantType == DatabricksOAuthGrantType.AccessToken) - { - // Get the access token from properties - string accessToken = string.Empty; - if (config.Properties.TryGetValue(SparkParameters.AccessToken, out string? token)) - { - accessToken = token ?? string.Empty; - } - else if (config.Properties.TryGetValue(SparkParameters.Token, out string? fallbackToken)) - { - accessToken = fallbackToken ?? string.Empty; - } + /// + /// Creates an HTTP handler chain specifically for feature flag API calls. + /// This is a simplified version of CreateHandlers that includes auth but not + /// tracing, retry, or Thrift error handlers. + /// + /// Handler chain order (outermost to innermost): + /// 1. OAuth handlers (OAuthDelegatingHandler or StaticBearerTokenHandler) - token management + /// 2. MandatoryTokenExchangeDelegatingHandler (if OAuth) - workload identity federation + /// 3. Base HTTP handler - actual network communication + /// + /// This properly supports all authentication methods including: + /// - PAT (Personal Access Token) + /// - OAuth M2M (client_credentials) + /// - OAuth U2M (access_token) + /// - Workload Identity Federation (via MandatoryTokenExchangeDelegatingHandler) + /// + /// Connection properties containing configuration. + /// The Databricks host (without protocol). + /// HTTP client timeout in seconds. + /// Configured HttpMessageHandler, or null if no valid authentication is available. + public static HttpMessageHandler? CreateFeatureFlagHandler( + IReadOnlyDictionary properties, + string host, + int timeoutSeconds) + { + HttpMessageHandler baseHandler = HttpClientFactory.CreateHandler(properties); - if (!string.IsNullOrEmpty(accessToken)) - { - // Check if token renewal is configured and token is JWT - if (config.Properties.TryGetValue(DatabricksParameters.TokenRenewLimit, out string? tokenRenewLimitStr) && - int.TryParse(tokenRenewLimitStr, out int tokenRenewLimit) && - tokenRenewLimit > 0 && - JwtTokenDecoder.TryGetExpirationTime(accessToken, out DateTime expiryTime)) - { - // Use TokenRefreshDelegatingHandler for JWT tokens with renewal configured - handler = new TokenRefreshDelegatingHandler( - handler, - tokenExchangeClient, - accessToken, - expiryTime, - tokenRenewLimit); - } - else - { - // Use StaticBearerTokenHandler for tokens without renewal - handler = new StaticBearerTokenHandler(handler, accessToken); - } - } - } - } - else + // Create auth HTTP client for OAuth if needed + HttpClient? authHttpClient = null; + if (IsOAuthEnabled(properties)) { - // Non-OAuth authentication: use static Bearer token if provided - // Try access_token first, then fall back to token - string accessToken = string.Empty; - if (config.Properties.TryGetValue(SparkParameters.AccessToken, out string? token)) - { - accessToken = token ?? string.Empty; - } - else if (config.Properties.TryGetValue(SparkParameters.Token, out string? fallbackToken)) - { - accessToken = fallbackToken ?? string.Empty; - } - - if (!string.IsNullOrEmpty(accessToken)) + HttpMessageHandler baseAuthHandler = HttpClientFactory.CreateHandler(properties); + authHttpClient = new HttpClient(baseAuthHandler) { - handler = new StaticBearerTokenHandler(handler, accessToken); - } + Timeout = TimeSpan.FromSeconds(timeoutSeconds) + }; } - return new HandlerResult - { - Handler = handler, - AuthHttpClient = authHttpClient - }; + // Get identity federation client ID + properties.TryGetValue(DatabricksParameters.IdentityFederationClientId, out string? identityFederationClientId); + + // Add auth handlers (no token refresh for feature flags) + return AddAuthHandlers( + baseHandler, + properties, + host, + authHttpClient, + identityFederationClientId, + enableTokenRefresh: false, + out _); } } } diff --git a/csharp/src/StatementExecution/StatementExecutionConnection.cs b/csharp/src/StatementExecution/StatementExecutionConnection.cs index d18a456b..bc2a5a53 100644 --- a/csharp/src/StatementExecution/StatementExecutionConnection.cs +++ b/csharp/src/StatementExecution/StatementExecutionConnection.cs @@ -221,11 +221,8 @@ private StatementExecutionConnection( // Create a separate HTTP client for CloudFetch downloads (without auth headers) // This is needed because CloudFetch uses pre-signed URLs from cloud storage (S3, Azure Blob, etc.) // and those services reject requests with multiple authentication methods - int timeoutMinutes = PropertyHelper.GetPositiveIntPropertyWithValidation(properties, DatabricksParameters.CloudFetchTimeoutMinutes, DatabricksConstants.DefaultCloudFetchTimeoutMinutes); - _cloudFetchHttpClient = new HttpClient() - { - Timeout = TimeSpan.FromMinutes(timeoutMinutes) - }; + // Note: We still need proxy and TLS configuration for corporate network access + _cloudFetchHttpClient = HttpClientFactory.CreateCloudFetchHttpClient(properties); // Create REST API client _client = new StatementExecutionClient(_httpClient, baseUrl); @@ -251,8 +248,8 @@ private HttpClient CreateHttpClient(IReadOnlyDictionary properti var config = new HttpHandlerFactory.HandlerConfig { - BaseHandler = new HttpClientHandler(), - BaseAuthHandler = new HttpClientHandler(), + BaseHandler = HttpClientFactory.CreateHandler(properties), + BaseAuthHandler = HttpClientFactory.CreateHandler(properties), Properties = properties, Host = GetHost(properties), ActivityTracer = this, diff --git a/csharp/src/Telemetry/CircuitBreakerManager.cs b/csharp/src/Telemetry/CircuitBreakerManager.cs new file mode 100644 index 00000000..22e05bc8 --- /dev/null +++ b/csharp/src/Telemetry/CircuitBreakerManager.cs @@ -0,0 +1,202 @@ +/* +* Copyright (c) 2025 ADBC Drivers Contributors +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Concurrent; +using System.Diagnostics; + +namespace AdbcDrivers.Databricks.Telemetry +{ + /// + /// Singleton that manages circuit breakers per host. + /// + /// + /// This class implements the per-host circuit breaker pattern from the JDBC driver: + /// - Each host gets its own circuit breaker for isolation + /// - One failing endpoint does not affect other endpoints + /// - Thread-safe using ConcurrentDictionary + /// + /// JDBC Reference: CircuitBreakerManager.java:25 + /// + internal sealed class CircuitBreakerManager + { + private static readonly CircuitBreakerManager s_instance = new CircuitBreakerManager(); + + private readonly ConcurrentDictionary _circuitBreakers; + private readonly CircuitBreakerConfig _defaultConfig; + + /// + /// Gets the singleton instance of the CircuitBreakerManager. + /// + public static CircuitBreakerManager GetInstance() => s_instance; + + /// + /// Creates a new CircuitBreakerManager with default configuration. + /// + internal CircuitBreakerManager() + : this(new CircuitBreakerConfig()) + { + } + + /// + /// Creates a new CircuitBreakerManager with the specified default configuration. + /// + /// The default configuration for new circuit breakers. + internal CircuitBreakerManager(CircuitBreakerConfig defaultConfig) + { + _circuitBreakers = new ConcurrentDictionary(StringComparer.OrdinalIgnoreCase); + _defaultConfig = defaultConfig ?? throw new ArgumentNullException(nameof(defaultConfig)); + } + + /// + /// Gets or creates a circuit breaker for the specified host. + /// + /// The host (Databricks workspace URL) to get or create a circuit breaker for. + /// The circuit breaker for the host. + /// Thrown when host is null or whitespace. + /// + /// This method is thread-safe. If multiple threads call this method simultaneously + /// for the same host, they will all receive the same circuit breaker instance. + /// The circuit breaker is created lazily on first access. + /// + public CircuitBreaker GetCircuitBreaker(string host) + { + if (string.IsNullOrWhiteSpace(host)) + { + throw new ArgumentException("Host cannot be null or whitespace.", nameof(host)); + } + + var circuitBreaker = _circuitBreakers.GetOrAdd(host, _ => + { + Debug.WriteLine($"[DEBUG] CircuitBreakerManager: Creating circuit breaker for host '{host}'"); + return new CircuitBreaker(_defaultConfig); + }); + + return circuitBreaker; + } + + /// + /// Gets or creates a circuit breaker for the specified host with custom configuration. + /// + /// The host (Databricks workspace URL) to get or create a circuit breaker for. + /// The configuration to use for this host's circuit breaker. + /// The circuit breaker for the host. + /// Thrown when host is null or whitespace. + /// Thrown when config is null. + /// + /// Note: If a circuit breaker already exists for the host, the existing instance + /// with its original configuration is returned. The provided config is only used + /// when creating a new circuit breaker. + /// + public CircuitBreaker GetCircuitBreaker(string host, CircuitBreakerConfig config) + { + if (string.IsNullOrWhiteSpace(host)) + { + throw new ArgumentException("Host cannot be null or whitespace.", nameof(host)); + } + + if (config == null) + { + throw new ArgumentNullException(nameof(config)); + } + + var circuitBreaker = _circuitBreakers.GetOrAdd(host, _ => + { + Debug.WriteLine($"[DEBUG] CircuitBreakerManager: Creating circuit breaker for host '{host}' with custom config"); + return new CircuitBreaker(config); + }); + + return circuitBreaker; + } + + /// + /// Gets the number of hosts with circuit breakers. + /// + internal int CircuitBreakerCount => _circuitBreakers.Count; + + /// + /// Checks if a circuit breaker exists for the specified host. + /// + /// The host to check. + /// True if a circuit breaker exists, false otherwise. + internal bool HasCircuitBreaker(string host) + { + if (string.IsNullOrWhiteSpace(host)) + { + return false; + } + + return _circuitBreakers.ContainsKey(host); + } + + /// + /// Tries to get an existing circuit breaker for the specified host. + /// Does not create a new circuit breaker if one doesn't exist. + /// + /// The host to get the circuit breaker for. + /// The circuit breaker if found, null otherwise. + /// True if the circuit breaker was found, false otherwise. + internal bool TryGetCircuitBreaker(string host, out CircuitBreaker? circuitBreaker) + { + circuitBreaker = null; + + if (string.IsNullOrWhiteSpace(host)) + { + return false; + } + + if (_circuitBreakers.TryGetValue(host, out var foundCircuitBreaker)) + { + circuitBreaker = foundCircuitBreaker; + return true; + } + + return false; + } + + /// + /// Removes the circuit breaker for the specified host. + /// + /// The host to remove the circuit breaker for. + /// True if the circuit breaker was removed, false if it didn't exist. + internal bool RemoveCircuitBreaker(string host) + { + if (string.IsNullOrWhiteSpace(host)) + { + return false; + } + + var removed = _circuitBreakers.TryRemove(host, out _); + + if (removed) + { + Debug.WriteLine($"[DEBUG] CircuitBreakerManager: Removed circuit breaker for host '{host}'"); + } + + return removed; + } + + /// + /// Clears all circuit breakers. + /// This is primarily for testing purposes. + /// + internal void Clear() + { + _circuitBreakers.Clear(); + Debug.WriteLine("[DEBUG] CircuitBreakerManager: Cleared all circuit breakers"); + } + } +} diff --git a/csharp/src/Telemetry/CircuitBreakerTelemetryExporter.cs b/csharp/src/Telemetry/CircuitBreakerTelemetryExporter.cs new file mode 100644 index 00000000..8d26916a --- /dev/null +++ b/csharp/src/Telemetry/CircuitBreakerTelemetryExporter.cs @@ -0,0 +1,143 @@ +/* +* Copyright (c) 2025 ADBC Drivers Contributors +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Threading; +using System.Threading.Tasks; +using AdbcDrivers.Databricks.Telemetry.Models; + +namespace AdbcDrivers.Databricks.Telemetry +{ + /// + /// Wraps a telemetry exporter with circuit breaker protection. + /// + /// + /// This exporter implements the circuit breaker pattern to protect against + /// failing telemetry endpoints: + /// - When circuit is closed: Exports events via the inner exporter + /// - When circuit is open: Drops events silently (logged at DEBUG level) + /// - Circuit breaker MUST see exceptions before they are swallowed + /// + /// The circuit breaker is per-host, managed by CircuitBreakerManager. + /// + /// JDBC Reference: CircuitBreakerTelemetryPushClient.java:15 + /// + internal sealed class CircuitBreakerTelemetryExporter : ITelemetryExporter + { + private readonly string _host; + private readonly ITelemetryExporter _innerExporter; + private readonly CircuitBreakerManager _circuitBreakerManager; + + /// + /// Gets the host URL for this exporter. + /// + internal string Host => _host; + + /// + /// Gets the inner telemetry exporter. + /// + internal ITelemetryExporter InnerExporter => _innerExporter; + + /// + /// Creates a new CircuitBreakerTelemetryExporter. + /// + /// The Databricks host URL. + /// The inner telemetry exporter to wrap. + /// Thrown when host is null or whitespace. + /// Thrown when innerExporter is null. + public CircuitBreakerTelemetryExporter(string host, ITelemetryExporter innerExporter) + : this(host, innerExporter, CircuitBreakerManager.GetInstance()) + { + } + + /// + /// Creates a new CircuitBreakerTelemetryExporter with a specified CircuitBreakerManager. + /// + /// The Databricks host URL. + /// The inner telemetry exporter to wrap. + /// The circuit breaker manager to use. + /// Thrown when host is null or whitespace. + /// Thrown when innerExporter or circuitBreakerManager is null. + /// + /// This constructor is primarily for testing to allow injecting a mock CircuitBreakerManager. + /// + internal CircuitBreakerTelemetryExporter( + string host, + ITelemetryExporter innerExporter, + CircuitBreakerManager circuitBreakerManager) + { + if (string.IsNullOrWhiteSpace(host)) + { + throw new ArgumentException("Host cannot be null or whitespace.", nameof(host)); + } + + _host = host; + _innerExporter = innerExporter ?? throw new ArgumentNullException(nameof(innerExporter)); + _circuitBreakerManager = circuitBreakerManager ?? throw new ArgumentNullException(nameof(circuitBreakerManager)); + } + + /// + /// Export telemetry frontend logs to the backend service with circuit breaker protection. + /// + /// The list of telemetry frontend logs to export. + /// Cancellation token. + /// A task representing the asynchronous export operation. + /// + /// This method never throws exceptions (except for cancellation). All errors are: + /// 1. First seen by the circuit breaker (to track failures) + /// 2. Then swallowed and logged at TRACE level + /// + /// When the circuit is open, events are dropped silently and logged at DEBUG level. + /// + public async Task ExportAsync(IReadOnlyList logs, CancellationToken ct = default) + { + if (logs == null || logs.Count == 0) + { + return; + } + + var circuitBreaker = _circuitBreakerManager.GetCircuitBreaker(_host); + + try + { + // Execute through circuit breaker - it tracks failures BEFORE swallowing + await circuitBreaker.ExecuteAsync(async () => + { + await _innerExporter.ExportAsync(logs, ct).ConfigureAwait(false); + }).ConfigureAwait(false); + } + catch (CircuitBreakerOpenException) + { + // Circuit is open - drop events silently + // Log at DEBUG level per design doc (circuit breaker state changes) + Debug.WriteLine($"[DEBUG] CircuitBreakerTelemetryExporter: Circuit breaker OPEN for host '{_host}' - dropping {logs.Count} telemetry events"); + } + catch (OperationCanceledException) + { + // Don't swallow cancellation - let it propagate + throw; + } + catch (Exception ex) + { + // All other exceptions swallowed AFTER circuit breaker saw them + // Log at TRACE level to avoid customer anxiety per design doc + Debug.WriteLine($"[TRACE] CircuitBreakerTelemetryExporter: Error exporting telemetry for host '{_host}': {ex.Message}"); + } + } + } +} diff --git a/csharp/src/Telemetry/DatabricksActivityListener.cs b/csharp/src/Telemetry/DatabricksActivityListener.cs new file mode 100644 index 00000000..96b36e91 --- /dev/null +++ b/csharp/src/Telemetry/DatabricksActivityListener.cs @@ -0,0 +1,337 @@ +/* +* Copyright (c) 2025 ADBC Drivers Contributors +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Diagnostics; +using System.Threading; +using System.Threading.Tasks; + +namespace AdbcDrivers.Databricks.Telemetry +{ + /// + /// Custom ActivityListener that aggregates metrics from Activity events + /// and exports them to Databricks telemetry service. + /// + /// + /// This class: + /// - Listens to activities from the "Databricks.Adbc.Driver" ActivitySource + /// - Uses Sample callback to respect feature flag (enabled via TelemetryConfiguration) + /// - Delegates metric processing to MetricsAggregator on ActivityStopped + /// - All callbacks wrapped in try-catch with TRACE logging to prevent impacting driver operations + /// - All exceptions are swallowed to prevent impacting driver operations + /// + /// JDBC Reference: DatabricksDriverTelemetryHelper.java - sets up telemetry collection + /// + internal sealed class DatabricksActivityListener : IDisposable + { + /// + /// The ActivitySource name that this listener listens to. + /// + public const string DatabricksActivitySourceName = "Databricks.Adbc.Driver"; + + private readonly MetricsAggregator _aggregator; + private readonly TelemetryConfiguration _config; + private readonly Func? _featureFlagChecker; + private ActivityListener? _listener; + private bool _started; + private bool _disposed; + private readonly object _lock = new object(); + + /// + /// Creates a new DatabricksActivityListener. + /// + /// The MetricsAggregator to delegate activity processing to. + /// The telemetry configuration. + /// + /// Optional function to check if telemetry is enabled at runtime. + /// If provided, this is called on each activity sample. If null, uses config.Enabled. + /// + /// Thrown when aggregator or config is null. + public DatabricksActivityListener( + MetricsAggregator aggregator, + TelemetryConfiguration config, + Func? featureFlagChecker = null) + { + _aggregator = aggregator ?? throw new ArgumentNullException(nameof(aggregator)); + _config = config ?? throw new ArgumentNullException(nameof(config)); + _featureFlagChecker = featureFlagChecker; + } + + /// + /// Gets whether the listener has been started. + /// + public bool IsStarted => _started; + + /// + /// Gets whether the listener has been disposed. + /// + public bool IsDisposed => _disposed; + + /// + /// Starts listening to activities from the Databricks ActivitySource. + /// + /// + /// This method creates and registers an ActivityListener with the following configuration: + /// - ShouldListenTo returns true only for "Databricks.Adbc.Driver" ActivitySource + /// - Sample returns AllDataAndRecorded when telemetry is enabled, None when disabled + /// - ActivityStopped delegates to MetricsAggregator.ProcessActivity + /// + /// This method is thread-safe and idempotent (calling multiple times has no additional effect). + /// All exceptions are caught and logged at TRACE level. + /// + public void Start() + { + lock (_lock) + { + if (_started || _disposed) + { + return; + } + + try + { + _listener = CreateListener(); + ActivitySource.AddActivityListener(_listener); + _started = true; + + Debug.WriteLine("[TRACE] DatabricksActivityListener: Started listening to activities"); + } + catch (Exception ex) + { + // Swallow all exceptions per telemetry requirement + Debug.WriteLine($"[TRACE] DatabricksActivityListener: Error starting listener: {ex.Message}"); + } + } + } + + /// + /// Stops listening to activities and flushes any pending metrics. + /// + /// Cancellation token. + /// A task representing the asynchronous stop operation. + /// + /// This method: + /// 1. Disposes the ActivityListener to stop receiving new activities + /// 2. Flushes all pending metrics via MetricsAggregator.FlushAsync + /// 3. Disposes the MetricsAggregator + /// + /// This method is thread-safe. All exceptions are caught and logged at TRACE level. + /// + public async Task StopAsync(CancellationToken ct = default) + { + lock (_lock) + { + if (!_started || _disposed) + { + return; + } + + try + { + // Stop receiving new activities + _listener?.Dispose(); + _listener = null; + _started = false; + + Debug.WriteLine("[TRACE] DatabricksActivityListener: Stopped listening to activities"); + } + catch (Exception ex) + { + // Swallow all exceptions per telemetry requirement + Debug.WriteLine($"[TRACE] DatabricksActivityListener: Error stopping listener: {ex.Message}"); + } + } + + // Flush pending metrics outside the lock to avoid deadlocks + try + { + await _aggregator.FlushAsync(ct).ConfigureAwait(false); + Debug.WriteLine("[TRACE] DatabricksActivityListener: Flushed pending metrics"); + } + catch (OperationCanceledException) + { + // Don't swallow cancellation + throw; + } + catch (Exception ex) + { + // Swallow all other exceptions per telemetry requirement + Debug.WriteLine($"[TRACE] DatabricksActivityListener: Error flushing metrics: {ex.Message}"); + } + } + + /// + /// Disposes the DatabricksActivityListener. + /// + /// + /// This method stops listening and disposes resources synchronously. + /// Use StopAsync for graceful shutdown with flush. + /// All exceptions are caught and logged at TRACE level. + /// + public void Dispose() + { + if (_disposed) + { + return; + } + + lock (_lock) + { + if (_disposed) + { + return; + } + + _disposed = true; + + try + { + _listener?.Dispose(); + _listener = null; + _started = false; + + // Dispose aggregator (this also flushes synchronously) + _aggregator.Dispose(); + + Debug.WriteLine("[TRACE] DatabricksActivityListener: Disposed"); + } + catch (Exception ex) + { + // Swallow all exceptions per telemetry requirement + Debug.WriteLine($"[TRACE] DatabricksActivityListener: Error during dispose: {ex.Message}"); + } + } + } + + #region Private Methods + + /// + /// Creates the ActivityListener with the appropriate callbacks. + /// + private ActivityListener CreateListener() + { + return new ActivityListener + { + ShouldListenTo = ShouldListenTo, + Sample = Sample, + ActivityStarted = OnActivityStarted, + ActivityStopped = OnActivityStopped + }; + } + + /// + /// Determines if the listener should listen to the given ActivitySource. + /// + /// The ActivitySource to check. + /// True if the source is "Databricks.Adbc.Driver", false otherwise. + private bool ShouldListenTo(ActivitySource source) + { + try + { + return string.Equals(source.Name, DatabricksActivitySourceName, StringComparison.Ordinal); + } + catch (Exception ex) + { + // Swallow all exceptions per telemetry requirement + Debug.WriteLine($"[TRACE] DatabricksActivityListener: Error in ShouldListenTo: {ex.Message}"); + return false; + } + } + + /// + /// Determines the sampling result for activity creation. + /// + /// The activity creation options. + /// + /// AllDataAndRecorded when telemetry is enabled, None when disabled. + /// + private ActivitySamplingResult Sample(ref ActivityCreationOptions options) + { + try + { + // Check feature flag if provided, otherwise use config + bool enabled = _featureFlagChecker?.Invoke() ?? _config.Enabled; + + return enabled + ? ActivitySamplingResult.AllDataAndRecorded + : ActivitySamplingResult.None; + } + catch (Exception ex) + { + // Swallow all exceptions per telemetry requirement + // On error, return None (don't sample) as a safe default + Debug.WriteLine($"[TRACE] DatabricksActivityListener: Error in Sample: {ex.Message}"); + return ActivitySamplingResult.None; + } + } + + /// + /// Called when an activity starts. + /// + /// The started activity. + /// + /// Currently a no-op. The listener primarily processes activities on stop. + /// All exceptions are caught and logged at TRACE level. + /// + private void OnActivityStarted(Activity activity) + { + try + { + // Currently no processing needed on start + // The listener primarily processes activities when they stop + } + catch (Exception ex) + { + // Swallow all exceptions per telemetry requirement + Debug.WriteLine($"[TRACE] DatabricksActivityListener: Error in OnActivityStarted: {ex.Message}"); + } + } + + /// + /// Called when an activity stops. + /// + /// The stopped activity. + /// + /// Delegates to MetricsAggregator.ProcessActivity to extract and aggregate metrics. + /// Only processes activities when telemetry is enabled (checked via feature flag or config). + /// All exceptions are caught and logged at TRACE level. + /// + private void OnActivityStopped(Activity activity) + { + try + { + // Check if telemetry is enabled before processing + // This is needed because ActivityStopped is called even if Sample returned None + // (when another listener requested the activity) + bool enabled = _featureFlagChecker?.Invoke() ?? _config.Enabled; + if (!enabled) + { + return; + } + + _aggregator.ProcessActivity(activity); + } + catch (Exception ex) + { + // Swallow all exceptions per telemetry requirement + // Use TRACE level to avoid customer anxiety + Debug.WriteLine($"[TRACE] DatabricksActivityListener: Error in OnActivityStopped: {ex.Message}"); + } + } + + #endregion + } +} diff --git a/csharp/src/Telemetry/DatabricksTelemetryExporter.cs b/csharp/src/Telemetry/DatabricksTelemetryExporter.cs new file mode 100644 index 00000000..e0c8d6dd --- /dev/null +++ b/csharp/src/Telemetry/DatabricksTelemetryExporter.cs @@ -0,0 +1,285 @@ +/* +* Copyright (c) 2025 ADBC Drivers Contributors +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Net.Http; +using System.Text; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using AdbcDrivers.Databricks.Telemetry.Models; + +namespace AdbcDrivers.Databricks.Telemetry +{ + /// + /// Exports telemetry events to the Databricks telemetry service. + /// + /// + /// This exporter: + /// - Creates TelemetryRequest wrapper with uploadTime and protoLogs + /// - Uses /telemetry-ext for authenticated requests + /// - Uses /telemetry-unauth for unauthenticated requests + /// - Implements retry logic for transient failures + /// - Never throws exceptions (all swallowed and traced at Verbose level) + /// + /// JDBC Reference: TelemetryPushClient.java + /// + internal sealed class DatabricksTelemetryExporter : ITelemetryExporter + { + /// + /// Authenticated telemetry endpoint path. + /// + internal const string AuthenticatedEndpoint = "/telemetry-ext"; + + /// + /// Unauthenticated telemetry endpoint path. + /// + internal const string UnauthenticatedEndpoint = "/telemetry-unauth"; + + /// + /// Activity source for telemetry exporter tracing. + /// + private static readonly ActivitySource s_activitySource = new ActivitySource("AdbcDrivers.Databricks.TelemetryExporter"); + + private readonly HttpClient _httpClient; + private readonly string _host; + private readonly bool _isAuthenticated; + private readonly TelemetryConfiguration _config; + + private static readonly JsonSerializerOptions s_jsonOptions = new JsonSerializerOptions + { + PropertyNamingPolicy = JsonNamingPolicy.CamelCase + }; + + /// + /// Gets the host URL for the telemetry endpoint. + /// + internal string Host => _host; + + /// + /// Gets whether this exporter uses authenticated endpoints. + /// + internal bool IsAuthenticated => _isAuthenticated; + + /// + /// Creates a new DatabricksTelemetryExporter. + /// + /// The HTTP client to use for sending requests. + /// The Databricks host URL. + /// Whether to use authenticated endpoints. + /// The telemetry configuration. + /// Thrown when httpClient, host, or config is null. + /// Thrown when host is empty or whitespace. + public DatabricksTelemetryExporter( + HttpClient httpClient, + string host, + bool isAuthenticated, + TelemetryConfiguration config) + { + _httpClient = httpClient ?? throw new ArgumentNullException(nameof(httpClient)); + + if (string.IsNullOrWhiteSpace(host)) + { + throw new ArgumentException("Host cannot be null or whitespace.", nameof(host)); + } + + _host = host; + _isAuthenticated = isAuthenticated; + _config = config ?? throw new ArgumentNullException(nameof(config)); + } + + /// + /// Export telemetry frontend logs to the Databricks telemetry service. + /// + /// The list of telemetry frontend logs to export. + /// Cancellation token. + /// + /// True if the export succeeded (HTTP 2xx response), false if it failed. + /// Returns true for empty/null logs since there's nothing to export. + /// + /// + /// This method never throws exceptions. All errors are caught and traced using ActivitySource. + /// + public async Task ExportAsync(IReadOnlyList logs, CancellationToken ct = default) + { + if (logs == null || logs.Count == 0) + { + return true; + } + + try + { + var request = CreateTelemetryRequest(logs); + var json = SerializeRequest(request); + + return await SendWithRetryAsync(json, ct).ConfigureAwait(false); + } + catch (OperationCanceledException) + { + // Don't swallow cancellation - let it propagate + throw; + } + catch (Exception ex) + { + // Swallow all other exceptions per telemetry requirement + // Trace at Verbose level to avoid customer anxiety + Activity.Current?.AddEvent(new ActivityEvent("telemetry.export.error", + tags: new ActivityTagsCollection + { + { "error.message", ex.Message }, + { "error.type", ex.GetType().Name } + })); + return false; + } + } + + /// + /// Creates a TelemetryRequest from a list of frontend logs. + /// + internal TelemetryRequest CreateTelemetryRequest(IReadOnlyList logs) + { + var protoLogs = new List(logs.Count); + + foreach (var log in logs) + { + var serializedLog = JsonSerializer.Serialize(log, s_jsonOptions); + protoLogs.Add(serializedLog); + } + + return new TelemetryRequest + { + UploadTime = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds(), + ProtoLogs = protoLogs + }; + } + + /// + /// Serializes the telemetry request to JSON. + /// + internal string SerializeRequest(TelemetryRequest request) + { + return JsonSerializer.Serialize(request, s_jsonOptions); + } + + /// + /// Gets the telemetry endpoint URL based on authentication status. + /// + internal string GetEndpointUrl() + { + var endpoint = _isAuthenticated ? AuthenticatedEndpoint : UnauthenticatedEndpoint; + var host = _host.TrimEnd('/'); + return $"{host}{endpoint}"; + } + + /// + /// Sends the telemetry request with retry logic. + /// + /// True if the request succeeded, false otherwise. + private async Task SendWithRetryAsync(string json, CancellationToken ct) + { + var endpointUrl = GetEndpointUrl(); + Exception? lastException = null; + + for (int attempt = 0; attempt <= _config.MaxRetries; attempt++) + { + try + { + if (attempt > 0 && _config.RetryDelayMs > 0) + { + await Task.Delay(_config.RetryDelayMs, ct).ConfigureAwait(false); + } + + await SendRequestAsync(endpointUrl, json, ct).ConfigureAwait(false); + + Activity.Current?.AddEvent(new ActivityEvent("telemetry.export.success", + tags: new ActivityTagsCollection + { + { "endpoint", endpointUrl }, + { "attempt", attempt + 1 } + })); + return true; + } + catch (OperationCanceledException) + { + // Don't retry on cancellation + throw; + } + catch (HttpRequestException ex) + { + lastException = ex; + + // Check if this is a terminal error that shouldn't be retried + if (ExceptionClassifier.IsTerminalException(ex)) + { + Activity.Current?.AddEvent(new ActivityEvent("telemetry.export.terminal_error", + tags: new ActivityTagsCollection + { + { "error.message", ex.Message }, + { "error.type", ex.GetType().Name } + })); + return false; + } + + Activity.Current?.AddEvent(new ActivityEvent("telemetry.export.retry", + tags: new ActivityTagsCollection + { + { "attempt", attempt + 1 }, + { "max_attempts", _config.MaxRetries + 1 }, + { "error.message", ex.Message } + })); + } + catch (Exception ex) + { + lastException = ex; + Activity.Current?.AddEvent(new ActivityEvent("telemetry.export.retry", + tags: new ActivityTagsCollection + { + { "attempt", attempt + 1 }, + { "max_attempts", _config.MaxRetries + 1 }, + { "error.message", ex.Message }, + { "error.type", ex.GetType().Name } + })); + } + } + + if (lastException != null) + { + Activity.Current?.AddEvent(new ActivityEvent("telemetry.export.exhausted", + tags: new ActivityTagsCollection + { + { "total_attempts", _config.MaxRetries + 1 }, + { "error.message", lastException.Message }, + { "error.type", lastException.GetType().Name } + })); + } + + return false; + } + + /// + /// Sends the HTTP request to the telemetry endpoint. + /// + private async Task SendRequestAsync(string endpointUrl, string json, CancellationToken ct) + { + using var content = new StringContent(json, Encoding.UTF8, "application/json"); + using var response = await _httpClient.PostAsync(endpointUrl, content, ct).ConfigureAwait(false); + + response.EnsureSuccessStatusCode(); + } + } +} diff --git a/csharp/src/Telemetry/ITelemetryClient.cs b/csharp/src/Telemetry/ITelemetryClient.cs new file mode 100644 index 00000000..d877239d --- /dev/null +++ b/csharp/src/Telemetry/ITelemetryClient.cs @@ -0,0 +1,74 @@ +/* +* Copyright (c) 2025 ADBC Drivers Contributors +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using AdbcDrivers.Databricks.Telemetry.Models; + +namespace AdbcDrivers.Databricks.Telemetry +{ + /// + /// Interface for a telemetry client that exports telemetry events to a backend service. + /// + /// + /// This interface represents a per-host telemetry client that is shared across + /// multiple connections to the same Databricks workspace. The client is managed + /// by with reference counting. + /// + /// Implementations must: + /// - Be thread-safe for concurrent access from multiple connections + /// - Never throw exceptions from ExportAsync (all caught and logged at TRACE level) + /// - Support graceful shutdown via CloseAsync + /// + /// JDBC Reference: TelemetryClient.java + /// + public interface ITelemetryClient + { + /// + /// Gets the host URL for this telemetry client. + /// + string Host { get; } + + /// + /// Exports telemetry frontend logs to the backend service. + /// + /// The list of telemetry frontend logs to export. + /// Cancellation token. + /// A task representing the asynchronous export operation. + /// + /// This method must never throw exceptions (except for cancellation). + /// All errors should be caught and logged at TRACE level internally. + /// + Task ExportAsync(IReadOnlyList logs, CancellationToken ct = default); + + /// + /// Closes the telemetry client and releases all resources. + /// + /// A task representing the asynchronous close operation. + /// + /// This method is called by TelemetryClientManager when the last connection + /// using this client is closed. The implementation should: + /// 1. Flush any pending telemetry events + /// 2. Cancel any background tasks + /// 3. Dispose all resources + /// + /// This method must never throw exceptions. All errors should be caught + /// and logged at TRACE level internally. + /// + Task CloseAsync(); + } +} diff --git a/csharp/src/Telemetry/ITelemetryExporter.cs b/csharp/src/Telemetry/ITelemetryExporter.cs new file mode 100644 index 00000000..934a5bba --- /dev/null +++ b/csharp/src/Telemetry/ITelemetryExporter.cs @@ -0,0 +1,53 @@ +/* +* Copyright (c) 2025 ADBC Drivers Contributors +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using AdbcDrivers.Databricks.Telemetry.Models; + +namespace AdbcDrivers.Databricks.Telemetry +{ + /// + /// Interface for exporting telemetry events to a backend service. + /// + /// + /// Implementations of this interface must be safe to call from any context. + /// All methods should be non-blocking and should never throw exceptions + /// (exceptions should be caught and logged at TRACE level internally). + /// This follows the telemetry design principle that telemetry operations + /// should never impact driver operations. + /// + public interface ITelemetryExporter + { + /// + /// Export telemetry frontend logs to the backend service. + /// + /// The list of telemetry frontend logs to export. + /// Cancellation token. + /// + /// A task that resolves to true if the export succeeded (HTTP 2xx response), + /// or false if the export failed or was skipped. Returns true for empty/null logs + /// since there's nothing to export (no failure occurred). + /// + /// + /// This method must never throw exceptions. All errors should be caught + /// and logged at TRACE level internally. The method may return early + /// if the circuit breaker is open or if there are no logs to export. + /// + Task ExportAsync(IReadOnlyList logs, CancellationToken ct = default); + } +} diff --git a/csharp/src/Telemetry/MetricsAggregator.cs b/csharp/src/Telemetry/MetricsAggregator.cs new file mode 100644 index 00000000..29d8e3df --- /dev/null +++ b/csharp/src/Telemetry/MetricsAggregator.cs @@ -0,0 +1,737 @@ +/* +* Copyright (c) 2025 ADBC Drivers Contributors +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using AdbcDrivers.Databricks.Telemetry.Models; +using AdbcDrivers.Databricks.Telemetry.TagDefinitions; + +namespace AdbcDrivers.Databricks.Telemetry +{ + /// + /// Aggregates metrics from activities by statement_id and includes session_id. + /// Follows JDBC driver pattern: aggregation by statement, export with both IDs. + /// + /// + /// This class: + /// - Aggregates Activity data by statement_id using ConcurrentDictionary + /// - Connection events emit immediately (no aggregation needed) + /// - Statement events aggregate until CompleteStatement is called + /// - Terminal exceptions flush immediately, retryable exceptions buffer until complete + /// - Uses TelemetryTagRegistry to filter tags + /// - Creates TelemetryFrontendLog wrapper with workspace_id + /// - All exceptions swallowed (logged at TRACE level) + /// + /// JDBC Reference: TelemetryCollector.java + /// + internal sealed class MetricsAggregator : IDisposable + { + private readonly ITelemetryExporter _exporter; + private readonly TelemetryConfiguration _config; + private readonly long _workspaceId; + private readonly string _userAgent; + private readonly ConcurrentDictionary _statementContexts; + private readonly ConcurrentQueue _pendingEvents; + private readonly object _flushLock = new object(); + private readonly Timer _flushTimer; + private bool _disposed; + + /// + /// Gets the number of pending events waiting to be flushed. + /// + internal int PendingEventCount => _pendingEvents.Count; + + /// + /// Gets the number of active statement contexts being tracked. + /// + internal int ActiveStatementCount => _statementContexts.Count; + + /// + /// Creates a new MetricsAggregator. + /// + /// The telemetry exporter to use for flushing events. + /// The telemetry configuration. + /// The Databricks workspace ID for all events. + /// The user agent string for client context. + /// Thrown when exporter or config is null. + public MetricsAggregator( + ITelemetryExporter exporter, + TelemetryConfiguration config, + long workspaceId, + string? userAgent = null) + { + _exporter = exporter ?? throw new ArgumentNullException(nameof(exporter)); + _config = config ?? throw new ArgumentNullException(nameof(config)); + _workspaceId = workspaceId; + _userAgent = userAgent ?? "AdbcDatabricksDriver"; + _statementContexts = new ConcurrentDictionary(); + _pendingEvents = new ConcurrentQueue(); + + // Start flush timer + _flushTimer = new Timer( + OnFlushTimerTick, + null, + _config.FlushIntervalMs, + _config.FlushIntervalMs); + } + + /// + /// Processes a completed activity and extracts telemetry metrics. + /// + /// The activity to process. + /// + /// This method determines the event type based on the activity operation name: + /// - Connection.* activities emit connection events immediately + /// - Statement.* activities are aggregated by statement_id + /// - Activities with error.type tag are treated as error events + /// + /// All exceptions are swallowed and logged at TRACE level. + /// + public void ProcessActivity(Activity? activity) + { + if (activity == null) + { + return; + } + + try + { + var eventType = DetermineEventType(activity); + + switch (eventType) + { + case TelemetryEventType.ConnectionOpen: + ProcessConnectionActivity(activity); + break; + + case TelemetryEventType.StatementExecution: + ProcessStatementActivity(activity); + break; + + case TelemetryEventType.Error: + ProcessErrorActivity(activity); + break; + } + + // Check if we should flush based on batch size + if (_pendingEvents.Count >= _config.BatchSize) + { + _ = FlushAsync(); + } + } + catch (Exception ex) + { + // Swallow all exceptions per telemetry requirement + // Log at TRACE level to avoid customer anxiety + Debug.WriteLine($"[TRACE] MetricsAggregator: Error processing activity: {ex.Message}"); + } + } + + /// + /// Marks a statement as complete and emits the aggregated telemetry event. + /// + /// The statement ID to complete. + /// Whether the statement execution failed. + /// + /// This method: + /// - Removes the statement context from the aggregation dictionary + /// - Emits the aggregated event with all collected metrics + /// - If failed, also emits any buffered retryable exception events + /// + /// All exceptions are swallowed and logged at TRACE level. + /// + public void CompleteStatement(string statementId, bool failed = false) + { + if (string.IsNullOrEmpty(statementId)) + { + return; + } + + try + { + if (_statementContexts.TryRemove(statementId, out var context)) + { + // Emit the aggregated statement event + var telemetryEvent = CreateTelemetryEvent(context); + var frontendLog = WrapInFrontendLog(telemetryEvent); + _pendingEvents.Enqueue(frontendLog); + + // If statement failed and we have buffered exceptions, emit them + if (failed && context.BufferedExceptions.Count > 0) + { + foreach (var exception in context.BufferedExceptions) + { + var errorEvent = CreateErrorTelemetryEvent( + context.SessionId, + statementId, + exception); + var errorLog = WrapInFrontendLog(errorEvent); + _pendingEvents.Enqueue(errorLog); + } + } + } + } + catch (Exception ex) + { + // Swallow all exceptions per telemetry requirement + Debug.WriteLine($"[TRACE] MetricsAggregator: Error completing statement: {ex.Message}"); + } + } + + /// + /// Records an exception for a statement. + /// + /// The statement ID. + /// The session ID. + /// The exception to record. + /// + /// Terminal exceptions are flushed immediately. + /// Retryable exceptions are buffered until CompleteStatement is called. + /// + /// All exceptions are swallowed and logged at TRACE level. + /// + public void RecordException(string statementId, string sessionId, Exception? exception) + { + if (exception == null || string.IsNullOrEmpty(statementId)) + { + return; + } + + try + { + if (ExceptionClassifier.IsTerminalException(exception)) + { + // Terminal exception: flush immediately + var errorEvent = CreateErrorTelemetryEvent(sessionId, statementId, exception); + var errorLog = WrapInFrontendLog(errorEvent); + _pendingEvents.Enqueue(errorLog); + + Debug.WriteLine($"[TRACE] MetricsAggregator: Terminal exception recorded, flushing immediately"); + } + else + { + // Retryable exception: buffer until statement completes + var context = _statementContexts.GetOrAdd( + statementId, + _ => new StatementTelemetryContext(statementId, sessionId)); + context.BufferedExceptions.Add(exception); + + Debug.WriteLine($"[TRACE] MetricsAggregator: Retryable exception buffered for statement {statementId}"); + } + } + catch (Exception ex) + { + // Swallow all exceptions per telemetry requirement + Debug.WriteLine($"[TRACE] MetricsAggregator: Error recording exception: {ex.Message}"); + } + } + + /// + /// Flushes all pending telemetry events to the exporter. + /// + /// Cancellation token. + /// A task representing the flush operation. + /// + /// This method is thread-safe and uses a lock to prevent concurrent flushes. + /// All exceptions are swallowed and logged at TRACE level. + /// + public async Task FlushAsync(CancellationToken ct = default) + { + try + { + List eventsToFlush; + + lock (_flushLock) + { + if (_pendingEvents.IsEmpty) + { + return; + } + + eventsToFlush = new List(); + while (_pendingEvents.TryDequeue(out var eventLog)) + { + eventsToFlush.Add(eventLog); + } + } + + if (eventsToFlush.Count > 0) + { + Debug.WriteLine($"[TRACE] MetricsAggregator: Flushing {eventsToFlush.Count} events"); + await _exporter.ExportAsync(eventsToFlush, ct).ConfigureAwait(false); + } + } + catch (OperationCanceledException) + { + // Don't swallow cancellation + throw; + } + catch (Exception ex) + { + // Swallow all other exceptions per telemetry requirement + Debug.WriteLine($"[TRACE] MetricsAggregator: Error flushing events: {ex.Message}"); + } + } + + /// + /// Disposes the MetricsAggregator and flushes any remaining events. + /// + public void Dispose() + { + if (_disposed) + { + return; + } + + _disposed = true; + + try + { + _flushTimer.Dispose(); + + // Final flush + FlushAsync().GetAwaiter().GetResult(); + } + catch (Exception ex) + { + Debug.WriteLine($"[TRACE] MetricsAggregator: Error during dispose: {ex.Message}"); + } + } + + #region Private Methods + + /// + /// Timer callback for periodic flushing. + /// + private void OnFlushTimerTick(object? state) + { + if (_disposed) + { + return; + } + + try + { + _ = FlushAsync(); + } + catch (Exception ex) + { + Debug.WriteLine($"[TRACE] MetricsAggregator: Error in flush timer: {ex.Message}"); + } + } + + /// + /// Determines the telemetry event type based on the activity. + /// + private static TelemetryEventType DetermineEventType(Activity activity) + { + // Check for errors first + if (activity.GetTagItem("error.type") != null) + { + return TelemetryEventType.Error; + } + + // Map based on operation name + var operationName = activity.OperationName ?? string.Empty; + + if (operationName.StartsWith("Connection.", StringComparison.OrdinalIgnoreCase) || + operationName.Equals("OpenConnection", StringComparison.OrdinalIgnoreCase) || + operationName.Equals("OpenAsync", StringComparison.OrdinalIgnoreCase)) + { + return TelemetryEventType.ConnectionOpen; + } + + // Default to statement execution for statement operations and others + return TelemetryEventType.StatementExecution; + } + + /// + /// Processes a connection activity and emits it immediately. + /// + private void ProcessConnectionActivity(Activity activity) + { + var sessionId = GetTagValue(activity, "session.id"); + var telemetryEvent = new TelemetryEvent + { + SessionId = sessionId, + OperationLatencyMs = (long)activity.Duration.TotalMilliseconds, + SystemConfiguration = ExtractSystemConfiguration(activity), + ConnectionParameters = ExtractConnectionParameters(activity) + }; + + var frontendLog = WrapInFrontendLog(telemetryEvent); + _pendingEvents.Enqueue(frontendLog); + + Debug.WriteLine($"[TRACE] MetricsAggregator: Connection event emitted immediately"); + } + + /// + /// Processes a statement activity and aggregates it by statement_id. + /// + private void ProcessStatementActivity(Activity activity) + { + var statementId = GetTagValue(activity, "statement.id"); + var sessionId = GetTagValue(activity, "session.id"); + + if (string.IsNullOrEmpty(statementId)) + { + // No statement ID, cannot aggregate - emit immediately + var telemetryEvent = new TelemetryEvent + { + SessionId = sessionId, + OperationLatencyMs = (long)activity.Duration.TotalMilliseconds, + SqlExecutionEvent = ExtractSqlExecutionEvent(activity) + }; + + var frontendLog = WrapInFrontendLog(telemetryEvent); + _pendingEvents.Enqueue(frontendLog); + return; + } + + // Get or create context for this statement + var context = _statementContexts.GetOrAdd( + statementId, + _ => new StatementTelemetryContext(statementId, sessionId)); + + // Aggregate metrics + context.AddLatency((long)activity.Duration.TotalMilliseconds); + AggregateActivityTags(context, activity); + } + + /// + /// Processes an error activity and emits it immediately. + /// + private void ProcessErrorActivity(Activity activity) + { + var statementId = GetTagValue(activity, "statement.id"); + var sessionId = GetTagValue(activity, "session.id"); + var errorType = GetTagValue(activity, "error.type"); + var errorMessage = GetTagValue(activity, "error.message"); + var errorCode = GetTagValue(activity, "error.code"); + + var telemetryEvent = new TelemetryEvent + { + SessionId = sessionId, + SqlStatementId = statementId, + OperationLatencyMs = (long)activity.Duration.TotalMilliseconds, + ErrorInfo = new DriverErrorInfo + { + ErrorType = errorType, + ErrorMessage = errorMessage, + ErrorCode = errorCode + } + }; + + var frontendLog = WrapInFrontendLog(telemetryEvent); + _pendingEvents.Enqueue(frontendLog); + + Debug.WriteLine($"[TRACE] MetricsAggregator: Error event emitted immediately"); + } + + /// + /// Aggregates activity tags into the statement context. + /// + private void AggregateActivityTags(StatementTelemetryContext context, Activity activity) + { + var eventType = TelemetryEventType.StatementExecution; + + foreach (var tag in activity.Tags) + { + // Filter using TelemetryTagRegistry + if (!TelemetryTagRegistry.ShouldExportToDatabricks(eventType, tag.Key)) + { + continue; + } + + switch (tag.Key) + { + case "result.format": + context.ResultFormat = tag.Value; + break; + case "result.chunk_count": + if (int.TryParse(tag.Value, out int chunkCount)) + { + context.ChunkCount = (context.ChunkCount ?? 0) + chunkCount; + } + break; + case "result.bytes_downloaded": + if (long.TryParse(tag.Value, out long bytesDownloaded)) + { + context.BytesDownloaded = (context.BytesDownloaded ?? 0) + bytesDownloaded; + } + break; + case "result.compression_enabled": + if (bool.TryParse(tag.Value, out bool compressionEnabled)) + { + context.CompressionEnabled = compressionEnabled; + } + break; + case "result.row_count": + if (long.TryParse(tag.Value, out long rowCount)) + { + context.RowCount = rowCount; + } + break; + case "poll.count": + if (int.TryParse(tag.Value, out int pollCount)) + { + context.PollCount = (context.PollCount ?? 0) + pollCount; + } + break; + case "poll.latency_ms": + if (long.TryParse(tag.Value, out long pollLatencyMs)) + { + context.PollLatencyMs = (context.PollLatencyMs ?? 0) + pollLatencyMs; + } + break; + case "execution.status": + context.ExecutionStatus = tag.Value; + break; + case "statement.type": + context.StatementType = tag.Value; + break; + } + } + } + + /// + /// Creates a TelemetryEvent from a statement context. + /// + private TelemetryEvent CreateTelemetryEvent(StatementTelemetryContext context) + { + return new TelemetryEvent + { + SessionId = context.SessionId, + SqlStatementId = context.StatementId, + OperationLatencyMs = context.TotalLatencyMs, + SqlExecutionEvent = new SqlExecutionEvent + { + ResultFormat = context.ResultFormat, + ChunkCount = context.ChunkCount, + BytesDownloaded = context.BytesDownloaded, + CompressionEnabled = context.CompressionEnabled, + RowCount = context.RowCount, + PollCount = context.PollCount, + PollLatencyMs = context.PollLatencyMs, + ExecutionStatus = context.ExecutionStatus, + StatementType = context.StatementType + } + }; + } + + /// + /// Creates an error TelemetryEvent from an exception. + /// + private TelemetryEvent CreateErrorTelemetryEvent( + string? sessionId, + string statementId, + Exception exception) + { + var isTerminal = ExceptionClassifier.IsTerminalException(exception); + int? httpStatusCode = null; + +#if NET5_0_OR_GREATER + if (exception is System.Net.Http.HttpRequestException httpEx && httpEx.StatusCode.HasValue) + { + httpStatusCode = (int)httpEx.StatusCode.Value; + } +#endif + + return new TelemetryEvent + { + SessionId = sessionId, + SqlStatementId = statementId, + ErrorInfo = new DriverErrorInfo + { + ErrorType = exception.GetType().Name, + ErrorMessage = SanitizeErrorMessage(exception.Message), + IsTerminal = isTerminal, + HttpStatusCode = httpStatusCode + } + }; + } + + /// + /// Sanitizes an error message to remove potential PII. + /// + private static string SanitizeErrorMessage(string? message) + { + if (string.IsNullOrEmpty(message)) + { + return string.Empty; + } + + // Truncate long messages + const int maxLength = 200; + if (message.Length > maxLength) + { + message = message.Substring(0, maxLength) + "..."; + } + + return message; + } + + /// + /// Wraps a TelemetryEvent in a TelemetryFrontendLog. + /// + internal TelemetryFrontendLog WrapInFrontendLog(TelemetryEvent telemetryEvent) + { + return new TelemetryFrontendLog + { + WorkspaceId = _workspaceId, + FrontendLogEventId = Guid.NewGuid().ToString(), + Context = new FrontendLogContext + { + ClientContext = new TelemetryClientContext + { + UserAgent = _userAgent + }, + TimestampMillis = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds() + }, + Entry = new FrontendLogEntry + { + SqlDriverLog = telemetryEvent + } + }; + } + + /// + /// Extracts system configuration from activity tags. + /// + private static DriverSystemConfiguration? ExtractSystemConfiguration(Activity activity) + { + var driverVersion = GetTagValue(activity, "driver.version"); + var osName = GetTagValue(activity, "driver.os"); + var runtime = GetTagValue(activity, "driver.runtime"); + + if (driverVersion == null && osName == null && runtime == null) + { + return null; + } + + return new DriverSystemConfiguration + { + DriverName = "Databricks ADBC Driver", + DriverVersion = driverVersion, + OsName = osName, + RuntimeName = ".NET", + RuntimeVersion = runtime + }; + } + + /// + /// Extracts connection parameters from activity tags. + /// + private static DriverConnectionParameters? ExtractConnectionParameters(Activity activity) + { + var cloudFetchEnabled = GetTagValue(activity, "feature.cloudfetch"); + var lz4Enabled = GetTagValue(activity, "feature.lz4"); + var directResultsEnabled = GetTagValue(activity, "feature.direct_results"); + + if (cloudFetchEnabled == null && lz4Enabled == null && directResultsEnabled == null) + { + return null; + } + + return new DriverConnectionParameters + { + CloudFetchEnabled = bool.TryParse(cloudFetchEnabled, out var cf) ? cf : (bool?)null, + Lz4CompressionEnabled = bool.TryParse(lz4Enabled, out var lz4) ? lz4 : (bool?)null, + DirectResultsEnabled = bool.TryParse(directResultsEnabled, out var dr) ? dr : (bool?)null + }; + } + + /// + /// Extracts SQL execution event details from activity tags. + /// + private static SqlExecutionEvent? ExtractSqlExecutionEvent(Activity activity) + { + var resultFormat = GetTagValue(activity, "result.format"); + var chunkCountStr = GetTagValue(activity, "result.chunk_count"); + var bytesDownloadedStr = GetTagValue(activity, "result.bytes_downloaded"); + var pollCountStr = GetTagValue(activity, "poll.count"); + + if (resultFormat == null && chunkCountStr == null && bytesDownloadedStr == null && pollCountStr == null) + { + return null; + } + + return new SqlExecutionEvent + { + ResultFormat = resultFormat, + ChunkCount = int.TryParse(chunkCountStr, out var cc) ? cc : (int?)null, + BytesDownloaded = long.TryParse(bytesDownloadedStr, out var bd) ? bd : (long?)null, + PollCount = int.TryParse(pollCountStr, out var pc) ? pc : (int?)null + }; + } + + /// + /// Gets a tag value from an activity. + /// + private static string? GetTagValue(Activity activity, string tagName) + { + return activity.GetTagItem(tagName)?.ToString(); + } + + #endregion + + #region Nested Classes + + /// + /// Holds aggregated telemetry data for a statement. + /// + internal sealed class StatementTelemetryContext + { + private long _totalLatencyMs; + + public string StatementId { get; } + public string? SessionId { get; set; } + public long TotalLatencyMs => Interlocked.Read(ref _totalLatencyMs); + + // Aggregated metrics + public string? ResultFormat { get; set; } + public int? ChunkCount { get; set; } + public long? BytesDownloaded { get; set; } + public bool? CompressionEnabled { get; set; } + public long? RowCount { get; set; } + public int? PollCount { get; set; } + public long? PollLatencyMs { get; set; } + public string? ExecutionStatus { get; set; } + public string? StatementType { get; set; } + + // Buffered exceptions for retryable errors + public ConcurrentBag BufferedExceptions { get; } = new ConcurrentBag(); + + public StatementTelemetryContext(string statementId, string? sessionId) + { + StatementId = statementId ?? throw new ArgumentNullException(nameof(statementId)); + SessionId = sessionId; + } + + public void AddLatency(long latencyMs) + { + Interlocked.Add(ref _totalLatencyMs, latencyMs); + } + } + + #endregion + } +} diff --git a/csharp/src/Telemetry/TelemetryClient.cs b/csharp/src/Telemetry/TelemetryClient.cs new file mode 100644 index 00000000..0d0d52f4 --- /dev/null +++ b/csharp/src/Telemetry/TelemetryClient.cs @@ -0,0 +1,324 @@ +/* +* Copyright (c) 2025 ADBC Drivers Contributors +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; +using AdbcDrivers.Databricks.Telemetry.Models; + +namespace AdbcDrivers.Databricks.Telemetry +{ + /// + /// Main telemetry client that coordinates listener, aggregator, and exporter. + /// Manages background flush task and graceful shutdown. + /// + /// + /// This class implements the TelemetryClient from Section 9.3 of the design doc: + /// - Coordinates lifecycle of DatabricksActivityListener, MetricsAggregator, and CircuitBreakerTelemetryExporter + /// - Manages a background flush task for periodic export + /// - Implements graceful shutdown via CloseAsync: cancels background task, flushes pending, disposes resources + /// - Never throws exceptions (all swallowed and logged at TRACE level) + /// + /// The client is managed per-host by TelemetryClientManager with reference counting. + /// + /// JDBC Reference: TelemetryClient.java + /// + internal sealed class TelemetryClient : ITelemetryClient + { + private readonly string _host; + private readonly DatabricksActivityListener _listener; + private readonly MetricsAggregator _aggregator; + private readonly ITelemetryExporter _exporter; + private readonly CancellationTokenSource _cts; + private readonly Task _backgroundFlushTask; + private readonly TelemetryConfiguration _config; + private bool _closed; + private readonly object _closeLock = new object(); + + /// + /// Gets the host URL for this telemetry client. + /// + public string Host => _host; + + /// + /// Creates a new TelemetryClient that coordinates all telemetry components. + /// + /// The Databricks host URL. + /// The HTTP client to use for sending telemetry requests. + /// The telemetry configuration. + /// The Databricks workspace ID. + /// The user agent string for client context. + /// Thrown when host is null or whitespace. + /// Thrown when httpClient or config is null. + public TelemetryClient( + string host, + HttpClient httpClient, + TelemetryConfiguration config, + long workspaceId = 0, + string? userAgent = null) + : this(host, httpClient, config, workspaceId, userAgent, null, null, null) + { + } + + /// + /// Creates a new TelemetryClient with optional component injection for testing. + /// + /// The Databricks host URL. + /// The HTTP client to use for sending telemetry requests. + /// The telemetry configuration. + /// The Databricks workspace ID. + /// The user agent string for client context. + /// Optional custom exporter (for testing). + /// Optional custom aggregator (for testing). + /// Optional custom listener (for testing). + internal TelemetryClient( + string host, + HttpClient httpClient, + TelemetryConfiguration config, + long workspaceId, + string? userAgent, + ITelemetryExporter? exporter, + MetricsAggregator? aggregator, + DatabricksActivityListener? listener) + { + if (string.IsNullOrWhiteSpace(host)) + { + throw new ArgumentException("Host cannot be null or whitespace.", nameof(host)); + } + + if (httpClient == null) + { + throw new ArgumentNullException(nameof(httpClient)); + } + + if (config == null) + { + throw new ArgumentNullException(nameof(config)); + } + + _host = host; + _config = config; + _cts = new CancellationTokenSource(); + + // Initialize exporter: CircuitBreakerTelemetryExporter wrapping DatabricksTelemetryExporter + if (exporter != null) + { + _exporter = exporter; + } + else + { + var innerExporter = new DatabricksTelemetryExporter(httpClient, host, isAuthenticated: true, config); + _exporter = new CircuitBreakerTelemetryExporter(host, innerExporter); + } + + // Initialize aggregator + _aggregator = aggregator ?? new MetricsAggregator(_exporter, config, workspaceId, userAgent); + + // Initialize listener + _listener = listener ?? new DatabricksActivityListener(_aggregator, config); + + // Start the listener to begin collecting activities + _listener.Start(); + + // Start the background flush task + _backgroundFlushTask = StartBackgroundFlushTask(); + + Debug.WriteLine($"[TRACE] TelemetryClient: Initialized for host '{host}'"); + } + + /// + /// Exports telemetry frontend logs to the backend service. + /// + /// The list of telemetry frontend logs to export. + /// Cancellation token. + /// A task representing the asynchronous export operation. + /// + /// This method delegates to the underlying exporter (CircuitBreakerTelemetryExporter). + /// It never throws exceptions (except for cancellation). + /// + public Task ExportAsync(IReadOnlyList logs, CancellationToken ct = default) + { + return _exporter.ExportAsync(logs, ct); + } + + /// + /// Closes the telemetry client and releases all resources. + /// + /// A task representing the asynchronous close operation. + /// + /// This method implements graceful shutdown per Section 9.3 of the design doc: + /// 1. Cancels the background flush task + /// 2. Flushes all pending metrics synchronously + /// 3. Waits for background task to complete (with timeout) + /// 4. Disposes all resources + /// + /// This method never throws exceptions. All errors are swallowed and logged at TRACE level. + /// This method is idempotent - calling it multiple times has no additional effect. + /// + public async Task CloseAsync() + { + lock (_closeLock) + { + if (_closed) + { + return; + } + _closed = true; + } + + Debug.WriteLine($"[TRACE] TelemetryClient: Closing client for host '{_host}'"); + + try + { + // Step 1: Cancel the background flush task + _cts.Cancel(); + } + catch (Exception ex) + { + Debug.WriteLine($"[TRACE] TelemetryClient: Error cancelling background task: {ex.Message}"); + } + + try + { + // Step 2: Stop the listener and flush pending metrics via the listener + // This also calls _aggregator.FlushAsync internally + await _listener.StopAsync().ConfigureAwait(false); + } + catch (Exception ex) + { + Debug.WriteLine($"[TRACE] TelemetryClient: Error stopping listener: {ex.Message}"); + } + + try + { + // Step 3: Wait for background task to complete (with timeout) +#if NET6_0_OR_GREATER + await _backgroundFlushTask.WaitAsync(TimeSpan.FromSeconds(5)).ConfigureAwait(false); +#else + // For older frameworks, use Task.WhenAny with a delay + var completedTask = await Task.WhenAny( + _backgroundFlushTask, + Task.Delay(TimeSpan.FromSeconds(5))).ConfigureAwait(false); + + // If the flush task didn't complete in time, just continue + if (completedTask != _backgroundFlushTask) + { + Debug.WriteLine($"[TRACE] TelemetryClient: Background task did not complete within timeout"); + } +#endif + } + catch (OperationCanceledException) + { + // Expected when the task is cancelled + } + catch (TimeoutException) + { + Debug.WriteLine($"[TRACE] TelemetryClient: Background task did not complete within timeout"); + } + catch (Exception ex) + { + Debug.WriteLine($"[TRACE] TelemetryClient: Error waiting for background task: {ex.Message}"); + } + + try + { + // Step 4: Dispose resources + _listener.Dispose(); + } + catch (Exception ex) + { + Debug.WriteLine($"[TRACE] TelemetryClient: Error disposing listener: {ex.Message}"); + } + + try + { + _cts.Dispose(); + } + catch (Exception ex) + { + Debug.WriteLine($"[TRACE] TelemetryClient: Error disposing cancellation token source: {ex.Message}"); + } + + Debug.WriteLine($"[TRACE] TelemetryClient: Closed client for host '{_host}'"); + } + + /// + /// Gets the metrics aggregator for this client. + /// + internal MetricsAggregator Aggregator => _aggregator; + + /// + /// Gets the activity listener for this client. + /// + internal DatabricksActivityListener Listener => _listener; + + /// + /// Gets the telemetry exporter for this client. + /// + internal ITelemetryExporter Exporter => _exporter; + + /// + /// Gets whether the client has been closed. + /// + internal bool IsClosed => _closed; + + /// + /// Starts the background flush task that periodically flushes pending metrics. + /// + private Task StartBackgroundFlushTask() + { + return Task.Run(async () => + { + Debug.WriteLine($"[TRACE] TelemetryClient: Background flush task started for host '{_host}'"); + + try + { + while (!_cts.Token.IsCancellationRequested) + { + try + { + // Wait for the flush interval + await Task.Delay(_config.FlushIntervalMs, _cts.Token).ConfigureAwait(false); + + // Flush pending metrics + await _aggregator.FlushAsync(_cts.Token).ConfigureAwait(false); + } + catch (OperationCanceledException) + { + // Expected when cancelled - exit the loop + break; + } + catch (Exception ex) + { + // Swallow all other exceptions per telemetry requirement + Debug.WriteLine($"[TRACE] TelemetryClient: Error in background flush: {ex.Message}"); + } + } + } + catch (Exception ex) + { + // Outer exception handler for any unexpected errors + Debug.WriteLine($"[TRACE] TelemetryClient: Background flush task error: {ex.Message}"); + } + + Debug.WriteLine($"[TRACE] TelemetryClient: Background flush task stopped for host '{_host}'"); + }, _cts.Token); + } + } +} diff --git a/csharp/src/Telemetry/TelemetryClientAdapter.cs b/csharp/src/Telemetry/TelemetryClientAdapter.cs new file mode 100644 index 00000000..a50b037b --- /dev/null +++ b/csharp/src/Telemetry/TelemetryClientAdapter.cs @@ -0,0 +1,103 @@ +/* +* Copyright (c) 2025 ADBC Drivers Contributors +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Threading; +using System.Threading.Tasks; +using AdbcDrivers.Databricks.Telemetry.Models; + +namespace AdbcDrivers.Databricks.Telemetry +{ + /// + /// Adapter that wraps an ITelemetryExporter and implements ITelemetryClient. + /// + /// + /// This adapter bridges the gap between the ITelemetryExporter interface + /// (used for exporting telemetry events) and the ITelemetryClient interface + /// (used by TelemetryClientManager for per-host client management). + /// + /// The adapter: + /// - Delegates ExportAsync calls to the underlying exporter + /// - Provides a no-op CloseAsync since the exporter doesn't have close semantics + /// + internal sealed class TelemetryClientAdapter : ITelemetryClient + { + private readonly string _host; + private readonly ITelemetryExporter _exporter; + + /// + /// Gets the host URL for this telemetry client. + /// + public string Host => _host; + + /// + /// Gets the underlying telemetry exporter. + /// + internal ITelemetryExporter Exporter => _exporter; + + /// + /// Creates a new TelemetryClientAdapter. + /// + /// The Databricks host URL. + /// The telemetry exporter to wrap. + /// Thrown when host is null or whitespace. + /// Thrown when exporter is null. + public TelemetryClientAdapter(string host, ITelemetryExporter exporter) + { + if (string.IsNullOrWhiteSpace(host)) + { + throw new ArgumentException("Host cannot be null or whitespace.", nameof(host)); + } + + _host = host; + _exporter = exporter ?? throw new ArgumentNullException(nameof(exporter)); + } + + /// + /// Exports telemetry frontend logs to the backend service. + /// + /// The list of telemetry frontend logs to export. + /// Cancellation token. + /// A task representing the asynchronous export operation. + /// + /// This method delegates to the underlying exporter. It never throws + /// exceptions (except for cancellation) as per the telemetry requirement. + /// + public Task ExportAsync(IReadOnlyList logs, CancellationToken ct = default) + { + return _exporter.ExportAsync(logs, ct); + } + + /// + /// Closes the telemetry client and releases all resources. + /// + /// A task representing the asynchronous close operation. + /// + /// Currently a no-op since ITelemetryExporter doesn't have close semantics. + /// If the underlying exporter implements IDisposable in the future, + /// this method should be updated to dispose it. + /// + public Task CloseAsync() + { + Debug.WriteLine($"[TRACE] TelemetryClientAdapter: Closing client for host '{_host}'"); + // No-op for now since ITelemetryExporter doesn't have close/dispose semantics + // The exporter is stateless and doesn't hold any resources that need cleanup + return Task.CompletedTask; + } + } +} diff --git a/csharp/src/Telemetry/TelemetryClientHolder.cs b/csharp/src/Telemetry/TelemetryClientHolder.cs new file mode 100644 index 00000000..3767c452 --- /dev/null +++ b/csharp/src/Telemetry/TelemetryClientHolder.cs @@ -0,0 +1,87 @@ +/* +* Copyright (c) 2025 ADBC Drivers Contributors +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Threading; + +namespace AdbcDrivers.Databricks.Telemetry +{ + /// + /// Holds a telemetry client and its reference count. + /// + /// + /// This class is used by to manage + /// shared telemetry clients per host with reference counting. When the + /// reference count reaches zero, the client should be closed and removed. + /// + /// Thread-safety is ensured using Interlocked operations for the reference count. + /// + /// JDBC Reference: TelemetryClientHolder.java + /// + internal sealed class TelemetryClientHolder + { + private readonly ITelemetryClient _client; + private int _refCount; + + /// + /// Gets the telemetry client. + /// + public ITelemetryClient Client => _client; + + /// + /// Gets the current reference count. + /// + public int RefCount => Volatile.Read(ref _refCount); + + /// + /// Creates a new TelemetryClientHolder with the specified client. + /// + /// The telemetry client to hold. + /// Thrown when client is null. + /// + /// The initial reference count is 0. Call + /// after creation to register the first reference. + /// + public TelemetryClientHolder(ITelemetryClient client) + { + _client = client ?? throw new ArgumentNullException(nameof(client)); + _refCount = 0; + } + + /// + /// Increments the reference count. + /// + /// The new reference count. + public int IncrementRefCount() + { + return Interlocked.Increment(ref _refCount); + } + + /// + /// Decrements the reference count. + /// + /// The new reference count. + /// + /// When the reference count reaches zero, the caller is responsible + /// for calling on the client + /// and removing this holder from the manager. + /// + public int DecrementRefCount() + { + return Interlocked.Decrement(ref _refCount); + } + } +} diff --git a/csharp/src/Telemetry/TelemetryClientManager.cs b/csharp/src/Telemetry/TelemetryClientManager.cs new file mode 100644 index 00000000..d8007363 --- /dev/null +++ b/csharp/src/Telemetry/TelemetryClientManager.cs @@ -0,0 +1,248 @@ +/* +* Copyright (c) 2025 ADBC Drivers Contributors +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Concurrent; +using System.Diagnostics; +using System.Net.Http; +using System.Threading.Tasks; + +namespace AdbcDrivers.Databricks.Telemetry +{ + /// + /// Singleton factory that manages one telemetry client per host. + /// Prevents rate limiting by sharing clients across connections. + /// + /// + /// This class implements the per-host client management pattern from the JDBC driver: + /// - One telemetry client per host to prevent rate limiting from concurrent connections + /// - Large customers (e.g., Celonis) open many parallel connections to the same host + /// - Shared client batches events from all connections, avoiding multiple concurrent flushes + /// - Reference counting tracks active connections, closes client when last connection closes + /// - Thread-safe using ConcurrentDictionary and atomic reference counting + /// + /// JDBC Reference: TelemetryClientFactory.java + /// + internal sealed class TelemetryClientManager + { + private static readonly TelemetryClientManager s_instance = new TelemetryClientManager(); + + private readonly ConcurrentDictionary _clients; + private readonly Func? _clientFactory; + + /// + /// Gets the singleton instance of the TelemetryClientManager. + /// + public static TelemetryClientManager GetInstance() => s_instance; + + /// + /// Creates a new TelemetryClientManager. + /// + internal TelemetryClientManager() + : this(null) + { + } + + /// + /// Creates a new TelemetryClientManager with a custom client factory. + /// + /// + /// Factory function to create telemetry clients. + /// If null, uses the default factory that creates CircuitBreakerTelemetryExporter + /// wrapped around DatabricksTelemetryExporter. + /// + /// + /// This constructor is primarily for testing to allow injecting mock clients. + /// + internal TelemetryClientManager(Func? clientFactory) + { + _clients = new ConcurrentDictionary(StringComparer.OrdinalIgnoreCase); + _clientFactory = clientFactory; + } + + /// + /// Gets or creates a telemetry client for the host. + /// Increments reference count. + /// + /// The Databricks host URL. + /// The HTTP client to use for sending requests. + /// The telemetry configuration. + /// The telemetry client for the host. + /// Thrown when host is null or whitespace. + /// Thrown when httpClient or config is null. + public ITelemetryClient GetOrCreateClient(string host, HttpClient httpClient, TelemetryConfiguration config) + { + if (string.IsNullOrWhiteSpace(host)) + { + throw new ArgumentException("Host cannot be null or whitespace.", nameof(host)); + } + + if (httpClient == null) + { + throw new ArgumentNullException(nameof(httpClient)); + } + + if (config == null) + { + throw new ArgumentNullException(nameof(config)); + } + + // Use GetOrAdd with a value factory to ensure thread-safety + var holder = _clients.GetOrAdd(host, h => + { + var client = CreateClient(h, httpClient, config); + return new TelemetryClientHolder(client); + }); + + // Increment the reference count + var newRefCount = holder.IncrementRefCount(); + Debug.WriteLine($"[TRACE] TelemetryClientManager: GetOrCreateClient for host '{host}', RefCount={newRefCount}"); + + return holder.Client; + } + + /// + /// Decrements reference count for the host. + /// Closes and removes client when ref count reaches zero. + /// + /// The host to release the client for. + /// A task representing the asynchronous release operation. + /// + /// This method is thread-safe. If the reference count reaches zero, + /// the client is closed and removed from the cache. If multiple threads + /// try to release the same client simultaneously, only one will successfully + /// close and remove it. + /// + public async Task ReleaseClientAsync(string host) + { + if (string.IsNullOrWhiteSpace(host)) + { + return; + } + + if (_clients.TryGetValue(host, out var holder)) + { + var newRefCount = holder.DecrementRefCount(); + Debug.WriteLine($"[TRACE] TelemetryClientManager: ReleaseClientAsync for host '{host}', RefCount={newRefCount}"); + + if (newRefCount <= 0) + { + // Try to remove the holder. Use TryRemove to avoid race conditions + // where a new connection added a reference. + if (holder.RefCount <= 0) + { + // Check RefCount again because another thread might have + // incremented it between our check and the removal attempt. +#if NET5_0_OR_GREATER + if (_clients.TryRemove(new System.Collections.Generic.KeyValuePair(host, holder))) +#else + // For netstandard2.0, we need to be more careful about the removal + if (_clients.TryGetValue(host, out var currentHolder) && currentHolder == holder && currentHolder.RefCount <= 0) +#endif + { +#if !NET5_0_OR_GREATER + // For netstandard2.0, attempt to remove + ((System.Collections.Generic.IDictionary)_clients).Remove( + new System.Collections.Generic.KeyValuePair(host, holder)); +#endif + Debug.WriteLine($"[TRACE] TelemetryClientManager: Closing client for host '{host}'"); + + try + { + await holder.Client.CloseAsync().ConfigureAwait(false); + Debug.WriteLine($"[TRACE] TelemetryClientManager: Closed and removed client for host '{host}'"); + } + catch (Exception ex) + { + // Swallow all exceptions per telemetry requirement + Debug.WriteLine($"[TRACE] TelemetryClientManager: Error closing client for host '{host}': {ex.Message}"); + } + } + } + } + } + } + + /// + /// Gets the number of hosts with active clients. + /// + internal int ClientCount => _clients.Count; + + /// + /// Checks if a client exists for the specified host. + /// + /// The host to check. + /// True if a client exists, false otherwise. + internal bool HasClient(string host) + { + if (string.IsNullOrWhiteSpace(host)) + { + return false; + } + + return _clients.ContainsKey(host); + } + + /// + /// Gets the holder for the specified host, if it exists. + /// Does not create a new client or modify reference count. + /// + /// The host to get the holder for. + /// The holder if found, null otherwise. + /// True if the holder was found, false otherwise. + internal bool TryGetHolder(string host, out TelemetryClientHolder? holder) + { + holder = null; + + if (string.IsNullOrWhiteSpace(host)) + { + return false; + } + + if (_clients.TryGetValue(host, out var foundHolder)) + { + holder = foundHolder; + return true; + } + + return false; + } + + /// + /// Clears all clients. + /// This is primarily for testing purposes. + /// + internal void Clear() + { + _clients.Clear(); + } + + /// + /// Creates a new telemetry client for the specified host. + /// + private ITelemetryClient CreateClient(string host, HttpClient httpClient, TelemetryConfiguration config) + { + if (_clientFactory != null) + { + return _clientFactory(host, httpClient, config); + } + + // Default factory: Create full TelemetryClient that coordinates + // listener, aggregator, and exporter with background flush task + return new TelemetryClient(host, httpClient, config); + } + } +} diff --git a/csharp/src/Telemetry/TelemetryConfiguration.cs b/csharp/src/Telemetry/TelemetryConfiguration.cs index d2b1732c..a66dbb5c 100644 --- a/csharp/src/Telemetry/TelemetryConfiguration.cs +++ b/csharp/src/Telemetry/TelemetryConfiguration.cs @@ -84,6 +84,13 @@ public sealed class TelemetryConfiguration /// public const string FeatureFlagName = "databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForAdbc"; + /// + /// Feature flag endpoint format (relative to host). + /// {0} = driver version without OSS suffix. + /// NOTE: Using OSS_JDBC endpoint until OSS_ADBC is configured server-side. + /// + public const string FeatureFlagEndpointFormat = "/api/2.0/connector-service/feature-flags/OSS_JDBC/{0}"; + /// /// Gets or sets whether telemetry is enabled. /// Default is true. diff --git a/csharp/test/E2E/FeatureFlagCacheE2ETest.cs b/csharp/test/E2E/FeatureFlagCacheE2ETest.cs new file mode 100644 index 00000000..1c674efe --- /dev/null +++ b/csharp/test/E2E/FeatureFlagCacheE2ETest.cs @@ -0,0 +1,237 @@ +/* +* Copyright (c) 2025 ADBC Drivers Contributors +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Generic; +using System.Threading.Tasks; +using Apache.Arrow.Adbc.Tests; +using Xunit; +using Xunit.Abstractions; + +namespace AdbcDrivers.Databricks.Tests +{ + /// + /// End-to-end tests for the FeatureFlagCache functionality using a real Databricks instance. + /// Tests that feature flags are properly fetched and cached from the Databricks connector service. + /// + public class FeatureFlagCacheE2ETest : TestBase + { + public FeatureFlagCacheE2ETest(ITestOutputHelper? outputHelper) + : base(outputHelper, new DatabricksTestEnvironment.Factory()) + { + // Skip the test if the DATABRICKS_TEST_CONFIG_FILE environment variable is not set + Skip.IfNot(Utils.CanExecuteTestConfig(TestConfigVariable)); + } + + /// + /// Tests that creating a connection successfully initializes the feature flag cache + /// and verifies that flags are actually fetched from the server. + /// + [SkippableFact] + public async Task TestFeatureFlagCacheInitialization() + { + // Arrange + var cache = FeatureFlagCache.GetInstance(); + var hostName = GetNormalizedHostName(); + Skip.If(string.IsNullOrEmpty(hostName), "Cannot determine host name from test configuration"); + + // Act - Create a connection which initializes the feature flag cache + using var connection = NewConnection(TestConfiguration); + + // Assert - The connection should be created successfully + Assert.NotNull(connection); + + // Verify the feature flag context exists for this host + Assert.True(cache.TryGetContext(hostName!, out var context), "Feature flag context should exist after connection creation"); + Assert.NotNull(context); + + // Verify that some flags were fetched from the server + // The server should return at least some feature flags + var flags = context.GetAllFlags(); + OutputHelper?.WriteLine($"[FeatureFlagCacheE2ETest] Fetched {flags.Count} feature flags from server"); + foreach (var flag in flags) + { + OutputHelper?.WriteLine($" - {flag.Key}: {flag.Value}"); + } + + // Log the TTL from the server + OutputHelper?.WriteLine($"[FeatureFlagCacheE2ETest] TTL from server: {context.Ttl.TotalSeconds} seconds"); + + // Note: We don't assert flags.Count > 0 because the server may return empty flags + // in some environments, but we verify the infrastructure works + + // Execute a simple query to verify the connection works + using var statement = connection.CreateStatement(); + statement.SqlQuery = "SELECT 1 as test_value"; + var result = await statement.ExecuteQueryAsync(); + + Assert.NotNull(result.Stream); + var batch = await result.Stream.ReadNextRecordBatchAsync(); + Assert.NotNull(batch); + Assert.Equal(1, batch.Length); + + OutputHelper?.WriteLine("[FeatureFlagCacheE2ETest] Connection with feature flag cache initialized successfully"); + } + + /// + /// Tests that multiple connections to the same host share the same feature flag context. + /// This verifies the per-host caching behavior. + /// + [SkippableFact] + public async Task TestFeatureFlagCacheSharedAcrossConnections() + { + // Arrange - Get the singleton cache instance + var cache = FeatureFlagCache.GetInstance(); + + // Act - Create two connections to the same host + using var connection1 = NewConnection(TestConfiguration); + using var connection2 = NewConnection(TestConfiguration); + + // Assert - Both connections should work properly + Assert.NotNull(connection1); + Assert.NotNull(connection2); + + // Verify both connections can execute queries + using var statement1 = connection1.CreateStatement(); + statement1.SqlQuery = "SELECT 1 as conn1_test"; + var result1 = await statement1.ExecuteQueryAsync(); + Assert.NotNull(result1.Stream); + + using var statement2 = connection2.CreateStatement(); + statement2.SqlQuery = "SELECT 2 as conn2_test"; + var result2 = await statement2.ExecuteQueryAsync(); + Assert.NotNull(result2.Stream); + + OutputHelper?.WriteLine("[FeatureFlagCacheE2ETest] Multiple connections sharing feature flag cache work correctly"); + } + + /// + /// Tests that the feature flag cache persists after connections close (TTL-based expiration). + /// With the IMemoryCache implementation, contexts stay in cache until TTL expires, + /// not when connections close. + /// + [SkippableFact] + public async Task TestFeatureFlagCachePersistsAfterConnectionClose() + { + // Arrange + var cache = FeatureFlagCache.GetInstance(); + var hostName = GetNormalizedHostName(); + Skip.If(string.IsNullOrEmpty(hostName), "Cannot determine host name from test configuration"); + + OutputHelper?.WriteLine($"[FeatureFlagCacheE2ETest] Initial cache count: {cache.CachedHostCount}"); + + // Act - Create and close a connection + using (var connection = NewConnection(TestConfiguration)) + { + // Connection is active, cache should have a context for this host + Assert.NotNull(connection); + + // Verify context exists during connection + Assert.True(cache.TryGetContext(hostName!, out var context), "Context should exist while connection is active"); + Assert.NotNull(context); + + // Verify flags were fetched + var flags = context.GetAllFlags(); + OutputHelper?.WriteLine($"[FeatureFlagCacheE2ETest] Flags fetched: {flags.Count}"); + OutputHelper?.WriteLine($"[FeatureFlagCacheE2ETest] TTL: {context.Ttl.TotalSeconds} seconds"); + + // Execute a query to ensure the connection is fully initialized + using var statement = connection.CreateStatement(); + statement.SqlQuery = "SELECT 1"; + var result = await statement.ExecuteQueryAsync(); + Assert.NotNull(result.Stream); + } + // Connection is disposed here + + // With TTL-based caching, the context should still exist in the cache + // (it only gets removed when TTL expires or cache is explicitly cleared) + Assert.True(cache.TryGetContext(hostName!, out var contextAfterDispose), + "Context should still exist in cache after connection close (TTL-based expiration)"); + Assert.NotNull(contextAfterDispose); + + OutputHelper?.WriteLine($"[FeatureFlagCacheE2ETest] Context still exists after connection close with TTL: {contextAfterDispose.Ttl.TotalSeconds}s"); + OutputHelper?.WriteLine($"[FeatureFlagCacheE2ETest] Final cache count: {cache.CachedHostCount}"); + OutputHelper?.WriteLine("[FeatureFlagCacheE2ETest] Feature flag cache TTL-based persistence test completed"); + } + + /// + /// Tests that connections work correctly with feature flags enabled. + /// This is a basic sanity check that the feature flag infrastructure doesn't + /// interfere with normal connection operations. + /// + [SkippableFact] + public async Task TestConnectionWithFeatureFlagsExecutesQueries() + { + // Arrange + using var connection = NewConnection(TestConfiguration); + + // Act - Execute multiple queries to ensure feature flags don't interfere + var queries = new[] + { + "SELECT 1 as value", + "SELECT 'hello' as greeting", + "SELECT CURRENT_DATE() as today" + }; + + foreach (var query in queries) + { + using var statement = connection.CreateStatement(); + statement.SqlQuery = query; + var result = await statement.ExecuteQueryAsync(); + + // Assert + Assert.NotNull(result.Stream); + var batch = await result.Stream.ReadNextRecordBatchAsync(); + Assert.NotNull(batch); + Assert.True(batch.Length > 0, $"Query '{query}' should return at least one row"); + + OutputHelper?.WriteLine($"[FeatureFlagCacheE2ETest] Query executed successfully: {query}"); + } + } + + /// + /// Gets the normalized host name from test configuration. + /// Strips protocol prefix if present (e.g., "https://host" -> "host"). + /// + private string? GetNormalizedHostName() + { + var hostName = TestConfiguration.HostName ?? TestConfiguration.Uri; + if (string.IsNullOrEmpty(hostName)) + { + return null; + } + + // Try to parse as URI first + if (Uri.TryCreate(hostName, UriKind.Absolute, out Uri? parsedUri) && + (parsedUri.Scheme == Uri.UriSchemeHttp || parsedUri.Scheme == Uri.UriSchemeHttps)) + { + return parsedUri.Host; + } + + // Fallback: strip common protocol prefixes manually + if (hostName.StartsWith("https://", StringComparison.OrdinalIgnoreCase)) + { + return hostName.Substring(8); + } + if (hostName.StartsWith("http://", StringComparison.OrdinalIgnoreCase)) + { + return hostName.Substring(7); + } + + return hostName; + } + } +} diff --git a/csharp/test/E2E/Telemetry/ClientTelemetryE2ETests.cs b/csharp/test/E2E/Telemetry/ClientTelemetryE2ETests.cs new file mode 100644 index 00000000..e4bca1e1 --- /dev/null +++ b/csharp/test/E2E/Telemetry/ClientTelemetryE2ETests.cs @@ -0,0 +1,491 @@ +/* +* Copyright (c) 2025 ADBC Drivers Contributors +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Generic; +using System.Net.Http; +using System.Net.Http.Headers; +using System.Threading.Tasks; +using AdbcDrivers.Databricks.Telemetry; +using AdbcDrivers.Databricks.Telemetry.Models; +using Apache.Arrow.Adbc.Tests; +using Xunit; +using Xunit.Abstractions; + +namespace AdbcDrivers.Databricks.Tests.E2E.Telemetry +{ + /// + /// End-to-end tests for client telemetry that sends telemetry to Databricks endpoint. + /// These tests verify that the DatabricksTelemetryExporter can successfully send + /// telemetry events to real Databricks telemetry endpoints. + /// + public class ClientTelemetryE2ETests : TestBase + { + public ClientTelemetryE2ETests(ITestOutputHelper? outputHelper) + : base(outputHelper, new DatabricksTestEnvironment.Factory()) + { + } + + /// + /// Tests that telemetry can be sent to the authenticated endpoint (/telemetry-ext). + /// This endpoint requires a valid authentication token. + /// + [SkippableFact] + public async Task CanSendTelemetryToAuthenticatedEndpoint() + { + // Skip if no token is available + Skip.If(string.IsNullOrEmpty(TestConfiguration.Token) && string.IsNullOrEmpty(TestConfiguration.AccessToken), + "Token is required for authenticated telemetry endpoint test"); + + var host = GetDatabricksHost(); + Skip.If(string.IsNullOrEmpty(host), "Databricks host is required"); + + OutputHelper?.WriteLine($"Testing authenticated telemetry endpoint at {host}/telemetry-ext"); + + // Create HttpClient with authentication + using var httpClient = CreateAuthenticatedHttpClient(); + + var config = new TelemetryConfiguration + { + MaxRetries = 2, + RetryDelayMs = 100 + }; + + var exporter = new DatabricksTelemetryExporter(httpClient, host, isAuthenticated: true, config); + + // Verify endpoint URL + var endpointUrl = exporter.GetEndpointUrl(); + Assert.Equal($"{host}/telemetry-ext", endpointUrl); + OutputHelper?.WriteLine($"Endpoint URL: {endpointUrl}"); + + // Create a test telemetry log + var logs = CreateTestTelemetryLogs(1); + + // Send telemetry - should succeed and return true + var success = await exporter.ExportAsync(logs); + + // ExportAsync should return true indicating HTTP 200 response + Assert.True(success, "ExportAsync should return true indicating successful HTTP 200 response"); + OutputHelper?.WriteLine("Successfully sent telemetry to authenticated endpoint"); + } + + /// + /// Tests that telemetry can be sent to the unauthenticated endpoint (/telemetry-unauth). + /// This endpoint does not require authentication. + /// + [SkippableFact] + public async Task CanSendTelemetryToUnauthenticatedEndpoint() + { + var host = GetDatabricksHost(); + Skip.If(string.IsNullOrEmpty(host), "Databricks host is required"); + + OutputHelper?.WriteLine($"Testing unauthenticated telemetry endpoint at {host}/telemetry-unauth"); + + // Create HttpClient without authentication + using var httpClient = new HttpClient(); + + var config = new TelemetryConfiguration + { + MaxRetries = 2, + RetryDelayMs = 100 + }; + + var exporter = new DatabricksTelemetryExporter(httpClient, host, isAuthenticated: false, config); + + // Verify endpoint URL + var endpointUrl = exporter.GetEndpointUrl(); + Assert.Equal($"{host}/telemetry-unauth", endpointUrl); + OutputHelper?.WriteLine($"Endpoint URL: {endpointUrl}"); + + // Create a test telemetry log + var logs = CreateTestTelemetryLogs(1); + + // Send telemetry - should succeed and return true + var success = await exporter.ExportAsync(logs); + + // ExportAsync should return true indicating HTTP 200 response + Assert.True(success, "ExportAsync should return true indicating successful HTTP 200 response"); + OutputHelper?.WriteLine("Successfully sent telemetry to unauthenticated endpoint"); + } + + /// + /// Tests that multiple telemetry logs can be batched and sent together. + /// + [SkippableFact] + public async Task CanSendBatchedTelemetryLogs() + { + Skip.If(string.IsNullOrEmpty(TestConfiguration.Token) && string.IsNullOrEmpty(TestConfiguration.AccessToken), + "Token is required for authenticated telemetry endpoint test"); + + var host = GetDatabricksHost(); + Skip.If(string.IsNullOrEmpty(host), "Databricks host is required"); + + OutputHelper?.WriteLine($"Testing batched telemetry to {host}"); + + using var httpClient = CreateAuthenticatedHttpClient(); + + var config = new TelemetryConfiguration + { + MaxRetries = 2, + RetryDelayMs = 100 + }; + + var exporter = new DatabricksTelemetryExporter(httpClient, host, isAuthenticated: true, config); + + // Create multiple telemetry logs + var logs = CreateTestTelemetryLogs(5); + OutputHelper?.WriteLine($"Created {logs.Count} telemetry logs for batch send"); + + // Send telemetry - should succeed and return true + var success = await exporter.ExportAsync(logs); + + Assert.True(success, "ExportAsync should return true for batched telemetry"); + OutputHelper?.WriteLine("Successfully sent batched telemetry logs"); + } + + /// + /// Tests that the telemetry request is properly formatted. + /// + [SkippableFact] + public void TelemetryRequestIsProperlyFormatted() + { + var host = GetDatabricksHost(); + Skip.If(string.IsNullOrEmpty(host), "Databricks host is required"); + + using var httpClient = new HttpClient(); + var config = new TelemetryConfiguration(); + var exporter = new DatabricksTelemetryExporter(httpClient, host, isAuthenticated: true, config); + + // Create test logs + var logs = CreateTestTelemetryLogs(2); + + // Create the request + var request = exporter.CreateTelemetryRequest(logs); + + // Verify request structure + Assert.True(request.UploadTime > 0, "UploadTime should be a positive timestamp"); + Assert.Equal(2, request.ProtoLogs.Count); + + // Verify each log is serialized as JSON + foreach (var protoLog in request.ProtoLogs) + { + Assert.NotEmpty(protoLog); + Assert.Contains("workspace_id", protoLog); + Assert.Contains("frontend_log_event_id", protoLog); + OutputHelper?.WriteLine($"Serialized log: {protoLog}"); + } + + // Verify the full request serialization + var json = exporter.SerializeRequest(request); + Assert.Contains("uploadTime", json); + Assert.Contains("protoLogs", json); + OutputHelper?.WriteLine($"Full request JSON: {json}"); + } + + /// + /// Tests telemetry with a complete TelemetryFrontendLog including all nested objects. + /// + [SkippableFact] + public async Task CanSendCompleteTelemetryEvent() + { + Skip.If(string.IsNullOrEmpty(TestConfiguration.Token) && string.IsNullOrEmpty(TestConfiguration.AccessToken), + "Token is required for authenticated telemetry endpoint test"); + + var host = GetDatabricksHost(); + Skip.If(string.IsNullOrEmpty(host), "Databricks host is required"); + + OutputHelper?.WriteLine($"Testing complete telemetry event to {host}"); + + using var httpClient = CreateAuthenticatedHttpClient(); + + var config = new TelemetryConfiguration + { + MaxRetries = 2, + RetryDelayMs = 100 + }; + + var exporter = new DatabricksTelemetryExporter(httpClient, host, isAuthenticated: true, config); + + // Create a complete telemetry log with all fields populated + var log = new TelemetryFrontendLog + { + WorkspaceId = 12345678901234, + FrontendLogEventId = Guid.NewGuid().ToString(), + Context = new FrontendLogContext + { + TimestampMillis = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds(), + ClientContext = new TelemetryClientContext + { + UserAgent = "AdbcDatabricksDriver/1.0.0-test (.NET; E2E Test)" + } + }, + Entry = new FrontendLogEntry + { + SqlDriverLog = new TelemetryEvent + { + SessionId = Guid.NewGuid().ToString(), + SqlStatementId = Guid.NewGuid().ToString(), + OperationLatencyMs = 150, + SystemConfiguration = new DriverSystemConfiguration + { + DriverName = "Databricks ADBC Driver", + DriverVersion = "1.0.0-test", + OsName = Environment.OSVersion.Platform.ToString(), + OsVersion = Environment.OSVersion.Version.ToString(), + RuntimeName = ".NET", + RuntimeVersion = Environment.Version.ToString(), + Locale = System.Globalization.CultureInfo.CurrentCulture.Name, + Timezone = TimeZoneInfo.Local.Id + } + } + } + }; + + var logs = new List { log }; + + OutputHelper?.WriteLine($"Sending complete telemetry event:"); + OutputHelper?.WriteLine($" WorkspaceId: {log.WorkspaceId}"); + OutputHelper?.WriteLine($" EventId: {log.FrontendLogEventId}"); + OutputHelper?.WriteLine($" SessionId: {log.Entry?.SqlDriverLog?.SessionId}"); + + // Send telemetry - should succeed and return true + var success = await exporter.ExportAsync(logs); + + Assert.True(success, "ExportAsync should return true for complete telemetry event"); + OutputHelper?.WriteLine("Successfully sent complete telemetry event"); + } + + /// + /// Tests that empty log lists are handled gracefully without making HTTP requests. + /// + [SkippableFact] + public async Task EmptyLogListDoesNotSendRequest() + { + var host = GetDatabricksHost(); + Skip.If(string.IsNullOrEmpty(host), "Databricks host is required"); + + using var httpClient = new HttpClient(); + var config = new TelemetryConfiguration(); + var exporter = new DatabricksTelemetryExporter(httpClient, host, isAuthenticated: true, config); + + // Send empty list - should return immediately with true (nothing to export) + var success = await exporter.ExportAsync(new List()); + + Assert.True(success, "ExportAsync should return true for empty list (nothing to export)"); + OutputHelper?.WriteLine("Empty log list handled gracefully"); + + // Send null list - should also return immediately with true (nothing to export) + success = await exporter.ExportAsync(null!); + + Assert.True(success, "ExportAsync should return true for null list (nothing to export)"); + OutputHelper?.WriteLine("Null log list handled gracefully"); + } + + /// + /// Tests that the exporter properly retries on transient failures. + /// This test verifies the retry configuration is respected. + /// + [SkippableFact] + public async Task TelemetryExporterRespectsRetryConfiguration() + { + Skip.If(string.IsNullOrEmpty(TestConfiguration.Token) && string.IsNullOrEmpty(TestConfiguration.AccessToken), + "Token is required for authenticated telemetry endpoint test"); + + var host = GetDatabricksHost(); + Skip.If(string.IsNullOrEmpty(host), "Databricks host is required"); + + OutputHelper?.WriteLine("Testing retry configuration"); + + using var httpClient = CreateAuthenticatedHttpClient(); + + // Configure with specific retry settings + var config = new TelemetryConfiguration + { + MaxRetries = 3, + RetryDelayMs = 50 + }; + + var exporter = new DatabricksTelemetryExporter(httpClient, host, isAuthenticated: true, config); + + var logs = CreateTestTelemetryLogs(1); + + // This should succeed and return true + var success = await exporter.ExportAsync(logs); + + // Exporter should return true on success + Assert.True(success, "ExportAsync should return true with retry configuration"); + OutputHelper?.WriteLine("Retry configuration test completed"); + } + + /// + /// Tests that the authenticated telemetry endpoint returns HTTP 200. + /// This test directly calls the endpoint to verify the response status code. + /// + [SkippableFact] + public async Task AuthenticatedEndpointReturnsHttp200() + { + Skip.If(string.IsNullOrEmpty(TestConfiguration.Token) && string.IsNullOrEmpty(TestConfiguration.AccessToken), + "Token is required for authenticated telemetry endpoint test"); + + var host = GetDatabricksHost(); + Skip.If(string.IsNullOrEmpty(host), "Databricks host is required"); + + var endpointUrl = $"{host}/telemetry-ext"; + OutputHelper?.WriteLine($"Testing HTTP response from {endpointUrl}"); + + using var httpClient = CreateAuthenticatedHttpClient(); + + // Create a minimal telemetry request + var logs = CreateTestTelemetryLogs(1); + var config = new TelemetryConfiguration(); + var exporter = new DatabricksTelemetryExporter(httpClient, host, isAuthenticated: true, config); + var request = exporter.CreateTelemetryRequest(logs); + var json = exporter.SerializeRequest(request); + + // Send the request directly to verify HTTP status + using var content = new System.Net.Http.StringContent(json, System.Text.Encoding.UTF8, "application/json"); + using var response = await httpClient.PostAsync(endpointUrl, content); + + OutputHelper?.WriteLine($"HTTP Status Code: {(int)response.StatusCode} ({response.StatusCode})"); + OutputHelper?.WriteLine($"Response Headers: {response.Headers}"); + + var responseBody = await response.Content.ReadAsStringAsync(); + OutputHelper?.WriteLine($"Response Body: {responseBody}"); + + // Assert we get HTTP 200 + Assert.Equal(System.Net.HttpStatusCode.OK, response.StatusCode); + OutputHelper?.WriteLine("Verified: Authenticated endpoint returns HTTP 200"); + } + + /// + /// Tests that the unauthenticated telemetry endpoint returns HTTP 200. + /// This test directly calls the endpoint to verify the response status code. + /// + [SkippableFact] + public async Task UnauthenticatedEndpointReturnsHttp200() + { + var host = GetDatabricksHost(); + Skip.If(string.IsNullOrEmpty(host), "Databricks host is required"); + + var endpointUrl = $"{host}/telemetry-unauth"; + OutputHelper?.WriteLine($"Testing HTTP response from {endpointUrl}"); + + using var httpClient = new HttpClient(); + + // Create a minimal telemetry request + var logs = CreateTestTelemetryLogs(1); + var config = new TelemetryConfiguration(); + var exporter = new DatabricksTelemetryExporter(httpClient, host, isAuthenticated: false, config); + var request = exporter.CreateTelemetryRequest(logs); + var json = exporter.SerializeRequest(request); + + // Send the request directly to verify HTTP status + using var content = new System.Net.Http.StringContent(json, System.Text.Encoding.UTF8, "application/json"); + using var response = await httpClient.PostAsync(endpointUrl, content); + + OutputHelper?.WriteLine($"HTTP Status Code: {(int)response.StatusCode} ({response.StatusCode})"); + OutputHelper?.WriteLine($"Response Headers: {response.Headers}"); + + var responseBody = await response.Content.ReadAsStringAsync(); + OutputHelper?.WriteLine($"Response Body: {responseBody}"); + + // Assert we get HTTP 200 + Assert.Equal(System.Net.HttpStatusCode.OK, response.StatusCode); + OutputHelper?.WriteLine("Verified: Unauthenticated endpoint returns HTTP 200"); + } + + #region Helper Methods + + /// + /// Gets the Databricks host URL from the test configuration. + /// + private string GetDatabricksHost() + { + // Try Uri first, then fall back to HostName + if (!string.IsNullOrEmpty(TestConfiguration.Uri)) + { + var uri = new Uri(TestConfiguration.Uri); + return $"{uri.Scheme}://{uri.Host}"; + } + + if (!string.IsNullOrEmpty(TestConfiguration.HostName)) + { + return $"https://{TestConfiguration.HostName}"; + } + + return string.Empty; + } + + /// + /// Creates an HttpClient with authentication headers. + /// + private HttpClient CreateAuthenticatedHttpClient() + { + var httpClient = new HttpClient(); + + // Use AccessToken if available, otherwise fall back to Token + var token = !string.IsNullOrEmpty(TestConfiguration.AccessToken) + ? TestConfiguration.AccessToken + : TestConfiguration.Token; + + if (!string.IsNullOrEmpty(token)) + { + httpClient.DefaultRequestHeaders.Authorization = + new AuthenticationHeaderValue("Bearer", token); + } + + return httpClient; + } + + /// + /// Creates test telemetry logs for E2E testing. + /// + private IReadOnlyList CreateTestTelemetryLogs(int count) + { + var logs = new List(count); + + for (int i = 0; i < count; i++) + { + logs.Add(new TelemetryFrontendLog + { + WorkspaceId = 5870029948831567, + FrontendLogEventId = Guid.NewGuid().ToString(), + Context = new FrontendLogContext + { + TimestampMillis = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds(), + ClientContext = new TelemetryClientContext + { + UserAgent = $"AdbcDatabricksDriver/1.0.0-test (.NET; E2E Test {i})" + } + }, + Entry = new FrontendLogEntry + { + SqlDriverLog = new TelemetryEvent + { + SessionId = Guid.NewGuid().ToString(), + SqlStatementId = Guid.NewGuid().ToString(), + OperationLatencyMs = 100 + (i * 10) + } + } + }); + } + + return logs; + } + + #endregion + } +} diff --git a/csharp/test/E2E/Telemetry/FeatureFlagCacheE2ETests.cs b/csharp/test/E2E/Telemetry/FeatureFlagCacheE2ETests.cs new file mode 100644 index 00000000..aadcfa47 --- /dev/null +++ b/csharp/test/E2E/Telemetry/FeatureFlagCacheE2ETests.cs @@ -0,0 +1,635 @@ +/* +* Copyright (c) 2025 ADBC Drivers Contributors +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; +using AdbcDrivers.Databricks.Telemetry; +using Apache.Arrow.Adbc.Tests; +using Xunit; +using Xunit.Abstractions; + +namespace AdbcDrivers.Databricks.Tests.E2E.Telemetry +{ + /// + /// E2E tests for FeatureFlagCache. + /// Tests feature flag fetching from real Databricks endpoints and validates + /// caching and reference counting behavior. + /// + /// + /// These tests require: + /// - DATABRICKS_TEST_CONFIG_FILE environment variable pointing to a valid test config + /// - The test config must include: hostName, token/access_token + /// + public class FeatureFlagCacheE2ETests : TestBase + { + private readonly bool _canRunRealEndpointTests; + private readonly string? _host; + private readonly string? _token; + + public FeatureFlagCacheE2ETests(ITestOutputHelper? outputHelper) + : base(outputHelper, new DatabricksTestEnvironment.Factory()) + { + // Check if we can run tests against a real endpoint + _canRunRealEndpointTests = Utils.CanExecuteTestConfig(TestConfigVariable); + + if (_canRunRealEndpointTests) + { + try + { + // Try to get host from HostName first, then fallback to extracting from Uri + _host = TestConfiguration.HostName; + if (string.IsNullOrEmpty(_host) && !string.IsNullOrEmpty(TestConfiguration.Uri)) + { + // Extract host from Uri (e.g., https://host.databricks.com/sql/1.0/...) + var uri = new Uri(TestConfiguration.Uri); + _host = uri.Host; + } + + _token = TestConfiguration.Token ?? TestConfiguration.AccessToken; + + // Validate we have required configuration + if (string.IsNullOrEmpty(_host) || string.IsNullOrEmpty(_token)) + { + _canRunRealEndpointTests = false; + } + } + catch + { + _canRunRealEndpointTests = false; + } + } + } + + /// + /// Creates an HttpClient configured with authentication for the Databricks endpoint. + /// + private HttpClient CreateAuthenticatedHttpClient() + { + var client = new HttpClient(); + client.DefaultRequestHeaders.Authorization = + new System.Net.Http.Headers.AuthenticationHeaderValue("Bearer", _token); + return client; + } + + /// + /// Creates a feature flag fetcher function that simulates calling the feature flag endpoint. + /// The actual feature flag endpoint would be called by the driver during connection. + /// For E2E testing purposes, we create a fetcher that always returns a boolean value. + /// + private Func> CreateFeatureFlagFetcher(bool returnValue, int? delayMs = null) + { + return async ct => + { + if (delayMs.HasValue) + { + await Task.Delay(delayMs.Value, ct); + } + return returnValue; + }; + } + + /// + /// Creates a feature flag fetcher that makes a real HTTP call to validate connectivity. + /// Uses the telemetry endpoint to verify the host is reachable. + /// + private Func> CreateRealEndpointFetcher(HttpClient httpClient) + { + return async ct => + { + // Make a lightweight request to validate endpoint connectivity + // In production, this would be a feature flag API call + // For E2E testing, we verify the host is reachable + try + { + var host = _host!.TrimEnd('/'); + if (!host.StartsWith("https://", StringComparison.OrdinalIgnoreCase)) + { + host = "https://" + host; + } + + // Try to reach a Databricks API endpoint to validate connectivity + // Using a simple GET request that should work with bearer auth + var request = new HttpRequestMessage(HttpMethod.Get, $"{host}/api/2.0/clusters/list"); + var response = await httpClient.SendAsync(request, ct); + + // Any response (even 403) means we reached the endpoint + // The feature flag would be determined by the response in production + return response.IsSuccessStatusCode || + response.StatusCode == System.Net.HttpStatusCode.Forbidden || + response.StatusCode == System.Net.HttpStatusCode.Unauthorized; + } + catch (HttpRequestException) + { + // Network error - endpoint not reachable + return false; + } + }; + } + + #region FeatureFlagCache_FetchFromRealEndpoint_ReturnsBoolean + + /// + /// Tests that FeatureFlagCache can fetch feature flags from a real Databricks endpoint. + /// This validates the cache correctly handles real-world HTTP responses. + /// + [SkippableFact] + public async Task FeatureFlagCache_FetchFromRealEndpoint_ReturnsBoolean() + { + Skip.IfNot(_canRunRealEndpointTests, "Real endpoint testing requires DATABRICKS_TEST_CONFIG_FILE"); + + // Arrange + var cache = new FeatureFlagCache(); + var host = _host!; + + using var httpClient = CreateAuthenticatedHttpClient(); + var fetcher = CreateRealEndpointFetcher(httpClient); + + // Create context first (required before calling IsTelemetryEnabledAsync) + var context = cache.GetOrCreateContext(host); + + try + { + // Act + var result = await cache.IsTelemetryEnabledAsync(host, fetcher, CancellationToken.None); + + // Assert + // The result should be a boolean - either true or false is valid + // The important thing is that no exception was thrown + Assert.True(result == true || result == false); + OutputHelper?.WriteLine($"Feature flag result from real endpoint: {result}"); + + // Verify the cache was updated + Assert.NotNull(context.TelemetryEnabled); + Assert.NotNull(context.LastFetched); + } + finally + { + cache.ReleaseContext(host); + } + } + + #endregion + + #region FeatureFlagCache_CachesValue_DoesNotRefetchWithinTTL + + /// + /// Tests that FeatureFlagCache caches values and does not refetch within the TTL period. + /// + [Fact] + public async Task FeatureFlagCache_CachesValue_DoesNotRefetchWithinTTL() + { + // Arrange + var cache = new FeatureFlagCache(TimeSpan.FromMinutes(15)); // Use default TTL + var host = "test-caching-host.databricks.com"; + var fetchCount = 0; + + var fetcher = async (CancellationToken ct) => + { + Interlocked.Increment(ref fetchCount); + await Task.CompletedTask; + return true; + }; + + // Create context first + var context = cache.GetOrCreateContext(host); + + try + { + // Act - First call should fetch + var result1 = await cache.IsTelemetryEnabledAsync(host, fetcher, CancellationToken.None); + Assert.True(result1); + Assert.Equal(1, fetchCount); + + // Second call should use cached value + var result2 = await cache.IsTelemetryEnabledAsync(host, fetcher, CancellationToken.None); + Assert.True(result2); + Assert.Equal(1, fetchCount); // Should NOT have fetched again + + // Third call should still use cached value + var result3 = await cache.IsTelemetryEnabledAsync(host, fetcher, CancellationToken.None); + Assert.True(result3); + Assert.Equal(1, fetchCount); // Should NOT have fetched again + + // Assert + OutputHelper?.WriteLine($"Total fetch count: {fetchCount} (expected: 1)"); + Assert.Equal(1, fetchCount); + + // Verify cache state + Assert.True(context.TelemetryEnabled); + Assert.False(context.IsExpired); + } + finally + { + cache.ReleaseContext(host); + } + } + + /// + /// Tests that FeatureFlagCache refetches after cache expires. + /// + [Fact] + public async Task FeatureFlagCache_RefetchesAfterExpiry() + { + // Arrange - Use very short TTL for testing + var cache = new FeatureFlagCache(TimeSpan.FromMilliseconds(50)); + var host = "test-expiry-host.databricks.com"; + var fetchCount = 0; + + var fetcher = async (CancellationToken ct) => + { + Interlocked.Increment(ref fetchCount); + await Task.CompletedTask; + return true; + }; + + // Create context first + var context = cache.GetOrCreateContext(host); + + try + { + // Act - First call should fetch + var result1 = await cache.IsTelemetryEnabledAsync(host, fetcher, CancellationToken.None); + Assert.True(result1); + Assert.Equal(1, fetchCount); + + // Wait for cache to expire + await Task.Delay(100); + Assert.True(context.IsExpired); + + // Second call should refetch because cache expired + var result2 = await cache.IsTelemetryEnabledAsync(host, fetcher, CancellationToken.None); + Assert.True(result2); + Assert.Equal(2, fetchCount); // Should have fetched again + + // Assert + OutputHelper?.WriteLine($"Total fetch count after expiry: {fetchCount} (expected: 2)"); + } + finally + { + cache.ReleaseContext(host); + } + } + + #endregion + + #region FeatureFlagCache_InvalidHost_ReturnsDefaultFalse + + /// + /// Tests that FeatureFlagCache returns false for invalid hosts. + /// + [Fact] + public async Task FeatureFlagCache_InvalidHost_ReturnsDefaultFalse() + { + // Arrange + var cache = new FeatureFlagCache(); + var invalidHost = "invalid-host-that-does-not-exist-12345.databricks.com"; + var fetchCount = 0; + + // Fetcher that throws to simulate network error + Func> fetcher = async (CancellationToken ct) => + { + Interlocked.Increment(ref fetchCount); + await Task.CompletedTask; + throw new HttpRequestException("Host not found"); + }; + + // Create context first + var context = cache.GetOrCreateContext(invalidHost); + + try + { + // Act + var result = await cache.IsTelemetryEnabledAsync(invalidHost, fetcher, CancellationToken.None); + + // Assert - Should return false on error (safe default) + Assert.False(result); + Assert.Equal(1, fetchCount); // Should have attempted to fetch + OutputHelper?.WriteLine($"Invalid host returned: {result} (expected: false)"); + } + finally + { + cache.ReleaseContext(invalidHost); + } + } + + /// + /// Tests that FeatureFlagCache returns false for null host. + /// + [Fact] + public async Task FeatureFlagCache_NullHost_ReturnsDefaultFalse() + { + // Arrange + var cache = new FeatureFlagCache(); + var fetchCalled = false; + + var fetcher = async (CancellationToken ct) => + { + fetchCalled = true; + await Task.CompletedTask; + return true; + }; + + // Act + var result = await cache.IsTelemetryEnabledAsync(null!, fetcher, CancellationToken.None); + + // Assert + Assert.False(result); + Assert.False(fetchCalled); // Should not have attempted to fetch for null host + } + + /// + /// Tests that FeatureFlagCache returns false for empty host. + /// + [Fact] + public async Task FeatureFlagCache_EmptyHost_ReturnsDefaultFalse() + { + // Arrange + var cache = new FeatureFlagCache(); + var fetchCalled = false; + + var fetcher = async (CancellationToken ct) => + { + fetchCalled = true; + await Task.CompletedTask; + return true; + }; + + // Act + var result = await cache.IsTelemetryEnabledAsync("", fetcher, CancellationToken.None); + + // Assert + Assert.False(result); + Assert.False(fetchCalled); // Should not have attempted to fetch for empty host + } + + /// + /// Tests that FeatureFlagCache returns false for unknown host (no context created). + /// + [Fact] + public async Task FeatureFlagCache_UnknownHost_ReturnsDefaultFalse() + { + // Arrange + var cache = new FeatureFlagCache(); + var unknownHost = "unknown-host.databricks.com"; + var fetchCalled = false; + + var fetcher = async (CancellationToken ct) => + { + fetchCalled = true; + await Task.CompletedTask; + return true; + }; + + // Act - Note: No context created for this host + var result = await cache.IsTelemetryEnabledAsync(unknownHost, fetcher, CancellationToken.None); + + // Assert + Assert.False(result); + Assert.False(fetchCalled); // Should not have attempted to fetch for unknown host + } + + #endregion + + #region FeatureFlagCache_RefCountingWorks_CleanupAfterRelease + + /// + /// Tests that FeatureFlagCache reference counting works correctly and + /// contexts are cleaned up after all references are released. + /// + [Fact] + public void FeatureFlagCache_RefCountingWorks_CleanupAfterRelease() + { + // Arrange + var cache = new FeatureFlagCache(); + var host = "test-refcount-host.databricks.com"; + + // Act - Create multiple references + var context1 = cache.GetOrCreateContext(host); + Assert.Equal(1, context1.RefCount); + Assert.True(cache.HasContext(host)); + + var context2 = cache.GetOrCreateContext(host); + Assert.Equal(2, context2.RefCount); + Assert.Same(context1, context2); // Should be same instance + + var context3 = cache.GetOrCreateContext(host); + Assert.Equal(3, context3.RefCount); + Assert.Same(context1, context3); // Should be same instance + + OutputHelper?.WriteLine($"After creating 3 references: RefCount = {context1.RefCount}"); + + // Release references one by one + cache.ReleaseContext(host); + Assert.Equal(2, context1.RefCount); + Assert.True(cache.HasContext(host)); // Still has references + + cache.ReleaseContext(host); + Assert.Equal(1, context1.RefCount); + Assert.True(cache.HasContext(host)); // Still has references + + cache.ReleaseContext(host); + // After last release, context should be removed + + // Assert + Assert.False(cache.HasContext(host)); + Assert.Equal(0, cache.CachedHostCount); + OutputHelper?.WriteLine($"After releasing all references: Context removed"); + } + + /// + /// Tests that releasing context for unknown host doesn't throw. + /// + [Fact] + public void FeatureFlagCache_ReleaseUnknownHost_DoesNotThrow() + { + // Arrange + var cache = new FeatureFlagCache(); + var unknownHost = "never-created-host.databricks.com"; + + // Act & Assert - Should not throw + cache.ReleaseContext(unknownHost); + cache.ReleaseContext(null!); + cache.ReleaseContext(""); + cache.ReleaseContext(" "); + + // All should complete without exception + Assert.Equal(0, cache.CachedHostCount); + } + + /// + /// Tests that multiple hosts can have independent reference counts. + /// + [Fact] + public void FeatureFlagCache_MultipleHosts_IndependentRefCounts() + { + // Arrange + var cache = new FeatureFlagCache(); + var host1 = "host1.databricks.com"; + var host2 = "host2.databricks.com"; + var host3 = "host3.databricks.com"; + + // Act - Create contexts for multiple hosts + var context1a = cache.GetOrCreateContext(host1); + var context1b = cache.GetOrCreateContext(host1); + var context2 = cache.GetOrCreateContext(host2); + var context3a = cache.GetOrCreateContext(host3); + var context3b = cache.GetOrCreateContext(host3); + var context3c = cache.GetOrCreateContext(host3); + + // Assert initial state + Assert.Equal(3, cache.CachedHostCount); + Assert.Equal(2, context1a.RefCount); + Assert.Equal(1, context2.RefCount); + Assert.Equal(3, context3a.RefCount); + + // Release host2 completely + cache.ReleaseContext(host2); + Assert.Equal(2, cache.CachedHostCount); + Assert.False(cache.HasContext(host2)); + + // Host1 and Host3 should still exist + Assert.True(cache.HasContext(host1)); + Assert.True(cache.HasContext(host3)); + + // Clean up remaining + cache.ReleaseContext(host1); + cache.ReleaseContext(host1); + cache.ReleaseContext(host3); + cache.ReleaseContext(host3); + cache.ReleaseContext(host3); + + Assert.Equal(0, cache.CachedHostCount); + } + + /// + /// Tests concurrent reference counting is thread-safe. + /// + [Fact] + public async Task FeatureFlagCache_ConcurrentRefCounting_ThreadSafe() + { + // Arrange + var cache = new FeatureFlagCache(); + var host = "concurrent-host.databricks.com"; + var incrementCount = 100; + var tasks = new Task[incrementCount]; + + // Act - Concurrently create references + for (int i = 0; i < incrementCount; i++) + { + tasks[i] = Task.Run(() => cache.GetOrCreateContext(host)); + } + await Task.WhenAll(tasks); + + // Assert + Assert.True(cache.TryGetContext(host, out var context)); + Assert.Equal(incrementCount, context!.RefCount); + + // Concurrently release references + var releaseTasks = new Task[incrementCount]; + for (int i = 0; i < incrementCount; i++) + { + releaseTasks[i] = Task.Run(() => cache.ReleaseContext(host)); + } + await Task.WhenAll(releaseTasks); + + // After all releases, context should be removed + Assert.False(cache.HasContext(host)); + } + + #endregion + + #region Additional E2E Tests + + /// + /// Tests that cached false value is correctly returned. + /// + [Fact] + public async Task FeatureFlagCache_CachesFalseValue_ReturnsCorrectly() + { + // Arrange + var cache = new FeatureFlagCache(); + var host = "test-false-value-host.databricks.com"; + var fetchCount = 0; + + var fetcher = async (CancellationToken ct) => + { + Interlocked.Increment(ref fetchCount); + await Task.CompletedTask; + return false; // Return false + }; + + // Create context first + var context = cache.GetOrCreateContext(host); + + try + { + // Act + var result1 = await cache.IsTelemetryEnabledAsync(host, fetcher, CancellationToken.None); + var result2 = await cache.IsTelemetryEnabledAsync(host, fetcher, CancellationToken.None); + + // Assert + Assert.False(result1); + Assert.False(result2); + Assert.Equal(1, fetchCount); // Should only fetch once + Assert.False(context.TelemetryEnabled); + } + finally + { + cache.ReleaseContext(host); + } + } + + /// + /// Tests that cancellation is properly propagated during fetch. + /// + [Fact] + public async Task FeatureFlagCache_Cancellation_PropagatesCorrectly() + { + // Arrange + var cache = new FeatureFlagCache(); + var host = "test-cancellation-host.databricks.com"; + var cts = new CancellationTokenSource(); + + var fetcher = async (CancellationToken ct) => + { + await Task.Delay(10000, ct); // Long delay that should be cancelled + return true; + }; + + // Create context first + var context = cache.GetOrCreateContext(host); + + try + { + // Act + cts.CancelAfter(50); // Cancel after 50ms + + // Assert - TaskCanceledException inherits from OperationCanceledException + var ex = await Assert.ThrowsAnyAsync( + () => cache.IsTelemetryEnabledAsync(host, fetcher, cts.Token)); + Assert.True(ex is OperationCanceledException); + } + finally + { + cache.ReleaseContext(host); + } + } + + #endregion + } +} diff --git a/csharp/test/Unit/FeatureFlagCacheTests.cs b/csharp/test/Unit/FeatureFlagCacheTests.cs new file mode 100644 index 00000000..e61e7d91 --- /dev/null +++ b/csharp/test/Unit/FeatureFlagCacheTests.cs @@ -0,0 +1,773 @@ +/* +* Copyright (c) 2025 ADBC Drivers Contributors +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Generic; +using System.Net; +using System.Net.Http; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using AdbcDrivers.Databricks; +using Moq; +using Moq.Protected; +using Xunit; + +namespace AdbcDrivers.Databricks.Tests.Unit +{ + /// + /// Tests for FeatureFlagCache and FeatureFlagContext classes. + /// + public class FeatureFlagCacheTests + { + private const string TestHost = "test-host.databricks.com"; + private const string DriverVersion = "1.0.0"; + + #region FeatureFlagContext Tests - Basic Functionality + + [Fact] + public void FeatureFlagContext_GetFlagValue_ReturnsValue() + { + // Arrange + var flags = new Dictionary + { + ["flag1"] = "value1", + ["flag2"] = "value2" + }; + var context = CreateTestContext(flags); + + // Act & Assert + Assert.Equal("value1", context.GetFlagValue("flag1")); + Assert.Equal("value2", context.GetFlagValue("flag2")); + } + + [Fact] + public void FeatureFlagContext_GetFlagValue_NotFound_ReturnsNull() + { + // Arrange + var context = CreateTestContext(); + + // Act & Assert + Assert.Null(context.GetFlagValue("nonexistent")); + } + + [Fact] + public void FeatureFlagContext_GetFlagValue_NullOrEmpty_ReturnsNull() + { + // Arrange + var context = CreateTestContext(); + + // Act & Assert + Assert.Null(context.GetFlagValue(null!)); + Assert.Null(context.GetFlagValue("")); + Assert.Null(context.GetFlagValue(" ")); + } + + [Fact] + public void FeatureFlagContext_GetFlagValue_CaseInsensitive() + { + // Arrange + var flags = new Dictionary + { + ["MyFlag"] = "value" + }; + var context = CreateTestContext(flags); + + // Act & Assert + Assert.Equal("value", context.GetFlagValue("myflag")); + Assert.Equal("value", context.GetFlagValue("MYFLAG")); + Assert.Equal("value", context.GetFlagValue("MyFlag")); + } + + [Fact] + public void FeatureFlagContext_GetAllFlags_ReturnsAllFlags() + { + // Arrange + var flags = new Dictionary + { + ["flag1"] = "value1", + ["flag2"] = "value2", + ["flag3"] = "value3" + }; + var context = CreateTestContext(flags); + + // Act + var allFlags = context.GetAllFlags(); + + // Assert + Assert.Equal(3, allFlags.Count); + Assert.Equal("value1", allFlags["flag1"]); + Assert.Equal("value2", allFlags["flag2"]); + Assert.Equal("value3", allFlags["flag3"]); + } + + [Fact] + public void FeatureFlagContext_GetAllFlags_ReturnsSnapshot() + { + // Arrange + var context = CreateTestContext(); + context.SetFlag("flag1", "value1"); + + // Act + var snapshot = context.GetAllFlags(); + context.SetFlag("flag2", "value2"); + + // Assert - snapshot should not include new flag + Assert.Single(snapshot); + Assert.Equal("value1", snapshot["flag1"]); + } + + [Fact] + public void FeatureFlagContext_GetAllFlags_Empty_ReturnsEmptyDictionary() + { + // Arrange + var context = CreateTestContext(); + + // Act + var allFlags = context.GetAllFlags(); + + // Assert + Assert.Empty(allFlags); + } + + #endregion + + #region FeatureFlagContext Tests - TTL + + [Fact] + public void FeatureFlagContext_DefaultTtl_Is15Minutes() + { + // Arrange + var context = CreateTestContext(); + + // Assert + Assert.Equal(TimeSpan.FromMinutes(15), context.Ttl); + Assert.Equal(TimeSpan.FromMinutes(15), context.RefreshInterval); // Alias + } + + [Fact] + public void FeatureFlagContext_CustomTtl() + { + // Arrange + var customTtl = TimeSpan.FromMinutes(5); + var context = CreateTestContext(null, customTtl); + + // Assert + Assert.Equal(customTtl, context.Ttl); + Assert.Equal(customTtl, context.RefreshInterval); + } + + #endregion + + #region FeatureFlagContext Tests - Dispose + + [Fact] + public void FeatureFlagContext_Dispose_CanBeCalledMultipleTimes() + { + // Arrange + var context = CreateTestContext(); + + // Act - should not throw + context.Dispose(); + context.Dispose(); + context.Dispose(); + } + + #endregion + + #region FeatureFlagContext Tests - Internal Methods + + [Fact] + public void FeatureFlagContext_SetFlag_AddsOrUpdatesFlag() + { + // Arrange + var context = CreateTestContext(); + + // Act + context.SetFlag("flag1", "value1"); + context.SetFlag("flag2", "value2"); + context.SetFlag("flag1", "updated"); + + // Assert + Assert.Equal("updated", context.GetFlagValue("flag1")); + Assert.Equal("value2", context.GetFlagValue("flag2")); + } + + [Fact] + public void FeatureFlagContext_ClearFlags_RemovesAllFlags() + { + // Arrange + var flags = new Dictionary + { + ["flag1"] = "value1", + ["flag2"] = "value2" + }; + var context = CreateTestContext(flags); + + // Act + context.ClearFlags(); + + // Assert + Assert.Empty(context.GetAllFlags()); + } + + #endregion + + #region FeatureFlagCache Singleton Tests + + [Fact] + public void FeatureFlagCache_GetInstance_ReturnsSingleton() + { + // Act + var instance1 = FeatureFlagCache.GetInstance(); + var instance2 = FeatureFlagCache.GetInstance(); + + // Assert + Assert.Same(instance1, instance2); + } + + #endregion + + #region FeatureFlagCache_GetOrCreateContext Tests + + [Fact] + public void FeatureFlagCache_GetOrCreateContext_NewHost_CreatesContext() + { + // Arrange + var cache = new FeatureFlagCache(); + var httpClient = CreateMockHttpClient(new FeatureFlagsResponse()); + + // Act + var context = cache.GetOrCreateContext("test-host-1.databricks.com", httpClient, DriverVersion); + + // Assert + Assert.NotNull(context); + Assert.True(cache.HasContext("test-host-1.databricks.com")); + + // Cleanup + cache.Clear(); + } + + [Fact] + public void FeatureFlagCache_GetOrCreateContext_ExistingHost_ReturnsSameContext() + { + // Arrange + var cache = new FeatureFlagCache(); + var host = "test-host-2.databricks.com"; + var httpClient = CreateMockHttpClient(new FeatureFlagsResponse()); + + // Act + var context1 = cache.GetOrCreateContext(host, httpClient, DriverVersion); + var context2 = cache.GetOrCreateContext(host, httpClient, DriverVersion); + + // Assert + Assert.Same(context1, context2); + + // Cleanup + cache.Clear(); + } + + [Fact] + public void FeatureFlagCache_GetOrCreateContext_MultipleHosts_CreatesMultipleContexts() + { + // Arrange + var cache = new FeatureFlagCache(); + var httpClient = CreateMockHttpClient(new FeatureFlagsResponse()); + + // Act + var context1 = cache.GetOrCreateContext("host1.databricks.com", httpClient, DriverVersion); + var context2 = cache.GetOrCreateContext("host2.databricks.com", httpClient, DriverVersion); + + // Assert + Assert.NotSame(context1, context2); + Assert.Equal(2, cache.CachedHostCount); + + // Cleanup + cache.Clear(); + } + + [Fact] + public void FeatureFlagCache_GetOrCreateContext_NullHost_ThrowsException() + { + // Arrange + var cache = new FeatureFlagCache(); + var httpClient = CreateMockHttpClient(new FeatureFlagsResponse()); + + // Act & Assert + Assert.Throws(() => cache.GetOrCreateContext(null!, httpClient, DriverVersion)); + } + + [Fact] + public void FeatureFlagCache_GetOrCreateContext_EmptyHost_ThrowsException() + { + // Arrange + var cache = new FeatureFlagCache(); + var httpClient = CreateMockHttpClient(new FeatureFlagsResponse()); + + // Act & Assert + Assert.Throws(() => cache.GetOrCreateContext("", httpClient, DriverVersion)); + } + + [Fact] + public void FeatureFlagCache_GetOrCreateContext_NullHttpClient_ThrowsException() + { + // Arrange + var cache = new FeatureFlagCache(); + + // Act & Assert + Assert.Throws(() => cache.GetOrCreateContext(TestHost, null!, DriverVersion)); + } + + [Fact] + public void FeatureFlagCache_GetOrCreateContext_CaseInsensitive() + { + // Arrange + var cache = new FeatureFlagCache(); + var host = "Test-Host.Databricks.com"; + var httpClient = CreateMockHttpClient(new FeatureFlagsResponse()); + + // Act + var context1 = cache.GetOrCreateContext(host.ToLower(), httpClient, DriverVersion); + var context2 = cache.GetOrCreateContext(host.ToUpper(), httpClient, DriverVersion); + + // Assert + Assert.Same(context1, context2); + Assert.Equal(1, cache.CachedHostCount); + + // Cleanup + cache.Clear(); + } + + #endregion + + #region FeatureFlagCache_RemoveContext Tests + + [Fact] + public void FeatureFlagCache_RemoveContext_RemovesContext() + { + // Arrange + var cache = new FeatureFlagCache(); + var host = "test-host-3.databricks.com"; + var httpClient = CreateMockHttpClient(new FeatureFlagsResponse()); + cache.GetOrCreateContext(host, httpClient, DriverVersion); + + // Act + cache.RemoveContext(host); + + // Assert + Assert.False(cache.HasContext(host)); + Assert.Equal(0, cache.CachedHostCount); + } + + [Fact] + public void FeatureFlagCache_RemoveContext_UnknownHost_DoesNothing() + { + // Arrange + var cache = new FeatureFlagCache(); + + // Act - should not throw + cache.RemoveContext("unknown-host.databricks.com"); + + // Assert + Assert.Equal(0, cache.CachedHostCount); + } + + [Fact] + public void FeatureFlagCache_RemoveContext_NullHost_DoesNothing() + { + // Arrange + var cache = new FeatureFlagCache(); + + // Act - should not throw + cache.RemoveContext(null!); + } + + #endregion + + #region FeatureFlagCache with API Response Tests + + [Fact] + public void FeatureFlagCache_GetOrCreateContext_ParsesFlags() + { + // Arrange + var cache = new FeatureFlagCache(); + var response = new FeatureFlagsResponse + { + Flags = new List + { + new FeatureFlagEntry { Name = "flag1", Value = "value1" }, + new FeatureFlagEntry { Name = "flag2", Value = "true" } + }, + TtlSeconds = 300 + }; + var httpClient = CreateMockHttpClient(response); + + // Act + var context = cache.GetOrCreateContext("test-api.databricks.com", httpClient, DriverVersion); + + // Assert + Assert.Equal("value1", context.GetFlagValue("flag1")); + Assert.Equal("true", context.GetFlagValue("flag2")); + + // Cleanup + cache.Clear(); + } + + [Fact] + public void FeatureFlagCache_GetOrCreateContext_UpdatesTtl() + { + // Arrange + var cache = new FeatureFlagCache(); + var response = new FeatureFlagsResponse + { + Flags = new List(), + TtlSeconds = 300 // 5 minutes + }; + var httpClient = CreateMockHttpClient(response); + + // Act + var context = cache.GetOrCreateContext("test-ttl.databricks.com", httpClient, DriverVersion); + + // Assert + Assert.Equal(TimeSpan.FromSeconds(300), context.Ttl); + + // Cleanup + cache.Clear(); + } + + [Fact] + public void FeatureFlagCache_GetOrCreateContext_ApiError_DoesNotThrow() + { + // Arrange + var cache = new FeatureFlagCache(); + var httpClient = CreateMockHttpClient(HttpStatusCode.InternalServerError); + + // Act - should not throw + var context = cache.GetOrCreateContext("test-error.databricks.com", httpClient, DriverVersion); + + // Assert + Assert.NotNull(context); + Assert.Empty(context.GetAllFlags()); + + // Cleanup + cache.Clear(); + } + + #endregion + + #region FeatureFlagCache Helper Method Tests + + [Fact] + public void FeatureFlagCache_TryGetContext_ExistingContext_ReturnsTrue() + { + // Arrange + var cache = new FeatureFlagCache(); + var host = "try-get-host.databricks.com"; + var httpClient = CreateMockHttpClient(new FeatureFlagsResponse()); + var expectedContext = cache.GetOrCreateContext(host, httpClient, DriverVersion); + + // Act + var result = cache.TryGetContext(host, out var context); + + // Assert + Assert.True(result); + Assert.Same(expectedContext, context); + + // Cleanup + cache.Clear(); + } + + [Fact] + public void FeatureFlagCache_TryGetContext_UnknownHost_ReturnsFalse() + { + // Arrange + var cache = new FeatureFlagCache(); + + // Act + var result = cache.TryGetContext("unknown.databricks.com", out var context); + + // Assert + Assert.False(result); + Assert.Null(context); + } + + [Fact] + public void FeatureFlagCache_Clear_RemovesAllContexts() + { + // Arrange + var cache = new FeatureFlagCache(); + var httpClient = CreateMockHttpClient(new FeatureFlagsResponse()); + cache.GetOrCreateContext("host1.databricks.com", httpClient, DriverVersion); + cache.GetOrCreateContext("host2.databricks.com", httpClient, DriverVersion); + cache.GetOrCreateContext("host3.databricks.com", httpClient, DriverVersion); + Assert.Equal(3, cache.CachedHostCount); + + // Act + cache.Clear(); + + // Assert + Assert.Equal(0, cache.CachedHostCount); + } + + #endregion + + #region Async Initial Fetch Tests + + [Fact] + public async Task FeatureFlagCache_GetOrCreateContextAsync_AwaitsInitialFetch_FlagsAvailableImmediately() + { + // Arrange + var cache = new FeatureFlagCache(); + var response = new FeatureFlagsResponse + { + Flags = new List + { + new FeatureFlagEntry { Name = "async_flag1", Value = "async_value1" }, + new FeatureFlagEntry { Name = "async_flag2", Value = "async_value2" } + }, + TtlSeconds = 300 + }; + var httpClient = CreateMockHttpClient(response); + + // Act - Use async method explicitly + var context = await cache.GetOrCreateContextAsync("test-async.databricks.com", httpClient, DriverVersion); + + // Assert - Flags should be immediately available after await completes + // This verifies that GetOrCreateContextAsync waits for the initial fetch + Assert.Equal("async_value1", context.GetFlagValue("async_flag1")); + Assert.Equal("async_value2", context.GetFlagValue("async_flag2")); + Assert.Equal(2, context.GetAllFlags().Count); + + // Cleanup + cache.Clear(); + } + + [Fact] + public async Task FeatureFlagCache_GetOrCreateContextAsync_WithDelayedResponse_StillAwaitsInitialFetch() + { + // Arrange - Create a mock that simulates network delay + var cache = new FeatureFlagCache(); + var response = new FeatureFlagsResponse + { + Flags = new List + { + new FeatureFlagEntry { Name = "delayed_flag", Value = "delayed_value" } + }, + TtlSeconds = 300 + }; + var httpClient = CreateDelayedMockHttpClient(response, delayMs: 100); + + // Act - Measure time to verify we actually waited + var stopwatch = System.Diagnostics.Stopwatch.StartNew(); + var context = await cache.GetOrCreateContextAsync("test-delayed.databricks.com", httpClient, DriverVersion); + stopwatch.Stop(); + + // Assert - Should have waited for the delayed response + Assert.True(stopwatch.ElapsedMilliseconds >= 50, "Should have waited for the delayed fetch"); + + // Flags should be available immediately after await + Assert.Equal("delayed_value", context.GetFlagValue("delayed_flag")); + + // Cleanup + cache.Clear(); + } + + [Fact] + public async Task FeatureFlagContext_CreateAsync_AwaitsInitialFetch_FlagsPopulated() + { + // Arrange + var response = new FeatureFlagsResponse + { + Flags = new List + { + new FeatureFlagEntry { Name = "create_async_flag", Value = "create_async_value" } + }, + TtlSeconds = 600 + }; + var httpClient = CreateMockHttpClient(response); + + // Act - Call CreateAsync directly + var context = await FeatureFlagContext.CreateAsync( + "test-create-async.databricks.com", + httpClient, + DriverVersion); + + // Assert - Flags should be populated after CreateAsync completes + Assert.Equal("create_async_value", context.GetFlagValue("create_async_flag")); + Assert.Equal(TimeSpan.FromSeconds(600), context.Ttl); + + // Cleanup + context.Dispose(); + } + + #endregion + + #region Thread Safety Tests + + [Fact] + public async Task FeatureFlagCache_ConcurrentGetOrCreateContext_ThreadSafe() + { + // Arrange + var cache = new FeatureFlagCache(); + var host = "concurrent-host.databricks.com"; + var httpClient = CreateMockHttpClient(new FeatureFlagsResponse()); + var tasks = new Task[100]; + + // Act + for (int i = 0; i < 100; i++) + { + tasks[i] = Task.Run(() => cache.GetOrCreateContext(host, httpClient, DriverVersion)); + } + + var contexts = await Task.WhenAll(tasks); + + // Assert - All should be the same context + var firstContext = contexts[0]; + Assert.All(contexts, ctx => Assert.Same(firstContext, ctx)); + + // Cleanup + cache.Clear(); + } + + [Fact] + public async Task FeatureFlagContext_ConcurrentFlagAccess_ThreadSafe() + { + // Arrange + var flags = new Dictionary + { + ["flag1"] = "value1", + ["flag2"] = "value2" + }; + var context = CreateTestContext(flags); + var tasks = new Task[100]; + + // Act - Concurrent reads and writes + for (int i = 0; i < 100; i++) + { + var index = i; + tasks[i] = Task.Run(() => + { + // Read + var value = context.GetFlagValue("flag1"); + var all = context.GetAllFlags(); + var flag2Value = context.GetFlagValue("flag2"); + + // Write + context.SetFlag($"new_flag_{index}", $"value_{index}"); + }); + } + + await Task.WhenAll(tasks); + + // Assert - No exceptions thrown, all flags accessible + Assert.Equal("value1", context.GetFlagValue("flag1")); + var allFlags = context.GetAllFlags(); + Assert.True(allFlags.Count >= 2); // At least original flags + } + + #endregion + + #region Helper Methods + + /// + /// Creates a FeatureFlagContext for unit testing with pre-populated flags. + /// Does not make API calls or start background refresh. + /// + private static FeatureFlagContext CreateTestContext( + IReadOnlyDictionary? initialFlags = null, + TimeSpan? ttl = null) + { + var context = new FeatureFlagContext( + host: "test-host", + httpClient: null, + driverVersion: DriverVersion, + endpointFormat: null); + + if (ttl.HasValue) + { + context.Ttl = ttl.Value; + } + + if (initialFlags != null) + { + foreach (var kvp in initialFlags) + { + context.SetFlag(kvp.Key, kvp.Value); + } + } + + return context; + } + + private static HttpClient CreateMockHttpClient(FeatureFlagsResponse response) + { + var json = JsonSerializer.Serialize(response); + return CreateMockHttpClient(HttpStatusCode.OK, json); + } + + private static HttpClient CreateMockHttpClient(HttpStatusCode statusCode, string content = "") + { + var mockHandler = new Mock(); + mockHandler.Protected() + .Setup>( + "SendAsync", + ItExpr.IsAny(), + ItExpr.IsAny()) + .ReturnsAsync(new HttpResponseMessage + { + StatusCode = statusCode, + Content = new StringContent(content) + }); + + return new HttpClient(mockHandler.Object) + { + BaseAddress = new Uri("https://test.databricks.com") + }; + } + + private static HttpClient CreateMockHttpClient(HttpStatusCode statusCode) + { + return CreateMockHttpClient(statusCode, ""); + } + + private static HttpClient CreateDelayedMockHttpClient(FeatureFlagsResponse response, int delayMs) + { + var json = JsonSerializer.Serialize(response); + var mockHandler = new Mock(); + mockHandler.Protected() + .Setup>( + "SendAsync", + ItExpr.IsAny(), + ItExpr.IsAny()) + .Returns(async (HttpRequestMessage request, CancellationToken token) => + { + await Task.Delay(delayMs, token); + return new HttpResponseMessage + { + StatusCode = HttpStatusCode.OK, + Content = new StringContent(json) + }; + }); + + return new HttpClient(mockHandler.Object) + { + BaseAddress = new Uri("https://test.databricks.com") + }; + } + + #endregion + } +} diff --git a/csharp/test/Unit/Telemetry/CircuitBreakerManagerTests.cs b/csharp/test/Unit/Telemetry/CircuitBreakerManagerTests.cs new file mode 100644 index 00000000..c64914f4 --- /dev/null +++ b/csharp/test/Unit/Telemetry/CircuitBreakerManagerTests.cs @@ -0,0 +1,642 @@ +/* +* Copyright (c) 2025 ADBC Drivers Contributors +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using AdbcDrivers.Databricks.Telemetry; +using Xunit; + +namespace AdbcDrivers.Databricks.Tests.Unit.Telemetry +{ + /// + /// Tests for CircuitBreakerManager class. + /// + public class CircuitBreakerManagerTests + { + #region Singleton Tests + + [Fact] + public void CircuitBreakerManager_GetInstance_ReturnsSingleton() + { + // Act + var instance1 = CircuitBreakerManager.GetInstance(); + var instance2 = CircuitBreakerManager.GetInstance(); + + // Assert + Assert.NotNull(instance1); + Assert.Same(instance1, instance2); + } + + #endregion + + #region GetCircuitBreaker - New Host Tests + + [Fact] + public void CircuitBreakerManager_GetCircuitBreaker_NewHost_CreatesBreaker() + { + // Arrange + var manager = new CircuitBreakerManager(); + var host = "test-host-new.databricks.com"; + + // Act + var circuitBreaker = manager.GetCircuitBreaker(host); + + // Assert + Assert.NotNull(circuitBreaker); + Assert.Equal(1, manager.CircuitBreakerCount); + Assert.True(manager.HasCircuitBreaker(host)); + } + + [Fact] + public void CircuitBreakerManager_GetCircuitBreaker_NewHost_UsesDefaultConfig() + { + // Arrange + var defaultConfig = new CircuitBreakerConfig + { + FailureThreshold = 10, + Timeout = TimeSpan.FromMinutes(2), + SuccessThreshold = 3 + }; + var manager = new CircuitBreakerManager(defaultConfig); + var host = "test-host-config.databricks.com"; + + // Act + var circuitBreaker = manager.GetCircuitBreaker(host); + + // Assert + Assert.NotNull(circuitBreaker); + Assert.Equal(10, circuitBreaker.Config.FailureThreshold); + Assert.Equal(TimeSpan.FromMinutes(2), circuitBreaker.Config.Timeout); + Assert.Equal(3, circuitBreaker.Config.SuccessThreshold); + } + + [Fact] + public void CircuitBreakerManager_GetCircuitBreaker_NullHost_ThrowsException() + { + // Arrange + var manager = new CircuitBreakerManager(); + + // Act & Assert + Assert.Throws(() => manager.GetCircuitBreaker(null!)); + } + + [Fact] + public void CircuitBreakerManager_GetCircuitBreaker_EmptyHost_ThrowsException() + { + // Arrange + var manager = new CircuitBreakerManager(); + + // Act & Assert + Assert.Throws(() => manager.GetCircuitBreaker(string.Empty)); + } + + [Fact] + public void CircuitBreakerManager_GetCircuitBreaker_WhitespaceHost_ThrowsException() + { + // Arrange + var manager = new CircuitBreakerManager(); + + // Act & Assert + Assert.Throws(() => manager.GetCircuitBreaker(" ")); + } + + #endregion + + #region GetCircuitBreaker - Same Host Tests + + [Fact] + public void CircuitBreakerManager_GetCircuitBreaker_SameHost_ReturnsSameBreaker() + { + // Arrange + var manager = new CircuitBreakerManager(); + var host = "test-host-same.databricks.com"; + + // Act + var circuitBreaker1 = manager.GetCircuitBreaker(host); + var circuitBreaker2 = manager.GetCircuitBreaker(host); + + // Assert + Assert.Same(circuitBreaker1, circuitBreaker2); + Assert.Equal(1, manager.CircuitBreakerCount); + } + + [Fact] + public void CircuitBreakerManager_GetCircuitBreaker_SameHostDifferentCase_ReturnsSameBreaker() + { + // Arrange + var manager = new CircuitBreakerManager(); + + // Act + var circuitBreaker1 = manager.GetCircuitBreaker("TEST-HOST.databricks.com"); + var circuitBreaker2 = manager.GetCircuitBreaker("test-host.databricks.com"); + var circuitBreaker3 = manager.GetCircuitBreaker("Test-Host.Databricks.Com"); + + // Assert + Assert.Same(circuitBreaker1, circuitBreaker2); + Assert.Same(circuitBreaker2, circuitBreaker3); + Assert.Equal(1, manager.CircuitBreakerCount); + } + + [Fact] + public void CircuitBreakerManager_GetCircuitBreaker_MultipleCallsSameHost_ReturnsSameInstance() + { + // Arrange + var manager = new CircuitBreakerManager(); + var host = "test-host-multiple.databricks.com"; + var instances = new List(); + + // Act + for (int i = 0; i < 100; i++) + { + instances.Add(manager.GetCircuitBreaker(host)); + } + + // Assert + Assert.All(instances, cb => Assert.Same(instances[0], cb)); + Assert.Equal(1, manager.CircuitBreakerCount); + } + + #endregion + + #region GetCircuitBreaker - Different Hosts Tests + + [Fact] + public void CircuitBreakerManager_GetCircuitBreaker_DifferentHosts_CreatesSeparateBreakers() + { + // Arrange + var manager = new CircuitBreakerManager(); + var host1 = "host1.databricks.com"; + var host2 = "host2.databricks.com"; + + // Act + var circuitBreaker1 = manager.GetCircuitBreaker(host1); + var circuitBreaker2 = manager.GetCircuitBreaker(host2); + + // Assert + Assert.NotNull(circuitBreaker1); + Assert.NotNull(circuitBreaker2); + Assert.NotSame(circuitBreaker1, circuitBreaker2); + Assert.Equal(2, manager.CircuitBreakerCount); + } + + [Fact] + public void CircuitBreakerManager_GetCircuitBreaker_ManyHosts_CreatesAllBreakers() + { + // Arrange + var manager = new CircuitBreakerManager(); + var hosts = new string[] + { + "host1.databricks.com", + "host2.databricks.com", + "host3.databricks.com", + "host4.databricks.com", + "host5.databricks.com" + }; + + // Act + var circuitBreakers = new Dictionary(); + foreach (var host in hosts) + { + circuitBreakers[host] = manager.GetCircuitBreaker(host); + } + + // Assert + Assert.Equal(5, manager.CircuitBreakerCount); + foreach (var host in hosts) + { + Assert.True(manager.HasCircuitBreaker(host)); + } + } + + #endregion + + #region GetCircuitBreaker with Config Tests + + [Fact] + public void CircuitBreakerManager_GetCircuitBreakerWithConfig_NewHost_UsesProvidedConfig() + { + // Arrange + var manager = new CircuitBreakerManager(); + var host = "test-host-custom.databricks.com"; + var customConfig = new CircuitBreakerConfig + { + FailureThreshold = 15, + Timeout = TimeSpan.FromMinutes(5), + SuccessThreshold = 4 + }; + + // Act + var circuitBreaker = manager.GetCircuitBreaker(host, customConfig); + + // Assert + Assert.NotNull(circuitBreaker); + Assert.Equal(15, circuitBreaker.Config.FailureThreshold); + Assert.Equal(TimeSpan.FromMinutes(5), circuitBreaker.Config.Timeout); + Assert.Equal(4, circuitBreaker.Config.SuccessThreshold); + } + + [Fact] + public void CircuitBreakerManager_GetCircuitBreakerWithConfig_ExistingHost_ReturnsExistingBreaker() + { + // Arrange + var manager = new CircuitBreakerManager(); + var host = "test-host-existing.databricks.com"; + var originalConfig = new CircuitBreakerConfig + { + FailureThreshold = 10 + }; + var newConfig = new CircuitBreakerConfig + { + FailureThreshold = 20 + }; + + // Act + var circuitBreaker1 = manager.GetCircuitBreaker(host, originalConfig); + var circuitBreaker2 = manager.GetCircuitBreaker(host, newConfig); + + // Assert + Assert.Same(circuitBreaker1, circuitBreaker2); + Assert.Equal(10, circuitBreaker2.Config.FailureThreshold); // Original config retained + } + + [Fact] + public void CircuitBreakerManager_GetCircuitBreakerWithConfig_NullConfig_ThrowsException() + { + // Arrange + var manager = new CircuitBreakerManager(); + + // Act & Assert + Assert.Throws(() => + manager.GetCircuitBreaker("valid-host.databricks.com", null!)); + } + + #endregion + + #region Thread Safety Tests + + [Fact] + public async Task CircuitBreakerManager_ConcurrentGetCircuitBreaker_SameHost_ThreadSafe() + { + // Arrange + var manager = new CircuitBreakerManager(); + var host = "concurrent-host.databricks.com"; + var circuitBreakers = new CircuitBreaker[100]; + var tasks = new Task[100]; + + // Act + for (int i = 0; i < 100; i++) + { + int index = i; + tasks[i] = Task.Run(() => + { + circuitBreakers[index] = manager.GetCircuitBreaker(host); + }); + } + + await Task.WhenAll(tasks); + + // Assert + Assert.Equal(1, manager.CircuitBreakerCount); + Assert.All(circuitBreakers, cb => Assert.Same(circuitBreakers[0], cb)); + } + + [Fact] + public async Task CircuitBreakerManager_ConcurrentGetCircuitBreaker_DifferentHosts_ThreadSafe() + { + // Arrange + var manager = new CircuitBreakerManager(); + var hostCount = 50; + var tasks = new Task[hostCount]; + + // Act + for (int i = 0; i < hostCount; i++) + { + int index = i; + tasks[i] = Task.Run(() => + { + manager.GetCircuitBreaker($"host{index}.databricks.com"); + }); + } + + await Task.WhenAll(tasks); + + // Assert + Assert.Equal(hostCount, manager.CircuitBreakerCount); + } + + #endregion + + #region Per-Host Isolation Tests + + [Fact] + public async Task CircuitBreakerManager_PerHostIsolation_FailureInOneHostDoesNotAffectOther() + { + // Arrange + var config = new CircuitBreakerConfig { FailureThreshold = 2 }; + var manager = new CircuitBreakerManager(config); + var host1 = "host1-isolation.databricks.com"; + var host2 = "host2-isolation.databricks.com"; + + var cb1 = manager.GetCircuitBreaker(host1); + var cb2 = manager.GetCircuitBreaker(host2); + + // Act - Cause failures on host1 to open its circuit + for (int i = 0; i < 2; i++) + { + try + { + await cb1.ExecuteAsync(() => throw new Exception("Failure")); + } + catch { } + } + + // Assert - Host1 circuit is open, Host2 circuit is still closed + Assert.Equal(CircuitBreakerState.Open, cb1.State); + Assert.Equal(CircuitBreakerState.Closed, cb2.State); + + // Host2 should still execute successfully + var executed = false; + await cb2.ExecuteAsync(async () => + { + executed = true; + await Task.CompletedTask; + }); + + Assert.True(executed); + } + + [Fact] + public async Task CircuitBreakerManager_PerHostIsolation_IndependentStateTransitions() + { + // Arrange + var config = new CircuitBreakerConfig + { + FailureThreshold = 1, + Timeout = TimeSpan.FromMilliseconds(100), + SuccessThreshold = 1 + }; + var manager = new CircuitBreakerManager(config); + var host1 = "host1-state.databricks.com"; + var host2 = "host2-state.databricks.com"; + var host3 = "host3-state.databricks.com"; + + var cb1 = manager.GetCircuitBreaker(host1); + var cb2 = manager.GetCircuitBreaker(host2); + var cb3 = manager.GetCircuitBreaker(host3); + + // Act - Put cb1 in Open state + try { await cb1.ExecuteAsync(() => throw new Exception("Failure")); } catch { } + Assert.Equal(CircuitBreakerState.Open, cb1.State); + + // Act - Put cb2 in HalfOpen state (Open then wait for timeout) + try { await cb2.ExecuteAsync(() => throw new Exception("Failure")); } catch { } + await Task.Delay(150); + // Transition to HalfOpen happens on next execute attempt + await cb2.ExecuteAsync(async () => await Task.CompletedTask); + + // cb3 stays Closed (no failures) + + // Assert - Each host has independent state + Assert.Equal(CircuitBreakerState.Open, cb1.State); + // cb2 is either HalfOpen or Closed depending on SuccessThreshold + Assert.True(cb2.State == CircuitBreakerState.HalfOpen || cb2.State == CircuitBreakerState.Closed); + Assert.Equal(CircuitBreakerState.Closed, cb3.State); + } + + #endregion + + #region HasCircuitBreaker Tests + + [Fact] + public void CircuitBreakerManager_HasCircuitBreaker_ExistingHost_ReturnsTrue() + { + // Arrange + var manager = new CircuitBreakerManager(); + var host = "existing-host.databricks.com"; + manager.GetCircuitBreaker(host); + + // Act + var exists = manager.HasCircuitBreaker(host); + + // Assert + Assert.True(exists); + } + + [Fact] + public void CircuitBreakerManager_HasCircuitBreaker_NonExistingHost_ReturnsFalse() + { + // Arrange + var manager = new CircuitBreakerManager(); + + // Act + var exists = manager.HasCircuitBreaker("non-existing.databricks.com"); + + // Assert + Assert.False(exists); + } + + [Fact] + public void CircuitBreakerManager_HasCircuitBreaker_NullHost_ReturnsFalse() + { + // Arrange + var manager = new CircuitBreakerManager(); + + // Act + var exists = manager.HasCircuitBreaker(null!); + + // Assert + Assert.False(exists); + } + + [Fact] + public void CircuitBreakerManager_HasCircuitBreaker_EmptyHost_ReturnsFalse() + { + // Arrange + var manager = new CircuitBreakerManager(); + + // Act + var exists = manager.HasCircuitBreaker(string.Empty); + + // Assert + Assert.False(exists); + } + + #endregion + + #region TryGetCircuitBreaker Tests + + [Fact] + public void CircuitBreakerManager_TryGetCircuitBreaker_ExistingHost_ReturnsTrue() + { + // Arrange + var manager = new CircuitBreakerManager(); + var host = "try-get-host.databricks.com"; + var originalCircuitBreaker = manager.GetCircuitBreaker(host); + + // Act + var found = manager.TryGetCircuitBreaker(host, out var circuitBreaker); + + // Assert + Assert.True(found); + Assert.Same(originalCircuitBreaker, circuitBreaker); + } + + [Fact] + public void CircuitBreakerManager_TryGetCircuitBreaker_NonExistingHost_ReturnsFalse() + { + // Arrange + var manager = new CircuitBreakerManager(); + + // Act + var found = manager.TryGetCircuitBreaker("non-existing.databricks.com", out var circuitBreaker); + + // Assert + Assert.False(found); + Assert.Null(circuitBreaker); + } + + [Fact] + public void CircuitBreakerManager_TryGetCircuitBreaker_NullHost_ReturnsFalse() + { + // Arrange + var manager = new CircuitBreakerManager(); + + // Act + var found = manager.TryGetCircuitBreaker(null!, out var circuitBreaker); + + // Assert + Assert.False(found); + Assert.Null(circuitBreaker); + } + + #endregion + + #region RemoveCircuitBreaker Tests + + [Fact] + public void CircuitBreakerManager_RemoveCircuitBreaker_ExistingHost_ReturnsTrue() + { + // Arrange + var manager = new CircuitBreakerManager(); + var host = "remove-host.databricks.com"; + manager.GetCircuitBreaker(host); + + // Act + var removed = manager.RemoveCircuitBreaker(host); + + // Assert + Assert.True(removed); + Assert.False(manager.HasCircuitBreaker(host)); + Assert.Equal(0, manager.CircuitBreakerCount); + } + + [Fact] + public void CircuitBreakerManager_RemoveCircuitBreaker_NonExistingHost_ReturnsFalse() + { + // Arrange + var manager = new CircuitBreakerManager(); + + // Act + var removed = manager.RemoveCircuitBreaker("non-existing.databricks.com"); + + // Assert + Assert.False(removed); + } + + [Fact] + public void CircuitBreakerManager_RemoveCircuitBreaker_NullHost_ReturnsFalse() + { + // Arrange + var manager = new CircuitBreakerManager(); + + // Act + var removed = manager.RemoveCircuitBreaker(null!); + + // Assert + Assert.False(removed); + } + + [Fact] + public void CircuitBreakerManager_RemoveCircuitBreaker_AfterRemoval_GetCreatesNewBreaker() + { + // Arrange + var manager = new CircuitBreakerManager(); + var host = "remove-recreate-host.databricks.com"; + var originalCircuitBreaker = manager.GetCircuitBreaker(host); + manager.RemoveCircuitBreaker(host); + + // Act + var newCircuitBreaker = manager.GetCircuitBreaker(host); + + // Assert + Assert.NotSame(originalCircuitBreaker, newCircuitBreaker); + } + + #endregion + + #region Clear Tests + + [Fact] + public void CircuitBreakerManager_Clear_RemovesAllBreakers() + { + // Arrange + var manager = new CircuitBreakerManager(); + manager.GetCircuitBreaker("host1.databricks.com"); + manager.GetCircuitBreaker("host2.databricks.com"); + manager.GetCircuitBreaker("host3.databricks.com"); + + // Act + manager.Clear(); + + // Assert + Assert.Equal(0, manager.CircuitBreakerCount); + Assert.False(manager.HasCircuitBreaker("host1.databricks.com")); + Assert.False(manager.HasCircuitBreaker("host2.databricks.com")); + Assert.False(manager.HasCircuitBreaker("host3.databricks.com")); + } + + #endregion + + #region Constructor Tests + + [Fact] + public void CircuitBreakerManager_Constructor_NullConfig_ThrowsException() + { + // Act & Assert + Assert.Throws(() => new CircuitBreakerManager(null!)); + } + + [Fact] + public void CircuitBreakerManager_DefaultConstructor_UsesDefaultConfig() + { + // Arrange + var manager = new CircuitBreakerManager(); + var host = "default-config-host.databricks.com"; + + // Act + var circuitBreaker = manager.GetCircuitBreaker(host); + + // Assert - Default config values + Assert.Equal(5, circuitBreaker.Config.FailureThreshold); + Assert.Equal(TimeSpan.FromMinutes(1), circuitBreaker.Config.Timeout); + Assert.Equal(2, circuitBreaker.Config.SuccessThreshold); + } + + #endregion + } +} diff --git a/csharp/test/Unit/Telemetry/CircuitBreakerTelemetryExporterTests.cs b/csharp/test/Unit/Telemetry/CircuitBreakerTelemetryExporterTests.cs new file mode 100644 index 00000000..239bacda --- /dev/null +++ b/csharp/test/Unit/Telemetry/CircuitBreakerTelemetryExporterTests.cs @@ -0,0 +1,535 @@ +/* +* Copyright (c) 2025 ADBC Drivers Contributors +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using AdbcDrivers.Databricks.Telemetry; +using AdbcDrivers.Databricks.Telemetry.Models; +using Xunit; + +namespace AdbcDrivers.Databricks.Tests.Unit.Telemetry +{ + /// + /// Tests for CircuitBreakerTelemetryExporter class. + /// + public class CircuitBreakerTelemetryExporterTests + { + private const string TestHost = "https://test-workspace.databricks.com"; + + #region Constructor Tests + + [Fact] + public void CircuitBreakerTelemetryExporter_Constructor_NullHost_ThrowsException() + { + // Arrange + var mockExporter = new MockTelemetryExporter(); + + // Act & Assert + Assert.Throws(() => + new CircuitBreakerTelemetryExporter(null!, mockExporter)); + } + + [Fact] + public void CircuitBreakerTelemetryExporter_Constructor_EmptyHost_ThrowsException() + { + // Arrange + var mockExporter = new MockTelemetryExporter(); + + // Act & Assert + Assert.Throws(() => + new CircuitBreakerTelemetryExporter(string.Empty, mockExporter)); + } + + [Fact] + public void CircuitBreakerTelemetryExporter_Constructor_WhitespaceHost_ThrowsException() + { + // Arrange + var mockExporter = new MockTelemetryExporter(); + + // Act & Assert + Assert.Throws(() => + new CircuitBreakerTelemetryExporter(" ", mockExporter)); + } + + [Fact] + public void CircuitBreakerTelemetryExporter_Constructor_NullInnerExporter_ThrowsException() + { + // Act & Assert + Assert.Throws(() => + new CircuitBreakerTelemetryExporter(TestHost, null!)); + } + + [Fact] + public void CircuitBreakerTelemetryExporter_Constructor_NullCircuitBreakerManager_ThrowsException() + { + // Arrange + var mockExporter = new MockTelemetryExporter(); + + // Act & Assert + Assert.Throws(() => + new CircuitBreakerTelemetryExporter(TestHost, mockExporter, null!)); + } + + [Fact] + public void CircuitBreakerTelemetryExporter_Constructor_ValidParameters_SetsProperties() + { + // Arrange + var mockExporter = new MockTelemetryExporter(); + + // Act + var exporter = new CircuitBreakerTelemetryExporter(TestHost, mockExporter); + + // Assert + Assert.Equal(TestHost, exporter.Host); + Assert.Same(mockExporter, exporter.InnerExporter); + } + + #endregion + + #region Circuit Closed Tests + + [Fact] + public async Task CircuitBreakerTelemetryExporter_CircuitClosed_ExportsEvents() + { + // Arrange + var mockExporter = new MockTelemetryExporter(); + var manager = new CircuitBreakerManager(); + var exporter = new CircuitBreakerTelemetryExporter(TestHost, mockExporter, manager); + + var logs = new List + { + new TelemetryFrontendLog { WorkspaceId = 12345, FrontendLogEventId = "event-1" }, + new TelemetryFrontendLog { WorkspaceId = 12345, FrontendLogEventId = "event-2" } + }; + + // Act + await exporter.ExportAsync(logs); + + // Assert + Assert.Equal(1, mockExporter.ExportCallCount); + Assert.Equal(logs, mockExporter.LastExportedLogs); + } + + [Fact] + public async Task CircuitBreakerTelemetryExporter_CircuitClosed_MultipleExports_AllSucceed() + { + // Arrange + var mockExporter = new MockTelemetryExporter(); + var manager = new CircuitBreakerManager(); + var exporter = new CircuitBreakerTelemetryExporter(TestHost, mockExporter, manager); + + var logs = new List + { + new TelemetryFrontendLog { WorkspaceId = 12345 } + }; + + // Act + await exporter.ExportAsync(logs); + await exporter.ExportAsync(logs); + await exporter.ExportAsync(logs); + + // Assert + Assert.Equal(3, mockExporter.ExportCallCount); + } + + [Fact] + public async Task CircuitBreakerTelemetryExporter_EmptyList_DoesNotCallInnerExporter() + { + // Arrange + var mockExporter = new MockTelemetryExporter(); + var manager = new CircuitBreakerManager(); + var exporter = new CircuitBreakerTelemetryExporter(TestHost, mockExporter, manager); + + // Act + await exporter.ExportAsync(new List()); + + // Assert + Assert.Equal(0, mockExporter.ExportCallCount); + } + + [Fact] + public async Task CircuitBreakerTelemetryExporter_NullList_DoesNotCallInnerExporter() + { + // Arrange + var mockExporter = new MockTelemetryExporter(); + var manager = new CircuitBreakerManager(); + var exporter = new CircuitBreakerTelemetryExporter(TestHost, mockExporter, manager); + + // Act + await exporter.ExportAsync(null!); + + // Assert + Assert.Equal(0, mockExporter.ExportCallCount); + } + + #endregion + + #region Circuit Open Tests + + [Fact] + public async Task CircuitBreakerTelemetryExporter_CircuitOpen_DropsEvents() + { + // Arrange + var config = new CircuitBreakerConfig { FailureThreshold = 2 }; + var manager = new CircuitBreakerManager(config); + var failingExporter = new MockTelemetryExporter { ShouldThrow = true }; + var exporter = new CircuitBreakerTelemetryExporter(TestHost, failingExporter, manager); + + var logs = new List + { + new TelemetryFrontendLog { WorkspaceId = 12345 } + }; + + // Cause failures to open the circuit + await exporter.ExportAsync(logs); + await exporter.ExportAsync(logs); + + // Verify circuit is open + var circuitBreaker = manager.GetCircuitBreaker(TestHost); + Assert.Equal(CircuitBreakerState.Open, circuitBreaker.State); + + // Reset call count after opening circuit + failingExporter.ExportCallCount = 0; + + // Act - Try to export while circuit is open + await exporter.ExportAsync(logs); + + // Assert - Inner exporter should NOT be called (events dropped) + Assert.Equal(0, failingExporter.ExportCallCount); + } + + [Fact] + public async Task CircuitBreakerTelemetryExporter_CircuitOpen_DoesNotThrow() + { + // Arrange + var config = new CircuitBreakerConfig { FailureThreshold = 1 }; + var manager = new CircuitBreakerManager(config); + var failingExporter = new MockTelemetryExporter { ShouldThrow = true }; + var exporter = new CircuitBreakerTelemetryExporter(TestHost, failingExporter, manager); + + var logs = new List + { + new TelemetryFrontendLog { WorkspaceId = 12345 } + }; + + // Cause failure to open the circuit + await exporter.ExportAsync(logs); + + // Verify circuit is open + var circuitBreaker = manager.GetCircuitBreaker(TestHost); + Assert.Equal(CircuitBreakerState.Open, circuitBreaker.State); + + // Act & Assert - Should not throw even though circuit is open + var exception = await Record.ExceptionAsync(() => exporter.ExportAsync(logs)); + Assert.Null(exception); + } + + [Fact] + public async Task CircuitBreakerTelemetryExporter_CircuitOpen_MultipleExportAttempts_AllDropped() + { + // Arrange + var config = new CircuitBreakerConfig { FailureThreshold = 1, Timeout = TimeSpan.FromHours(1) }; + var manager = new CircuitBreakerManager(config); + var failingExporter = new MockTelemetryExporter { ShouldThrow = true }; + var exporter = new CircuitBreakerTelemetryExporter(TestHost, failingExporter, manager); + + var logs = new List + { + new TelemetryFrontendLog { WorkspaceId = 12345 } + }; + + // Cause failure to open the circuit + await exporter.ExportAsync(logs); + + // Reset call count and stop throwing + failingExporter.ExportCallCount = 0; + failingExporter.ShouldThrow = false; + + // Act - Try multiple exports while circuit is open + await exporter.ExportAsync(logs); + await exporter.ExportAsync(logs); + await exporter.ExportAsync(logs); + + // Assert - All should be dropped (inner exporter not called) + Assert.Equal(0, failingExporter.ExportCallCount); + } + + #endregion + + #region Circuit Breaker Tracks Failure Tests + + [Fact] + public async Task CircuitBreakerTelemetryExporter_InnerExporterFails_CircuitBreakerTracksFailure() + { + // Arrange + var config = new CircuitBreakerConfig { FailureThreshold = 3 }; + var manager = new CircuitBreakerManager(config); + var failingExporter = new MockTelemetryExporter { ShouldThrow = true }; + var exporter = new CircuitBreakerTelemetryExporter(TestHost, failingExporter, manager); + + var logs = new List + { + new TelemetryFrontendLog { WorkspaceId = 12345 } + }; + + var circuitBreaker = manager.GetCircuitBreaker(TestHost); + + // Act - Cause failures + await exporter.ExportAsync(logs); + Assert.Equal(1, circuitBreaker.ConsecutiveFailures); + Assert.Equal(CircuitBreakerState.Closed, circuitBreaker.State); + + await exporter.ExportAsync(logs); + Assert.Equal(2, circuitBreaker.ConsecutiveFailures); + Assert.Equal(CircuitBreakerState.Closed, circuitBreaker.State); + + await exporter.ExportAsync(logs); + + // Assert - Circuit should now be open after 3 failures + Assert.Equal(CircuitBreakerState.Open, circuitBreaker.State); + } + + [Fact] + public async Task CircuitBreakerTelemetryExporter_InnerExporterFails_ExceptionSwallowed() + { + // Arrange + var manager = new CircuitBreakerManager(); + var failingExporter = new MockTelemetryExporter { ShouldThrow = true }; + var exporter = new CircuitBreakerTelemetryExporter(TestHost, failingExporter, manager); + + var logs = new List + { + new TelemetryFrontendLog { WorkspaceId = 12345 } + }; + + // Act & Assert - Should not throw even though inner exporter fails + var exception = await Record.ExceptionAsync(() => exporter.ExportAsync(logs)); + Assert.Null(exception); + } + + [Fact] + public async Task CircuitBreakerTelemetryExporter_InnerExporterSucceeds_CircuitBreakerResetsFailures() + { + // Arrange + var config = new CircuitBreakerConfig { FailureThreshold = 5 }; + var manager = new CircuitBreakerManager(config); + var mockExporter = new MockTelemetryExporter { ShouldThrow = true }; + var exporter = new CircuitBreakerTelemetryExporter(TestHost, mockExporter, manager); + + var logs = new List + { + new TelemetryFrontendLog { WorkspaceId = 12345 } + }; + + var circuitBreaker = manager.GetCircuitBreaker(TestHost); + + // Cause 2 failures + await exporter.ExportAsync(logs); + await exporter.ExportAsync(logs); + Assert.Equal(2, circuitBreaker.ConsecutiveFailures); + + // Now succeed + mockExporter.ShouldThrow = false; + await exporter.ExportAsync(logs); + + // Assert - Failures should be reset + Assert.Equal(0, circuitBreaker.ConsecutiveFailures); + Assert.Equal(CircuitBreakerState.Closed, circuitBreaker.State); + } + + #endregion + + #region Cancellation Tests + + [Fact] + public async Task CircuitBreakerTelemetryExporter_Cancelled_PropagatesCancellation() + { + // Arrange + var mockExporter = new MockTelemetryExporter { ShouldDelay = true }; + var manager = new CircuitBreakerManager(); + var exporter = new CircuitBreakerTelemetryExporter(TestHost, mockExporter, manager); + + var logs = new List + { + new TelemetryFrontendLog { WorkspaceId = 12345 } + }; + + using var cts = new CancellationTokenSource(); + cts.Cancel(); + + // Act & Assert - Cancellation should propagate + await Assert.ThrowsAnyAsync( + () => exporter.ExportAsync(logs, cts.Token)); + } + + #endregion + + #region Per-Host Isolation Tests + + [Fact] + public async Task CircuitBreakerTelemetryExporter_DifferentHosts_IndependentCircuitBreakers() + { + // Arrange + var config = new CircuitBreakerConfig { FailureThreshold = 1 }; + var manager = new CircuitBreakerManager(config); + + var failingExporter = new MockTelemetryExporter { ShouldThrow = true }; + var successExporter = new MockTelemetryExporter { ShouldThrow = false }; + + var host1 = "https://host1.databricks.com"; + var host2 = "https://host2.databricks.com"; + + var exporter1 = new CircuitBreakerTelemetryExporter(host1, failingExporter, manager); + var exporter2 = new CircuitBreakerTelemetryExporter(host2, successExporter, manager); + + var logs = new List + { + new TelemetryFrontendLog { WorkspaceId = 12345 } + }; + + // Act - Cause failure on host1 to open its circuit + await exporter1.ExportAsync(logs); + + // Assert - Host1 circuit is open, Host2 circuit is still closed + var cb1 = manager.GetCircuitBreaker(host1); + var cb2 = manager.GetCircuitBreaker(host2); + + Assert.Equal(CircuitBreakerState.Open, cb1.State); + Assert.Equal(CircuitBreakerState.Closed, cb2.State); + + // Reset call count + successExporter.ExportCallCount = 0; + + // Act - Export on host2 should still work + await exporter2.ExportAsync(logs); + + // Assert + Assert.Equal(1, successExporter.ExportCallCount); + } + + #endregion + + #region HalfOpen State Tests + + [Fact] + public async Task CircuitBreakerTelemetryExporter_HalfOpen_SuccessClosesCircuit() + { + // Arrange + var config = new CircuitBreakerConfig + { + FailureThreshold = 1, + Timeout = TimeSpan.FromMilliseconds(50), + SuccessThreshold = 1 + }; + var manager = new CircuitBreakerManager(config); + var mockExporter = new MockTelemetryExporter { ShouldThrow = true }; + var exporter = new CircuitBreakerTelemetryExporter(TestHost, mockExporter, manager); + + var logs = new List + { + new TelemetryFrontendLog { WorkspaceId = 12345 } + }; + + // Cause failure to open the circuit + await exporter.ExportAsync(logs); + + var circuitBreaker = manager.GetCircuitBreaker(TestHost); + Assert.Equal(CircuitBreakerState.Open, circuitBreaker.State); + + // Wait for timeout to allow transition to HalfOpen + await Task.Delay(100); + + // Now succeed + mockExporter.ShouldThrow = false; + await exporter.ExportAsync(logs); + + // Assert - Circuit should be closed after success in HalfOpen + Assert.Equal(CircuitBreakerState.Closed, circuitBreaker.State); + } + + [Fact] + public async Task CircuitBreakerTelemetryExporter_HalfOpen_FailureReopensCircuit() + { + // Arrange + var config = new CircuitBreakerConfig + { + FailureThreshold = 1, + Timeout = TimeSpan.FromMilliseconds(50), + SuccessThreshold = 2 + }; + var manager = new CircuitBreakerManager(config); + var mockExporter = new MockTelemetryExporter { ShouldThrow = true }; + var exporter = new CircuitBreakerTelemetryExporter(TestHost, mockExporter, manager); + + var logs = new List + { + new TelemetryFrontendLog { WorkspaceId = 12345 } + }; + + // Cause failure to open the circuit + await exporter.ExportAsync(logs); + + var circuitBreaker = manager.GetCircuitBreaker(TestHost); + Assert.Equal(CircuitBreakerState.Open, circuitBreaker.State); + + // Wait for timeout to allow transition to HalfOpen + await Task.Delay(100); + + // Fail again - this will transition from HalfOpen to Open + await exporter.ExportAsync(logs); + + // Assert - Circuit should be open again + Assert.Equal(CircuitBreakerState.Open, circuitBreaker.State); + } + + #endregion + + #region Mock Telemetry Exporter + + /// + /// Mock telemetry exporter for testing. + /// + private class MockTelemetryExporter : ITelemetryExporter + { + public int ExportCallCount { get; set; } + public IReadOnlyList? LastExportedLogs { get; private set; } + public bool ShouldThrow { get; set; } + public bool ShouldDelay { get; set; } + + public async Task ExportAsync(IReadOnlyList logs, CancellationToken ct = default) + { + ct.ThrowIfCancellationRequested(); + + if (ShouldDelay) + { + await Task.Delay(1000, ct); + } + + ExportCallCount++; + LastExportedLogs = logs; + + if (ShouldThrow) + { + throw new Exception("Simulated export failure"); + } + } + } + + #endregion + } +} diff --git a/csharp/test/Unit/Telemetry/DatabricksActivityListenerTests.cs b/csharp/test/Unit/Telemetry/DatabricksActivityListenerTests.cs new file mode 100644 index 00000000..255cb437 --- /dev/null +++ b/csharp/test/Unit/Telemetry/DatabricksActivityListenerTests.cs @@ -0,0 +1,796 @@ +/* +* Copyright (c) 2025 ADBC Drivers Contributors +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using AdbcDrivers.Databricks.Telemetry; +using AdbcDrivers.Databricks.Telemetry.Models; +using Xunit; + +namespace AdbcDrivers.Databricks.Tests.Unit.Telemetry +{ + /// + /// Tests for DatabricksActivityListener class. + /// + public class DatabricksActivityListenerTests : IDisposable + { + private readonly MockTelemetryExporter _mockExporter; + private readonly TelemetryConfiguration _config; + private readonly ActivitySource _databricksActivitySource; + private readonly ActivitySource _otherActivitySource; + private MetricsAggregator? _aggregator; + private DatabricksActivityListener? _listener; + + private const long TestWorkspaceId = 12345; + private const string TestUserAgent = "TestAgent/1.0"; + + public DatabricksActivityListenerTests() + { + _mockExporter = new MockTelemetryExporter(); + _config = new TelemetryConfiguration + { + Enabled = true, + BatchSize = 100, + FlushIntervalMs = 60000 // High value to control flush timing + }; + _databricksActivitySource = new ActivitySource(DatabricksActivityListener.DatabricksActivitySourceName); + _otherActivitySource = new ActivitySource("Other.ActivitySource"); + } + + public void Dispose() + { + _listener?.Dispose(); + _aggregator?.Dispose(); + _databricksActivitySource.Dispose(); + _otherActivitySource.Dispose(); + } + + #region Constructor Tests + + [Fact] + public void DatabricksActivityListener_Constructor_NullAggregator_ThrowsException() + { + // Act & Assert + Assert.Throws(() => + new DatabricksActivityListener(null!, _config)); + } + + [Fact] + public void DatabricksActivityListener_Constructor_NullConfig_ThrowsException() + { + // Arrange + _aggregator = CreateAggregator(); + + // Act & Assert + Assert.Throws(() => + new DatabricksActivityListener(_aggregator, null!)); + } + + [Fact] + public void DatabricksActivityListener_Constructor_ValidParameters_CreatesInstance() + { + // Arrange + _aggregator = CreateAggregator(); + + // Act + _listener = new DatabricksActivityListener(_aggregator, _config); + + // Assert + Assert.NotNull(_listener); + Assert.False(_listener.IsStarted); + Assert.False(_listener.IsDisposed); + } + + [Fact] + public void DatabricksActivityListener_Constructor_WithFeatureFlagChecker_CreatesInstance() + { + // Arrange + _aggregator = CreateAggregator(); + Func featureFlagChecker = () => true; + + // Act + _listener = new DatabricksActivityListener(_aggregator, _config, featureFlagChecker); + + // Assert + Assert.NotNull(_listener); + } + + #endregion + + #region Start Tests + + [Fact] + public void DatabricksActivityListener_Start_SetsIsStartedToTrue() + { + // Arrange + _aggregator = CreateAggregator(); + _listener = new DatabricksActivityListener(_aggregator, _config); + + // Act + _listener.Start(); + + // Assert + Assert.True(_listener.IsStarted); + } + + [Fact] + public void DatabricksActivityListener_Start_ListensToDatabricksActivitySource() + { + // Arrange + _aggregator = CreateAggregator(); + _listener = new DatabricksActivityListener(_aggregator, _config); + _listener.Start(); + + // Act - create and stop an activity from Databricks source + using var activity = _databricksActivitySource.StartActivity("Connection.Open"); + + // Assert - activity should be created (listener is listening) + Assert.NotNull(activity); + } + + [Fact] + public void DatabricksActivityListener_Start_IgnoresOtherActivitySources() + { + // Arrange + _aggregator = CreateAggregator(); + _listener = new DatabricksActivityListener(_aggregator, _config); + _listener.Start(); + + // Act - create an activity from a different source + // When a listener filters out a source, no activity is created + using var otherActivity = _otherActivitySource.StartActivity("SomeOperation"); + + // Assert - activity may or may not be null depending on other listeners + // The key test is that our listener's callbacks are not triggered + // This is verified in the ActivityStopped tests + } + + [Fact] + public void DatabricksActivityListener_Start_MultipleCallsAreIdempotent() + { + // Arrange + _aggregator = CreateAggregator(); + _listener = new DatabricksActivityListener(_aggregator, _config); + + // Act - start multiple times + _listener.Start(); + _listener.Start(); + _listener.Start(); + + // Assert - should still be started, no exceptions + Assert.True(_listener.IsStarted); + } + + [Fact] + public void DatabricksActivityListener_Start_AfterDispose_DoesNothing() + { + // Arrange + _aggregator = CreateAggregator(); + _listener = new DatabricksActivityListener(_aggregator, _config); + _listener.Dispose(); + + // Act - start after dispose + _listener.Start(); + + // Assert - should not be started + Assert.False(_listener.IsStarted); + Assert.True(_listener.IsDisposed); + } + + #endregion + + #region ShouldListenTo Tests + + [Fact] + public void DatabricksActivityListener_ShouldListenTo_DatabricksSource_ReturnsTrue() + { + // This is implicitly tested by Start_ListensToDatabricksActivitySource + // The activity would not be created if ShouldListenTo returned false + + // Arrange + _aggregator = CreateAggregator(); + _listener = new DatabricksActivityListener(_aggregator, _config); + _listener.Start(); + + // Act + using var activity = _databricksActivitySource.StartActivity("TestOperation"); + + // Assert + Assert.NotNull(activity); + } + + #endregion + + #region Sample Tests - Feature Flag + + [Fact] + public void DatabricksActivityListener_Sample_FeatureFlagEnabled_ReturnsAllDataAndRecorded() + { + // Arrange + var enabledConfig = new TelemetryConfiguration { Enabled = true, FlushIntervalMs = 60000 }; + _aggregator = new MetricsAggregator(_mockExporter, enabledConfig, TestWorkspaceId, TestUserAgent); + _listener = new DatabricksActivityListener(_aggregator, enabledConfig); + _listener.Start(); + + // Act + using var activity = _databricksActivitySource.StartActivity("TestOperation"); + + // Assert - activity should be created and recorded + Assert.NotNull(activity); + Assert.True(activity.Recorded); + Assert.True(activity.IsAllDataRequested); + } + + [Fact] + public void DatabricksActivityListener_Sample_FeatureFlagDisabled_ReturnsNone() + { + // Arrange + var disabledConfig = new TelemetryConfiguration { Enabled = false, FlushIntervalMs = 60000 }; + _aggregator = new MetricsAggregator(_mockExporter, disabledConfig, TestWorkspaceId, TestUserAgent); + _listener = new DatabricksActivityListener(_aggregator, disabledConfig); + _listener.Start(); + + // Act + using var activity = _databricksActivitySource.StartActivity("TestOperation"); + + // Assert - when disabled, our Sample returns None, so our listener + // won't process the activity. Note: activity may still be created + // if other listeners are registered (e.g., from other tests). + // The key assertion is that our aggregator should not process + // activities when disabled. + Assert.Equal(0, _aggregator.PendingEventCount); + } + + [Fact] + public void DatabricksActivityListener_Sample_DynamicFeatureFlagChecker_Enabled_ReturnsAllData() + { + // Arrange + bool featureFlagValue = true; + Func featureFlagChecker = () => featureFlagValue; + + _aggregator = CreateAggregator(); + _listener = new DatabricksActivityListener(_aggregator, _config, featureFlagChecker); + _listener.Start(); + + // Act + using var activity = _databricksActivitySource.StartActivity("TestOperation"); + + // Assert + Assert.NotNull(activity); + Assert.True(activity.Recorded); + } + + [Fact] + public void DatabricksActivityListener_Sample_DynamicFeatureFlagChecker_Disabled_ReturnsNone() + { + // Arrange + bool featureFlagValue = false; + Func featureFlagChecker = () => featureFlagValue; + + _aggregator = CreateAggregator(); + _listener = new DatabricksActivityListener(_aggregator, _config, featureFlagChecker); + _listener.Start(); + + // Act + using var activity = _databricksActivitySource.StartActivity("TestOperation"); + + // Assert - when feature flag is disabled, our Sample returns None, + // so our listener won't process the activity. + // Note: activity may still be created if other listeners exist. + Assert.Equal(0, _aggregator.PendingEventCount); + } + + [Fact] + public void DatabricksActivityListener_Sample_FeatureFlagCheckerThrows_ReturnsNone() + { + // Arrange + Func throwingChecker = () => throw new InvalidOperationException("Test exception"); + + _aggregator = CreateAggregator(); + _listener = new DatabricksActivityListener(_aggregator, _config, throwingChecker); + _listener.Start(); + + // Act - should not throw, should return None (activity not created) + using var activity = _databricksActivitySource.StartActivity("TestOperation"); + + // Assert - activity should not be created due to exception handling + Assert.Null(activity); + } + + #endregion + + #region ActivityStopped Tests + + [Fact] + public void DatabricksActivityListener_ActivityStopped_ProcessesActivity() + { + // Arrange + _aggregator = CreateAggregator(); + _listener = new DatabricksActivityListener(_aggregator, _config); + _listener.Start(); + + // Act + using (var activity = _databricksActivitySource.StartActivity("Connection.Open")) + { + Assert.NotNull(activity); + activity.SetTag("session.id", "test-session-123"); + activity.Stop(); + } + + // Assert - aggregator should have processed the activity + Assert.Equal(1, _aggregator.PendingEventCount); + } + + [Fact] + public void DatabricksActivityListener_ActivityStopped_ProcessesMultipleActivities() + { + // Arrange + _aggregator = CreateAggregator(); + _listener = new DatabricksActivityListener(_aggregator, _config); + _listener.Start(); + + // Act + for (int i = 0; i < 5; i++) + { + using var activity = _databricksActivitySource.StartActivity("Connection.Open"); + Assert.NotNull(activity); + activity.SetTag("session.id", $"session-{i}"); + activity.Stop(); + } + + // Assert - all activities should be processed + Assert.Equal(5, _aggregator.PendingEventCount); + } + + [Fact] + public void DatabricksActivityListener_ActivityStopped_ExceptionSwallowed() + { + // Arrange - use a throwing exporter to test exception handling + // The MetricsAggregator will handle the export exception, and + // the listener will swallow any exceptions from ProcessActivity + var throwingExporter = new ThrowingTelemetryExporter(); + var config = new TelemetryConfiguration { BatchSize = 1, FlushIntervalMs = 60000 }; + var aggregator = new MetricsAggregator(throwingExporter, config, TestWorkspaceId, TestUserAgent); + _listener = new DatabricksActivityListener(aggregator, config); + _listener.Start(); + + // Act & Assert - should not throw even though exporter throws + using (var activity = _databricksActivitySource.StartActivity("Connection.Open")) + { + Assert.NotNull(activity); + activity.SetTag("session.id", "test-session"); + activity.Stop(); + // No exception should be thrown + } + + aggregator.Dispose(); + } + + [Fact] + public void DatabricksActivityListener_ActivityStopped_NotCalledWhenDisabled() + { + // Arrange + var disabledConfig = new TelemetryConfiguration { Enabled = false, FlushIntervalMs = 60000 }; + _aggregator = new MetricsAggregator(_mockExporter, disabledConfig, TestWorkspaceId, TestUserAgent); + _listener = new DatabricksActivityListener(_aggregator, disabledConfig); + _listener.Start(); + + // Act - our listener's Sample returns None when disabled + using var activity = _databricksActivitySource.StartActivity("Connection.Open"); + + // Note: activity may be created if other listeners are active (test isolation issue) + // but if created, stop it to simulate the full lifecycle + if (activity != null) + { + activity.SetTag("session.id", "test-session"); + activity.Stop(); + } + + // Assert - when disabled, our listener's Sample returns None, + // so our listener won't receive the ActivityStopped callback and + // won't process the activity + Assert.Equal(0, _aggregator.PendingEventCount); + } + + #endregion + + #region StopAsync Tests + + [Fact] + public async Task DatabricksActivityListener_StopAsync_FlushesAndDisposes() + { + // Arrange + _aggregator = CreateAggregator(); + _listener = new DatabricksActivityListener(_aggregator, _config); + _listener.Start(); + + // Create some activities + for (int i = 0; i < 3; i++) + { + using var activity = _databricksActivitySource.StartActivity("Connection.Open"); + Assert.NotNull(activity); + activity.SetTag("session.id", $"session-{i}"); + activity.Stop(); + } + + Assert.Equal(3, _aggregator.PendingEventCount); + + // Act + await _listener.StopAsync(); + + // Assert + Assert.False(_listener.IsStarted); + Assert.True(_mockExporter.ExportCallCount > 0); + } + + [Fact] + public async Task DatabricksActivityListener_StopAsync_StopsListening() + { + // Arrange + _aggregator = CreateAggregator(); + _listener = new DatabricksActivityListener(_aggregator, _config); + _listener.Start(); + + // Create an activity before stopping + using (var activity1 = _databricksActivitySource.StartActivity("Connection.Open")) + { + Assert.NotNull(activity1); + activity1.Stop(); + } + + Assert.Equal(1, _aggregator.PendingEventCount); + + // Stop the listener + await _listener.StopAsync(); + Assert.False(_listener.IsStarted); + + // Clear the exporter count + var countBeforeNewActivity = _mockExporter.ExportCallCount; + + // Try to create another activity after stopping + using var activity2 = _databricksActivitySource.StartActivity("Connection.Open"); + + // Activity may still be created by other listeners, but our listener + // should not process it since it's stopped + } + + [Fact] + public async Task DatabricksActivityListener_StopAsync_CanBeCalledMultipleTimes() + { + // Arrange + _aggregator = CreateAggregator(); + _listener = new DatabricksActivityListener(_aggregator, _config); + _listener.Start(); + + // Act & Assert - should not throw + await _listener.StopAsync(); + await _listener.StopAsync(); + await _listener.StopAsync(); + + Assert.False(_listener.IsStarted); + } + + [Fact] + public async Task DatabricksActivityListener_StopAsync_BeforeStart_DoesNothing() + { + // Arrange + _aggregator = CreateAggregator(); + _listener = new DatabricksActivityListener(_aggregator, _config); + + // Act & Assert - should not throw + await _listener.StopAsync(); + + Assert.False(_listener.IsStarted); + } + + [Fact] + public async Task DatabricksActivityListener_StopAsync_WithCancellation_PropagatesCancellation() + { + // Arrange - use an exporter that throws on cancellation + var cancellingExporter = new CancellingTelemetryExporter(); + var aggregator = new MetricsAggregator(cancellingExporter, _config, TestWorkspaceId, TestUserAgent); + _listener = new DatabricksActivityListener(aggregator, _config); + _listener.Start(); + + // Create some activities so there's something to flush + using (var activity = _databricksActivitySource.StartActivity("Connection.Open")) + { + Assert.NotNull(activity); + activity.SetTag("session.id", "test-session"); + activity.Stop(); + } + + var cts = new CancellationTokenSource(); + cts.Cancel(); + + // Act & Assert - should propagate cancellation + await Assert.ThrowsAsync(() => + _listener.StopAsync(cts.Token)); + + aggregator.Dispose(); + } + + #endregion + + #region Dispose Tests + + [Fact] + public void DatabricksActivityListener_Dispose_SetsIsDisposedToTrue() + { + // Arrange + _aggregator = CreateAggregator(); + _listener = new DatabricksActivityListener(_aggregator, _config); + + // Act + _listener.Dispose(); + + // Assert + Assert.True(_listener.IsDisposed); + } + + [Fact] + public void DatabricksActivityListener_Dispose_FlushesRemainingEvents() + { + // Arrange + _aggregator = CreateAggregator(); + _listener = new DatabricksActivityListener(_aggregator, _config); + _listener.Start(); + + // Create some activities + for (int i = 0; i < 3; i++) + { + using var activity = _databricksActivitySource.StartActivity("Connection.Open"); + Assert.NotNull(activity); + activity.SetTag("session.id", $"session-{i}"); + activity.Stop(); + } + + // Act + _listener.Dispose(); + + // Assert - events should have been flushed via aggregator dispose + Assert.True(_mockExporter.ExportCallCount > 0); + } + + [Fact] + public void DatabricksActivityListener_Dispose_CanBeCalledMultipleTimes() + { + // Arrange + _aggregator = CreateAggregator(); + _listener = new DatabricksActivityListener(_aggregator, _config); + + // Act & Assert - should not throw + _listener.Dispose(); + _listener.Dispose(); + _listener.Dispose(); + + Assert.True(_listener.IsDisposed); + } + + [Fact] + public void DatabricksActivityListener_Dispose_StopsListening() + { + // Arrange + _aggregator = CreateAggregator(); + _listener = new DatabricksActivityListener(_aggregator, _config); + _listener.Start(); + Assert.True(_listener.IsStarted); + + // Act + _listener.Dispose(); + + // Assert + Assert.False(_listener.IsStarted); + Assert.True(_listener.IsDisposed); + } + + #endregion + + #region Integration Tests + + [Fact] + public async Task DatabricksActivityListener_EndToEnd_ConnectionActivity() + { + // Arrange + _aggregator = CreateAggregator(); + _listener = new DatabricksActivityListener(_aggregator, _config); + _listener.Start(); + + // Act - simulate connection open + using (var activity = _databricksActivitySource.StartActivity("Connection.Open")) + { + Assert.NotNull(activity); + activity.SetTag("session.id", "e2e-session-123"); + activity.SetTag("driver.version", "1.0.0"); + activity.SetTag("driver.os", "Windows"); + activity.Stop(); + } + + // Flush + await _aggregator.FlushAsync(); + + // Assert + Assert.True(_mockExporter.ExportCallCount > 0); + Assert.True(_mockExporter.TotalExportedEvents > 0); + } + + [Fact] + public async Task DatabricksActivityListener_EndToEnd_StatementActivity() + { + // Arrange + _aggregator = CreateAggregator(); + _listener = new DatabricksActivityListener(_aggregator, _config); + _listener.Start(); + + var statementId = "e2e-stmt-123"; + var sessionId = "e2e-session-456"; + + // Act - simulate statement execution + using (var activity = _databricksActivitySource.StartActivity("Statement.Execute")) + { + Assert.NotNull(activity); + activity.SetTag("statement.id", statementId); + activity.SetTag("session.id", sessionId); + activity.SetTag("result.format", "cloudfetch"); + activity.SetTag("result.chunk_count", "5"); + activity.Stop(); + } + + // Complete the statement + _aggregator.CompleteStatement(statementId); + + // Flush + await _aggregator.FlushAsync(); + + // Assert + Assert.True(_mockExporter.TotalExportedEvents > 0); + } + + [Fact] + public async Task DatabricksActivityListener_EndToEnd_DynamicFeatureFlag() + { + // Arrange + bool featureFlagEnabled = true; + Func featureFlagChecker = () => featureFlagEnabled; + + _aggregator = CreateAggregator(); + _listener = new DatabricksActivityListener(_aggregator, _config, featureFlagChecker); + _listener.Start(); + + // Act 1 - create activity while enabled + using (var activity1 = _databricksActivitySource.StartActivity("Connection.Open")) + { + Assert.NotNull(activity1); + activity1.SetTag("session.id", "enabled-session"); + activity1.Stop(); + } + + var countAfterFirstActivity = _aggregator.PendingEventCount; + Assert.Equal(1, countAfterFirstActivity); + + // Disable feature flag + featureFlagEnabled = false; + + // Act 2 - try to create activity while disabled + // Note: activity may be created if other listeners exist + using (var activity2 = _databricksActivitySource.StartActivity("Connection.Open")) + { + // Our listener should not process this activity because feature flag is disabled + // If activity exists (due to other listeners), stop it + if (activity2 != null) + { + activity2.SetTag("session.id", "disabled-session"); + activity2.Stop(); + } + } + + // Still only 1 pending event (our listener didn't process the disabled one) + Assert.Equal(countAfterFirstActivity, _aggregator.PendingEventCount); + + // Re-enable feature flag + featureFlagEnabled = true; + + // Act 3 - create activity while re-enabled + using (var activity3 = _databricksActivitySource.StartActivity("Connection.Open")) + { + Assert.NotNull(activity3); + activity3.SetTag("session.id", "reenabled-session"); + activity3.Stop(); + } + + // Now 2 pending events + Assert.Equal(2, _aggregator.PendingEventCount); + + // Flush + await _aggregator.FlushAsync(); + + // Assert + Assert.Equal(2, _mockExporter.TotalExportedEvents); + } + + #endregion + + #region Helper Methods + + private MetricsAggregator CreateAggregator() + { + return new MetricsAggregator(_mockExporter, _config, TestWorkspaceId, TestUserAgent); + } + + #endregion + + #region Mock Classes + + /// + /// Mock telemetry exporter for testing. + /// + private class MockTelemetryExporter : ITelemetryExporter + { + private int _exportCallCount; + private int _totalExportedEvents; + private readonly ConcurrentBag _exportedLogs = new ConcurrentBag(); + + public int ExportCallCount => _exportCallCount; + public int TotalExportedEvents => _totalExportedEvents; + public IReadOnlyCollection ExportedLogs => _exportedLogs.ToList(); + + public Task ExportAsync(IReadOnlyList logs, CancellationToken ct = default) + { + ct.ThrowIfCancellationRequested(); + + Interlocked.Increment(ref _exportCallCount); + Interlocked.Add(ref _totalExportedEvents, logs.Count); + + foreach (var log in logs) + { + _exportedLogs.Add(log); + } + + return Task.CompletedTask; + } + } + + /// + /// Telemetry exporter that throws for testing exception handling. + /// + private class ThrowingTelemetryExporter : ITelemetryExporter + { + public Task ExportAsync(IReadOnlyList logs, CancellationToken ct = default) + { + throw new InvalidOperationException("Test exception from exporter"); + } + } + + /// + /// Telemetry exporter that throws OperationCanceledException for testing cancellation. + /// + private class CancellingTelemetryExporter : ITelemetryExporter + { + public Task ExportAsync(IReadOnlyList logs, CancellationToken ct = default) + { + ct.ThrowIfCancellationRequested(); + return Task.CompletedTask; + } + } + + #endregion + } +} diff --git a/csharp/test/Unit/Telemetry/DatabricksTelemetryExporterTests.cs b/csharp/test/Unit/Telemetry/DatabricksTelemetryExporterTests.cs new file mode 100644 index 00000000..fdb657db --- /dev/null +++ b/csharp/test/Unit/Telemetry/DatabricksTelemetryExporterTests.cs @@ -0,0 +1,603 @@ +/* +* Copyright (c) 2025 ADBC Drivers Contributors +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Generic; +using System.Net; +using System.Net.Http; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using AdbcDrivers.Databricks.Telemetry; +using AdbcDrivers.Databricks.Telemetry.Models; +using Xunit; + +namespace AdbcDrivers.Databricks.Tests.Unit.Telemetry +{ + /// + /// Tests for DatabricksTelemetryExporter class. + /// + public class DatabricksTelemetryExporterTests + { + private const string TestHost = "https://test-workspace.databricks.com"; + + #region Constructor Tests + + [Fact] + public void DatabricksTelemetryExporter_Constructor_NullHttpClient_ThrowsException() + { + // Arrange + var config = new TelemetryConfiguration(); + + // Act & Assert + Assert.Throws(() => + new DatabricksTelemetryExporter(null!, TestHost, true, config)); + } + + [Fact] + public void DatabricksTelemetryExporter_Constructor_NullHost_ThrowsException() + { + // Arrange + using var httpClient = new HttpClient(); + var config = new TelemetryConfiguration(); + + // Act & Assert + Assert.Throws(() => + new DatabricksTelemetryExporter(httpClient, null!, true, config)); + } + + [Fact] + public void DatabricksTelemetryExporter_Constructor_EmptyHost_ThrowsException() + { + // Arrange + using var httpClient = new HttpClient(); + var config = new TelemetryConfiguration(); + + // Act & Assert + Assert.Throws(() => + new DatabricksTelemetryExporter(httpClient, "", true, config)); + } + + [Fact] + public void DatabricksTelemetryExporter_Constructor_WhitespaceHost_ThrowsException() + { + // Arrange + using var httpClient = new HttpClient(); + var config = new TelemetryConfiguration(); + + // Act & Assert + Assert.Throws(() => + new DatabricksTelemetryExporter(httpClient, " ", true, config)); + } + + [Fact] + public void DatabricksTelemetryExporter_Constructor_NullConfig_ThrowsException() + { + // Arrange + using var httpClient = new HttpClient(); + + // Act & Assert + Assert.Throws(() => + new DatabricksTelemetryExporter(httpClient, TestHost, true, null!)); + } + + [Fact] + public void DatabricksTelemetryExporter_Constructor_ValidParameters_SetsProperties() + { + // Arrange + using var httpClient = new HttpClient(); + var config = new TelemetryConfiguration(); + + // Act + var exporter = new DatabricksTelemetryExporter(httpClient, TestHost, true, config); + + // Assert + Assert.Equal(TestHost, exporter.Host); + Assert.True(exporter.IsAuthenticated); + } + + #endregion + + #region Endpoint Tests + + [Fact] + public void DatabricksTelemetryExporter_ExportAsync_Authenticated_UsesCorrectEndpoint() + { + // Arrange + using var httpClient = new HttpClient(); + var config = new TelemetryConfiguration(); + var exporter = new DatabricksTelemetryExporter(httpClient, TestHost, true, config); + + // Act + var endpointUrl = exporter.GetEndpointUrl(); + + // Assert + Assert.Equal($"{TestHost}/telemetry-ext", endpointUrl); + } + + [Fact] + public void DatabricksTelemetryExporter_ExportAsync_Unauthenticated_UsesCorrectEndpoint() + { + // Arrange + using var httpClient = new HttpClient(); + var config = new TelemetryConfiguration(); + var exporter = new DatabricksTelemetryExporter(httpClient, TestHost, false, config); + + // Act + var endpointUrl = exporter.GetEndpointUrl(); + + // Assert + Assert.Equal($"{TestHost}/telemetry-unauth", endpointUrl); + } + + [Fact] + public void DatabricksTelemetryExporter_GetEndpointUrl_HostWithTrailingSlash_HandlesCorrectly() + { + // Arrange + using var httpClient = new HttpClient(); + var config = new TelemetryConfiguration(); + var exporter = new DatabricksTelemetryExporter(httpClient, $"{TestHost}/", true, config); + + // Act + var endpointUrl = exporter.GetEndpointUrl(); + + // Assert + Assert.Equal($"{TestHost}/telemetry-ext", endpointUrl); + } + + #endregion + + #region TelemetryRequest Creation Tests + + [Fact] + public void DatabricksTelemetryExporter_CreateTelemetryRequest_SingleLog_CreatesCorrectFormat() + { + // Arrange + using var httpClient = new HttpClient(); + var config = new TelemetryConfiguration(); + var exporter = new DatabricksTelemetryExporter(httpClient, TestHost, true, config); + + var logs = new List + { + new TelemetryFrontendLog + { + WorkspaceId = 12345, + FrontendLogEventId = "test-event-id" + } + }; + + // Act + var request = exporter.CreateTelemetryRequest(logs); + + // Assert + Assert.True(request.UploadTime > 0); + Assert.Single(request.ProtoLogs); + Assert.Contains("12345", request.ProtoLogs[0]); + Assert.Contains("test-event-id", request.ProtoLogs[0]); + } + + [Fact] + public void DatabricksTelemetryExporter_CreateTelemetryRequest_MultipleLogs_CreatesCorrectFormat() + { + // Arrange + using var httpClient = new HttpClient(); + var config = new TelemetryConfiguration(); + var exporter = new DatabricksTelemetryExporter(httpClient, TestHost, true, config); + + var logs = new List + { + new TelemetryFrontendLog { WorkspaceId = 1, FrontendLogEventId = "event-1" }, + new TelemetryFrontendLog { WorkspaceId = 2, FrontendLogEventId = "event-2" }, + new TelemetryFrontendLog { WorkspaceId = 3, FrontendLogEventId = "event-3" } + }; + + // Act + var request = exporter.CreateTelemetryRequest(logs); + + // Assert + Assert.Equal(3, request.ProtoLogs.Count); + } + + [Fact] + public void DatabricksTelemetryExporter_CreateTelemetryRequest_UploadTime_IsRecentTimestamp() + { + // Arrange + using var httpClient = new HttpClient(); + var config = new TelemetryConfiguration(); + var exporter = new DatabricksTelemetryExporter(httpClient, TestHost, true, config); + + var logs = new List + { + new TelemetryFrontendLog { WorkspaceId = 1 } + }; + + var beforeTime = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds(); + + // Act + var request = exporter.CreateTelemetryRequest(logs); + + var afterTime = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds(); + + // Assert + Assert.True(request.UploadTime >= beforeTime); + Assert.True(request.UploadTime <= afterTime); + } + + #endregion + + #region Serialization Tests + + [Fact] + public void DatabricksTelemetryExporter_SerializeRequest_ProducesValidJson() + { + // Arrange + using var httpClient = new HttpClient(); + var config = new TelemetryConfiguration(); + var exporter = new DatabricksTelemetryExporter(httpClient, TestHost, true, config); + + var request = new TelemetryRequest + { + UploadTime = 1234567890000, + ProtoLogs = new List { "{\"workspace_id\":12345}" } + }; + + // Act + var json = exporter.SerializeRequest(request); + + // Assert + Assert.NotEmpty(json); + var parsed = JsonDocument.Parse(json); + Assert.Equal(1234567890000, parsed.RootElement.GetProperty("uploadTime").GetInt64()); + Assert.Single(parsed.RootElement.GetProperty("protoLogs").EnumerateArray()); + } + + #endregion + + #region ExportAsync Tests with Mock Handler + + [Fact] + public async Task DatabricksTelemetryExporter_ExportAsync_Success_ReturnsTrue() + { + // Arrange + var handler = new MockHttpMessageHandler((request, ct) => + { + return Task.FromResult(new HttpResponseMessage(HttpStatusCode.OK)); + }); + + using var httpClient = new HttpClient(handler); + var config = new TelemetryConfiguration { MaxRetries = 0 }; + var exporter = new DatabricksTelemetryExporter(httpClient, TestHost, true, config); + + var logs = new List + { + new TelemetryFrontendLog { WorkspaceId = 12345 } + }; + + // Act + var result = await exporter.ExportAsync(logs); + + // Assert + Assert.True(result, "ExportAsync should return true on successful HTTP 200 response"); + } + + [Fact] + public async Task DatabricksTelemetryExporter_ExportAsync_EmptyList_ReturnsTrueWithoutRequest() + { + // Arrange + var requestCount = 0; + var handler = new MockHttpMessageHandler((request, ct) => + { + requestCount++; + return Task.FromResult(new HttpResponseMessage(HttpStatusCode.OK)); + }); + + using var httpClient = new HttpClient(handler); + var config = new TelemetryConfiguration(); + var exporter = new DatabricksTelemetryExporter(httpClient, TestHost, true, config); + + // Act + var result = await exporter.ExportAsync(new List()); + + // Assert - no HTTP request should be made and should return true + Assert.Equal(0, requestCount); + Assert.True(result, "ExportAsync should return true for empty list (nothing to export)"); + } + + [Fact] + public async Task DatabricksTelemetryExporter_ExportAsync_NullList_ReturnsTrueWithoutRequest() + { + // Arrange + var requestCount = 0; + var handler = new MockHttpMessageHandler((request, ct) => + { + requestCount++; + return Task.FromResult(new HttpResponseMessage(HttpStatusCode.OK)); + }); + + using var httpClient = new HttpClient(handler); + var config = new TelemetryConfiguration(); + var exporter = new DatabricksTelemetryExporter(httpClient, TestHost, true, config); + + // Act + var result = await exporter.ExportAsync(null!); + + // Assert - no HTTP request should be made and should return true + Assert.Equal(0, requestCount); + Assert.True(result, "ExportAsync should return true for null list (nothing to export)"); + } + + [Fact] + public async Task DatabricksTelemetryExporter_ExportAsync_TransientFailure_RetriesAndReturnsTrue() + { + // Arrange + var attemptCount = 0; + var handler = new MockHttpMessageHandler((request, ct) => + { + attemptCount++; + if (attemptCount < 3) + { + throw new HttpRequestException("Simulated transient failure"); + } + return Task.FromResult(new HttpResponseMessage(HttpStatusCode.OK)); + }); + + using var httpClient = new HttpClient(handler); + var config = new TelemetryConfiguration { MaxRetries = 3, RetryDelayMs = 10 }; + var exporter = new DatabricksTelemetryExporter(httpClient, TestHost, true, config); + + var logs = new List + { + new TelemetryFrontendLog { WorkspaceId = 12345 } + }; + + // Act + var result = await exporter.ExportAsync(logs); + + // Assert - should retry and eventually succeed + Assert.Equal(3, attemptCount); + Assert.True(result, "ExportAsync should return true after successful retry"); + } + + [Fact] + public async Task DatabricksTelemetryExporter_ExportAsync_MaxRetries_ReturnsFalse() + { + // Arrange + var attemptCount = 0; + var handler = new MockHttpMessageHandler((request, ct) => + { + attemptCount++; + throw new HttpRequestException("Simulated persistent failure"); + }); + + using var httpClient = new HttpClient(handler); + var config = new TelemetryConfiguration { MaxRetries = 3, RetryDelayMs = 10 }; + var exporter = new DatabricksTelemetryExporter(httpClient, TestHost, true, config); + + var logs = new List + { + new TelemetryFrontendLog { WorkspaceId = 12345 } + }; + + // Act + var result = await exporter.ExportAsync(logs); + + // Assert - should have tried initial attempt + max retries and return false + Assert.Equal(4, attemptCount); // 1 initial + 3 retries + Assert.False(result, "ExportAsync should return false after all retries exhausted"); + } + + [Fact] + public async Task DatabricksTelemetryExporter_ExportAsync_TerminalError_ReturnsFalseWithoutRetry() + { + // Arrange + var attemptCount = 0; + var handler = new MockHttpMessageHandler((request, ct) => + { + attemptCount++; + // Throw HttpRequestException with an inner UnauthorizedAccessException + // The ExceptionClassifier checks inner exceptions for terminal types + throw new HttpRequestException("Authentication failed", + new UnauthorizedAccessException("Unauthorized")); + }); + + using var httpClient = new HttpClient(handler); + var config = new TelemetryConfiguration { MaxRetries = 3, RetryDelayMs = 10 }; + var exporter = new DatabricksTelemetryExporter(httpClient, TestHost, true, config); + + var logs = new List + { + new TelemetryFrontendLog { WorkspaceId = 12345 } + }; + + // Act + var result = await exporter.ExportAsync(logs); + + // Assert - should not retry on terminal error and return false + Assert.Equal(1, attemptCount); + Assert.False(result, "ExportAsync should return false on terminal error"); + } + + [Fact] + public async Task DatabricksTelemetryExporter_ExportAsync_Cancelled_ThrowsCancelledException() + { + // Arrange + var handler = new MockHttpMessageHandler((request, ct) => + { + ct.ThrowIfCancellationRequested(); + return Task.FromResult(new HttpResponseMessage(HttpStatusCode.OK)); + }); + + using var httpClient = new HttpClient(handler); + var config = new TelemetryConfiguration(); + var exporter = new DatabricksTelemetryExporter(httpClient, TestHost, true, config); + + var logs = new List + { + new TelemetryFrontendLog { WorkspaceId = 12345 } + }; + + using var cts = new CancellationTokenSource(); + cts.Cancel(); + + // Act & Assert - TaskCanceledException inherits from OperationCanceledException + var ex = await Assert.ThrowsAnyAsync( + () => exporter.ExportAsync(logs, cts.Token)); + Assert.True(ex is OperationCanceledException); + } + + [Fact] + public async Task DatabricksTelemetryExporter_ExportAsync_Authenticated_SendsToCorrectEndpoint() + { + // Arrange + string? capturedUrl = null; + var handler = new MockHttpMessageHandler((request, ct) => + { + capturedUrl = request.RequestUri?.ToString(); + return Task.FromResult(new HttpResponseMessage(HttpStatusCode.OK)); + }); + + using var httpClient = new HttpClient(handler); + var config = new TelemetryConfiguration(); + var exporter = new DatabricksTelemetryExporter(httpClient, TestHost, true, config); + + var logs = new List + { + new TelemetryFrontendLog { WorkspaceId = 12345 } + }; + + // Act + await exporter.ExportAsync(logs); + + // Assert + Assert.Equal($"{TestHost}/telemetry-ext", capturedUrl); + } + + [Fact] + public async Task DatabricksTelemetryExporter_ExportAsync_Unauthenticated_SendsToCorrectEndpoint() + { + // Arrange + string? capturedUrl = null; + var handler = new MockHttpMessageHandler((request, ct) => + { + capturedUrl = request.RequestUri?.ToString(); + return Task.FromResult(new HttpResponseMessage(HttpStatusCode.OK)); + }); + + using var httpClient = new HttpClient(handler); + var config = new TelemetryConfiguration(); + var exporter = new DatabricksTelemetryExporter(httpClient, TestHost, false, config); + + var logs = new List + { + new TelemetryFrontendLog { WorkspaceId = 12345 } + }; + + // Act + await exporter.ExportAsync(logs); + + // Assert + Assert.Equal($"{TestHost}/telemetry-unauth", capturedUrl); + } + + [Fact] + public async Task DatabricksTelemetryExporter_ExportAsync_SendsValidJsonBody() + { + // Arrange + string? capturedContent = null; + var handler = new MockHttpMessageHandler(async (request, ct) => + { + capturedContent = await request.Content!.ReadAsStringAsync(); + return new HttpResponseMessage(HttpStatusCode.OK); + }); + + using var httpClient = new HttpClient(handler); + var config = new TelemetryConfiguration(); + var exporter = new DatabricksTelemetryExporter(httpClient, TestHost, true, config); + + var logs = new List + { + new TelemetryFrontendLog + { + WorkspaceId = 12345, + FrontendLogEventId = "test-event-123" + } + }; + + // Act + await exporter.ExportAsync(logs); + + // Assert + Assert.NotNull(capturedContent); + var parsedRequest = JsonDocument.Parse(capturedContent); + Assert.True(parsedRequest.RootElement.TryGetProperty("uploadTime", out _)); + Assert.True(parsedRequest.RootElement.TryGetProperty("protoLogs", out var protoLogs)); + Assert.Single(protoLogs.EnumerateArray()); + + var protoLogJson = protoLogs[0].GetString(); + Assert.Contains("12345", protoLogJson); + Assert.Contains("test-event-123", protoLogJson); + } + + [Fact] + public async Task DatabricksTelemetryExporter_ExportAsync_GenericException_ReturnsFalse() + { + // Arrange + var handler = new MockHttpMessageHandler((request, ct) => + { + throw new InvalidOperationException("Unexpected error"); + }); + + using var httpClient = new HttpClient(handler); + var config = new TelemetryConfiguration { MaxRetries = 0 }; + var exporter = new DatabricksTelemetryExporter(httpClient, TestHost, true, config); + + var logs = new List + { + new TelemetryFrontendLog { WorkspaceId = 12345 } + }; + + // Act + var result = await exporter.ExportAsync(logs); + + // Assert - should not throw but return false + Assert.False(result, "ExportAsync should return false on unexpected exception"); + } + + #endregion + + #region Mock HTTP Handler + + /// + /// Mock HttpMessageHandler for testing HTTP requests. + /// + private class MockHttpMessageHandler : HttpMessageHandler + { + private readonly Func> _handler; + + public MockHttpMessageHandler(Func> handler) + { + _handler = handler; + } + + protected override Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) + { + return _handler(request, cancellationToken); + } + } + + #endregion + } +} diff --git a/csharp/test/Unit/Telemetry/MetricsAggregatorTests.cs b/csharp/test/Unit/Telemetry/MetricsAggregatorTests.cs new file mode 100644 index 00000000..141fe84d --- /dev/null +++ b/csharp/test/Unit/Telemetry/MetricsAggregatorTests.cs @@ -0,0 +1,812 @@ +/* +* Copyright (c) 2025 ADBC Drivers Contributors +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; +using AdbcDrivers.Databricks.Telemetry; +using AdbcDrivers.Databricks.Telemetry.Models; +using Xunit; + +namespace AdbcDrivers.Databricks.Tests.Unit.Telemetry +{ + /// + /// Tests for MetricsAggregator class. + /// + public class MetricsAggregatorTests : IDisposable + { + private readonly ActivitySource _activitySource; + private readonly MockTelemetryExporter _mockExporter; + private readonly TelemetryConfiguration _config; + private MetricsAggregator? _aggregator; + + private const long TestWorkspaceId = 12345; + private const string TestUserAgent = "TestAgent/1.0"; + + public MetricsAggregatorTests() + { + _activitySource = new ActivitySource("TestSource"); + _mockExporter = new MockTelemetryExporter(); + _config = new TelemetryConfiguration + { + BatchSize = 10, + FlushIntervalMs = 60000 // Set high to control when flush happens + }; + } + + public void Dispose() + { + _aggregator?.Dispose(); + _activitySource.Dispose(); + } + + #region Constructor Tests + + [Fact] + public void MetricsAggregator_Constructor_NullExporter_ThrowsException() + { + // Act & Assert + Assert.Throws(() => + new MetricsAggregator(null!, _config, TestWorkspaceId)); + } + + [Fact] + public void MetricsAggregator_Constructor_NullConfig_ThrowsException() + { + // Act & Assert + Assert.Throws(() => + new MetricsAggregator(_mockExporter, null!, TestWorkspaceId)); + } + + [Fact] + public void MetricsAggregator_Constructor_ValidParameters_CreatesInstance() + { + // Act + _aggregator = new MetricsAggregator(_mockExporter, _config, TestWorkspaceId, TestUserAgent); + + // Assert + Assert.NotNull(_aggregator); + Assert.Equal(0, _aggregator.PendingEventCount); + Assert.Equal(0, _aggregator.ActiveStatementCount); + } + + #endregion + + #region ProcessActivity Tests - Connection Events + + [Fact] + public void MetricsAggregator_ProcessActivity_ConnectionOpen_EmitsImmediately() + { + // Arrange + _aggregator = new MetricsAggregator(_mockExporter, _config, TestWorkspaceId, TestUserAgent); + + using var listener = new ActivityListener + { + ShouldListenTo = _ => true, + Sample = (ref ActivityCreationOptions _) => ActivitySamplingResult.AllDataAndRecorded + }; + ActivitySource.AddActivityListener(listener); + + using var activity = _activitySource.StartActivity("Connection.Open"); + Assert.NotNull(activity); + activity.SetTag("session.id", "test-session-123"); + activity.Stop(); + + // Act + _aggregator.ProcessActivity(activity); + + // Assert + Assert.Equal(1, _aggregator.PendingEventCount); + } + + [Fact] + public void MetricsAggregator_ProcessActivity_ConnectionOpenAsync_EmitsImmediately() + { + // Arrange + _aggregator = new MetricsAggregator(_mockExporter, _config, TestWorkspaceId, TestUserAgent); + + using var listener = new ActivityListener + { + ShouldListenTo = _ => true, + Sample = (ref ActivityCreationOptions _) => ActivitySamplingResult.AllDataAndRecorded + }; + ActivitySource.AddActivityListener(listener); + + using var activity = _activitySource.StartActivity("OpenAsync"); + Assert.NotNull(activity); + activity.SetTag("session.id", "test-session-456"); + activity.SetTag("driver.version", "1.0.0"); + activity.Stop(); + + // Act + _aggregator.ProcessActivity(activity); + + // Assert + Assert.Equal(1, _aggregator.PendingEventCount); + } + + #endregion + + #region ProcessActivity Tests - Statement Events + + [Fact] + public void MetricsAggregator_ProcessActivity_Statement_AggregatesByStatementId() + { + // Arrange + _aggregator = new MetricsAggregator(_mockExporter, _config, TestWorkspaceId, TestUserAgent); + + using var listener = new ActivityListener + { + ShouldListenTo = _ => true, + Sample = (ref ActivityCreationOptions _) => ActivitySamplingResult.AllDataAndRecorded + }; + ActivitySource.AddActivityListener(listener); + + var statementId = "stmt-123"; + var sessionId = "session-456"; + + // First activity + using (var activity1 = _activitySource.StartActivity("Statement.Execute")) + { + Assert.NotNull(activity1); + activity1.SetTag("statement.id", statementId); + activity1.SetTag("session.id", sessionId); + activity1.SetTag("result.chunk_count", "5"); + activity1.Stop(); + + _aggregator.ProcessActivity(activity1); + } + + // Second activity with same statement_id + using (var activity2 = _activitySource.StartActivity("Statement.FetchResults")) + { + Assert.NotNull(activity2); + activity2.SetTag("statement.id", statementId); + activity2.SetTag("session.id", sessionId); + activity2.SetTag("result.chunk_count", "3"); + activity2.Stop(); + + _aggregator.ProcessActivity(activity2); + } + + // Assert - should not emit until CompleteStatement + Assert.Equal(0, _aggregator.PendingEventCount); + Assert.Equal(1, _aggregator.ActiveStatementCount); + } + + [Fact] + public void MetricsAggregator_ProcessActivity_Statement_WithoutStatementId_EmitsImmediately() + { + // Arrange + _aggregator = new MetricsAggregator(_mockExporter, _config, TestWorkspaceId, TestUserAgent); + + using var listener = new ActivityListener + { + ShouldListenTo = _ => true, + Sample = (ref ActivityCreationOptions _) => ActivitySamplingResult.AllDataAndRecorded + }; + ActivitySource.AddActivityListener(listener); + + using var activity = _activitySource.StartActivity("Statement.Execute"); + Assert.NotNull(activity); + // No statement.id tag + activity.SetTag("session.id", "session-123"); + activity.Stop(); + + // Act + _aggregator.ProcessActivity(activity); + + // Assert - should emit immediately since no statement_id + Assert.Equal(1, _aggregator.PendingEventCount); + } + + #endregion + + #region CompleteStatement Tests + + [Fact] + public void MetricsAggregator_CompleteStatement_EmitsAggregatedEvent() + { + // Arrange + _aggregator = new MetricsAggregator(_mockExporter, _config, TestWorkspaceId, TestUserAgent); + + using var listener = new ActivityListener + { + ShouldListenTo = _ => true, + Sample = (ref ActivityCreationOptions _) => ActivitySamplingResult.AllDataAndRecorded + }; + ActivitySource.AddActivityListener(listener); + + var statementId = "stmt-complete-123"; + var sessionId = "session-complete-456"; + + using (var activity = _activitySource.StartActivity("Statement.Execute")) + { + Assert.NotNull(activity); + activity.SetTag("statement.id", statementId); + activity.SetTag("session.id", sessionId); + activity.SetTag("result.format", "cloudfetch"); + activity.SetTag("result.chunk_count", "10"); + activity.Stop(); + + _aggregator.ProcessActivity(activity); + } + + Assert.Equal(0, _aggregator.PendingEventCount); + Assert.Equal(1, _aggregator.ActiveStatementCount); + + // Act + _aggregator.CompleteStatement(statementId); + + // Assert + Assert.Equal(1, _aggregator.PendingEventCount); + Assert.Equal(0, _aggregator.ActiveStatementCount); + } + + [Fact] + public void MetricsAggregator_CompleteStatement_NullStatementId_NoOp() + { + // Arrange + _aggregator = new MetricsAggregator(_mockExporter, _config, TestWorkspaceId, TestUserAgent); + + // Act - should not throw + _aggregator.CompleteStatement(null!); + _aggregator.CompleteStatement(string.Empty); + + // Assert + Assert.Equal(0, _aggregator.PendingEventCount); + } + + [Fact] + public void MetricsAggregator_CompleteStatement_UnknownStatementId_NoOp() + { + // Arrange + _aggregator = new MetricsAggregator(_mockExporter, _config, TestWorkspaceId, TestUserAgent); + + // Act - should not throw + _aggregator.CompleteStatement("unknown-statement-id"); + + // Assert + Assert.Equal(0, _aggregator.PendingEventCount); + } + + #endregion + + #region FlushAsync Tests + + [Fact] + public async Task MetricsAggregator_FlushAsync_BatchSizeReached_ExportsEvents() + { + // Arrange + var config = new TelemetryConfiguration { BatchSize = 2, FlushIntervalMs = 60000 }; + _aggregator = new MetricsAggregator(_mockExporter, config, TestWorkspaceId, TestUserAgent); + + using var listener = new ActivityListener + { + ShouldListenTo = _ => true, + Sample = (ref ActivityCreationOptions _) => ActivitySamplingResult.AllDataAndRecorded + }; + ActivitySource.AddActivityListener(listener); + + // Add events that should trigger flush + for (int i = 0; i < 3; i++) + { + using var activity = _activitySource.StartActivity("Connection.Open"); + Assert.NotNull(activity); + activity.SetTag("session.id", $"session-{i}"); + activity.Stop(); + _aggregator.ProcessActivity(activity); + } + + // Wait for any background flush to complete + await Task.Delay(100); + + // Assert - exporter should have been called + Assert.True(_mockExporter.ExportCallCount > 0); + } + + [Fact] + public async Task MetricsAggregator_FlushAsync_TimeInterval_ExportsEvents() + { + // Arrange + var config = new TelemetryConfiguration { BatchSize = 100, FlushIntervalMs = 100 }; // 100ms interval + _aggregator = new MetricsAggregator(_mockExporter, config, TestWorkspaceId, TestUserAgent); + + using var listener = new ActivityListener + { + ShouldListenTo = _ => true, + Sample = (ref ActivityCreationOptions _) => ActivitySamplingResult.AllDataAndRecorded + }; + ActivitySource.AddActivityListener(listener); + + using var activity = _activitySource.StartActivity("Connection.Open"); + Assert.NotNull(activity); + activity.SetTag("session.id", "session-timer"); + activity.Stop(); + _aggregator.ProcessActivity(activity); + + // Act - wait for timer to trigger flush + await Task.Delay(250); + + // Assert - exporter should have been called by timer + Assert.True(_mockExporter.ExportCallCount > 0); + } + + [Fact] + public async Task MetricsAggregator_FlushAsync_EmptyQueue_NoExport() + { + // Arrange + _aggregator = new MetricsAggregator(_mockExporter, _config, TestWorkspaceId, TestUserAgent); + + // Act + await _aggregator.FlushAsync(); + + // Assert + Assert.Equal(0, _mockExporter.ExportCallCount); + } + + [Fact] + public async Task MetricsAggregator_FlushAsync_MultipleEvents_ExportsAll() + { + // Arrange + _aggregator = new MetricsAggregator(_mockExporter, _config, TestWorkspaceId, TestUserAgent); + + using var listener = new ActivityListener + { + ShouldListenTo = _ => true, + Sample = (ref ActivityCreationOptions _) => ActivitySamplingResult.AllDataAndRecorded + }; + ActivitySource.AddActivityListener(listener); + + for (int i = 0; i < 5; i++) + { + using var activity = _activitySource.StartActivity("Connection.Open"); + Assert.NotNull(activity); + activity.SetTag("session.id", $"session-{i}"); + activity.Stop(); + _aggregator.ProcessActivity(activity); + } + + Assert.Equal(5, _aggregator.PendingEventCount); + + // Act + await _aggregator.FlushAsync(); + + // Assert + Assert.Equal(0, _aggregator.PendingEventCount); + Assert.Equal(1, _mockExporter.ExportCallCount); + Assert.Equal(5, _mockExporter.TotalExportedEvents); + } + + #endregion + + #region RecordException Tests + + [Fact] + public void MetricsAggregator_RecordException_Terminal_FlushesImmediately() + { + // Arrange + _aggregator = new MetricsAggregator(_mockExporter, _config, TestWorkspaceId, TestUserAgent); + + var terminalException = new HttpRequestException("401 (Unauthorized)"); + + // Act + _aggregator.RecordException("stmt-123", "session-456", terminalException); + + // Assert - terminal exception should be queued immediately + Assert.Equal(1, _aggregator.PendingEventCount); + } + + [Fact] + public void MetricsAggregator_RecordException_Retryable_BuffersUntilComplete() + { + // Arrange + _aggregator = new MetricsAggregator(_mockExporter, _config, TestWorkspaceId, TestUserAgent); + + var retryableException = new HttpRequestException("503 (Service Unavailable)"); + var statementId = "stmt-retryable-123"; + var sessionId = "session-retryable-456"; + + // Act + _aggregator.RecordException(statementId, sessionId, retryableException); + + // Assert - retryable exception should be buffered, not queued + Assert.Equal(0, _aggregator.PendingEventCount); + Assert.Equal(1, _aggregator.ActiveStatementCount); + } + + [Fact] + public void MetricsAggregator_RecordException_Retryable_EmittedOnFailedComplete() + { + // Arrange + _aggregator = new MetricsAggregator(_mockExporter, _config, TestWorkspaceId, TestUserAgent); + + var retryableException = new HttpRequestException("503 (Service Unavailable)"); + var statementId = "stmt-failed-123"; + var sessionId = "session-failed-456"; + + _aggregator.RecordException(statementId, sessionId, retryableException); + Assert.Equal(0, _aggregator.PendingEventCount); + + // Act - complete statement as failed + _aggregator.CompleteStatement(statementId, failed: true); + + // Assert - both statement event and error event should be emitted + Assert.Equal(2, _aggregator.PendingEventCount); + } + + [Fact] + public void MetricsAggregator_RecordException_Retryable_NotEmittedOnSuccessComplete() + { + // Arrange + _aggregator = new MetricsAggregator(_mockExporter, _config, TestWorkspaceId, TestUserAgent); + + var retryableException = new HttpRequestException("503 (Service Unavailable)"); + var statementId = "stmt-success-123"; + var sessionId = "session-success-456"; + + _aggregator.RecordException(statementId, sessionId, retryableException); + + // Act - complete statement as success + _aggregator.CompleteStatement(statementId, failed: false); + + // Assert - only statement event should be emitted, not the error + Assert.Equal(1, _aggregator.PendingEventCount); + } + + [Fact] + public void MetricsAggregator_RecordException_NullException_NoOp() + { + // Arrange + _aggregator = new MetricsAggregator(_mockExporter, _config, TestWorkspaceId, TestUserAgent); + + // Act - should not throw + _aggregator.RecordException("stmt-123", "session-456", null); + + // Assert + Assert.Equal(0, _aggregator.PendingEventCount); + } + + [Fact] + public void MetricsAggregator_RecordException_NullStatementId_NoOp() + { + // Arrange + _aggregator = new MetricsAggregator(_mockExporter, _config, TestWorkspaceId, TestUserAgent); + + // Act - should not throw + _aggregator.RecordException(null!, "session-456", new Exception("test")); + _aggregator.RecordException(string.Empty, "session-456", new Exception("test")); + + // Assert + Assert.Equal(0, _aggregator.PendingEventCount); + } + + #endregion + + #region Exception Swallowing Tests + + [Fact] + public void MetricsAggregator_ProcessActivity_ExceptionSwallowed_NoThrow() + { + // Arrange + var throwingExporter = new ThrowingTelemetryExporter(); + _aggregator = new MetricsAggregator(throwingExporter, _config, TestWorkspaceId, TestUserAgent); + + // Act & Assert - should not throw even with throwing exporter + _aggregator.ProcessActivity(null); + } + + [Fact] + public async Task MetricsAggregator_FlushAsync_ExceptionSwallowed_NoThrow() + { + // Arrange + var throwingExporter = new ThrowingTelemetryExporter(); + _aggregator = new MetricsAggregator(throwingExporter, _config, TestWorkspaceId, TestUserAgent); + + using var listener = new ActivityListener + { + ShouldListenTo = _ => true, + Sample = (ref ActivityCreationOptions _) => ActivitySamplingResult.AllDataAndRecorded + }; + ActivitySource.AddActivityListener(listener); + + using var activity = _activitySource.StartActivity("Connection.Open"); + Assert.NotNull(activity); + activity.SetTag("session.id", "session-throw"); + activity.Stop(); + _aggregator.ProcessActivity(activity); + + // Act & Assert - should not throw + await _aggregator.FlushAsync(); + } + + #endregion + + #region Tag Filtering Tests + + [Fact] + public void MetricsAggregator_ProcessActivity_FiltersTags_UsingRegistry() + { + // Arrange + _aggregator = new MetricsAggregator(_mockExporter, _config, TestWorkspaceId, TestUserAgent); + + using var listener = new ActivityListener + { + ShouldListenTo = _ => true, + Sample = (ref ActivityCreationOptions _) => ActivitySamplingResult.AllDataAndRecorded + }; + ActivitySource.AddActivityListener(listener); + + var statementId = "stmt-filter-123"; + + using (var activity = _activitySource.StartActivity("Statement.Execute")) + { + Assert.NotNull(activity); + activity.SetTag("statement.id", statementId); + activity.SetTag("session.id", "session-filter"); + activity.SetTag("result.format", "cloudfetch"); + // This sensitive tag should be filtered out + activity.SetTag("db.statement", "SELECT * FROM sensitive_table"); + // This tag should be included + activity.SetTag("result.chunk_count", "5"); + activity.Stop(); + + _aggregator.ProcessActivity(activity); + } + + // Complete to emit + _aggregator.CompleteStatement(statementId); + + // Assert - event should be created (we can verify via export) + Assert.Equal(1, _aggregator.PendingEventCount); + } + + #endregion + + #region WrapInFrontendLog Tests + + [Fact] + public void MetricsAggregator_WrapInFrontendLog_CreatesValidStructure() + { + // Arrange + _aggregator = new MetricsAggregator(_mockExporter, _config, TestWorkspaceId, TestUserAgent); + + var telemetryEvent = new TelemetryEvent + { + SessionId = "session-wrap-123", + SqlStatementId = "stmt-wrap-456", + OperationLatencyMs = 100 + }; + + // Act + var frontendLog = _aggregator.WrapInFrontendLog(telemetryEvent); + + // Assert + Assert.NotNull(frontendLog); + Assert.Equal(TestWorkspaceId, frontendLog.WorkspaceId); + Assert.NotEmpty(frontendLog.FrontendLogEventId); + Assert.NotNull(frontendLog.Context); + Assert.NotNull(frontendLog.Context.ClientContext); + Assert.Equal(TestUserAgent, frontendLog.Context.ClientContext.UserAgent); + Assert.True(frontendLog.Context.TimestampMillis > 0); + Assert.NotNull(frontendLog.Entry); + Assert.NotNull(frontendLog.Entry.SqlDriverLog); + Assert.Equal(telemetryEvent.SessionId, frontendLog.Entry.SqlDriverLog.SessionId); + Assert.Equal(telemetryEvent.SqlStatementId, frontendLog.Entry.SqlDriverLog.SqlStatementId); + } + + [Fact] + public void MetricsAggregator_WrapInFrontendLog_GeneratesUniqueEventIds() + { + // Arrange + _aggregator = new MetricsAggregator(_mockExporter, _config, TestWorkspaceId, TestUserAgent); + + var telemetryEvent = new TelemetryEvent { SessionId = "session-unique" }; + + // Act + var frontendLog1 = _aggregator.WrapInFrontendLog(telemetryEvent); + var frontendLog2 = _aggregator.WrapInFrontendLog(telemetryEvent); + + // Assert + Assert.NotEqual(frontendLog1.FrontendLogEventId, frontendLog2.FrontendLogEventId); + } + + #endregion + + #region Dispose Tests + + [Fact] + public void MetricsAggregator_Dispose_FlushesRemainingEvents() + { + // Arrange + _aggregator = new MetricsAggregator(_mockExporter, _config, TestWorkspaceId, TestUserAgent); + + using var listener = new ActivityListener + { + ShouldListenTo = _ => true, + Sample = (ref ActivityCreationOptions _) => ActivitySamplingResult.AllDataAndRecorded + }; + ActivitySource.AddActivityListener(listener); + + using var activity = _activitySource.StartActivity("Connection.Open"); + Assert.NotNull(activity); + activity.SetTag("session.id", "session-dispose"); + activity.Stop(); + _aggregator.ProcessActivity(activity); + + Assert.Equal(1, _aggregator.PendingEventCount); + + // Act + _aggregator.Dispose(); + + // Assert - events should have been flushed + Assert.True(_mockExporter.ExportCallCount > 0); + } + + [Fact] + public void MetricsAggregator_Dispose_CanBeCalledMultipleTimes() + { + // Arrange + _aggregator = new MetricsAggregator(_mockExporter, _config, TestWorkspaceId, TestUserAgent); + + // Act & Assert - should not throw + _aggregator.Dispose(); + _aggregator.Dispose(); + } + + #endregion + + #region Integration Tests + + [Fact] + public async Task MetricsAggregator_EndToEnd_StatementLifecycle() + { + // Arrange + _aggregator = new MetricsAggregator(_mockExporter, _config, TestWorkspaceId, TestUserAgent); + + using var listener = new ActivityListener + { + ShouldListenTo = _ => true, + Sample = (ref ActivityCreationOptions _) => ActivitySamplingResult.AllDataAndRecorded + }; + ActivitySource.AddActivityListener(listener); + + var statementId = "stmt-e2e-123"; + var sessionId = "session-e2e-456"; + + // Simulate statement execution + using (var executeActivity = _activitySource.StartActivity("Statement.Execute")) + { + Assert.NotNull(executeActivity); + executeActivity.SetTag("statement.id", statementId); + executeActivity.SetTag("session.id", sessionId); + executeActivity.SetTag("result.format", "cloudfetch"); + executeActivity.Stop(); + _aggregator.ProcessActivity(executeActivity); + } + + // Simulate chunk downloads + for (int i = 0; i < 3; i++) + { + using var downloadActivity = _activitySource.StartActivity("CloudFetch.Download"); + Assert.NotNull(downloadActivity); + downloadActivity.SetTag("statement.id", statementId); + downloadActivity.SetTag("session.id", sessionId); + downloadActivity.SetTag("result.chunk_count", "1"); + downloadActivity.SetTag("result.bytes_downloaded", "1000"); + downloadActivity.Stop(); + _aggregator.ProcessActivity(downloadActivity); + } + + // Complete statement + _aggregator.CompleteStatement(statementId); + + // Flush + await _aggregator.FlushAsync(); + + // Assert + Assert.Equal(1, _mockExporter.ExportCallCount); + Assert.Equal(1, _mockExporter.TotalExportedEvents); + Assert.Equal(0, _aggregator.ActiveStatementCount); + } + + [Fact] + public async Task MetricsAggregator_EndToEnd_MultipleStatements() + { + // Arrange + _aggregator = new MetricsAggregator(_mockExporter, _config, TestWorkspaceId, TestUserAgent); + + using var listener = new ActivityListener + { + ShouldListenTo = _ => true, + Sample = (ref ActivityCreationOptions _) => ActivitySamplingResult.AllDataAndRecorded + }; + ActivitySource.AddActivityListener(listener); + + var sessionId = "session-multi"; + + // Execute 3 statements + for (int i = 0; i < 3; i++) + { + var statementId = $"stmt-multi-{i}"; + + using var activity = _activitySource.StartActivity("Statement.Execute"); + Assert.NotNull(activity); + activity.SetTag("statement.id", statementId); + activity.SetTag("session.id", sessionId); + activity.Stop(); + _aggregator.ProcessActivity(activity); + + _aggregator.CompleteStatement(statementId); + } + + // Flush + await _aggregator.FlushAsync(); + + // Assert + Assert.Equal(3, _mockExporter.TotalExportedEvents); + } + + #endregion + + #region Mock Classes + + /// + /// Mock telemetry exporter for testing. + /// + private class MockTelemetryExporter : ITelemetryExporter + { + private int _exportCallCount; + private int _totalExportedEvents; + private readonly ConcurrentBag _exportedLogs = new ConcurrentBag(); + + public int ExportCallCount => _exportCallCount; + public int TotalExportedEvents => _totalExportedEvents; + public IReadOnlyCollection ExportedLogs => _exportedLogs.ToList(); + + public Task ExportAsync(IReadOnlyList logs, CancellationToken ct = default) + { + Interlocked.Increment(ref _exportCallCount); + Interlocked.Add(ref _totalExportedEvents, logs.Count); + + foreach (var log in logs) + { + _exportedLogs.Add(log); + } + + return Task.CompletedTask; + } + } + + /// + /// Telemetry exporter that always throws for testing exception handling. + /// + private class ThrowingTelemetryExporter : ITelemetryExporter + { + public Task ExportAsync(IReadOnlyList logs, CancellationToken ct = default) + { + throw new InvalidOperationException("Test exception from exporter"); + } + } + + #endregion + } +} diff --git a/csharp/test/Unit/Telemetry/TelemetryClientManagerTests.cs b/csharp/test/Unit/Telemetry/TelemetryClientManagerTests.cs new file mode 100644 index 00000000..fc14270b --- /dev/null +++ b/csharp/test/Unit/Telemetry/TelemetryClientManagerTests.cs @@ -0,0 +1,720 @@ +/* +* Copyright (c) 2025 ADBC Drivers Contributors +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Generic; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; +using AdbcDrivers.Databricks.Telemetry; +using AdbcDrivers.Databricks.Telemetry.Models; +using Xunit; + +namespace AdbcDrivers.Databricks.Tests.Unit.Telemetry +{ + /// + /// Tests for TelemetryClientManager, TelemetryClientHolder, TelemetryClientAdapter, and ITelemetryClient. + /// + public class TelemetryClientManagerTests + { + #region TelemetryClientHolder Tests + + [Fact] + public void TelemetryClientHolder_Constructor_InitializesCorrectly() + { + // Arrange + var mockClient = new MockTelemetryClient("test-host"); + + // Act + var holder = new TelemetryClientHolder(mockClient); + + // Assert + Assert.Same(mockClient, holder.Client); + Assert.Equal(0, holder.RefCount); + } + + [Fact] + public void TelemetryClientHolder_Constructor_NullClient_ThrowsException() + { + // Act & Assert + Assert.Throws(() => new TelemetryClientHolder(null!)); + } + + [Fact] + public void TelemetryClientHolder_IncrementRefCount_IncrementsCorrectly() + { + // Arrange + var mockClient = new MockTelemetryClient("test-host"); + var holder = new TelemetryClientHolder(mockClient); + + // Act & Assert + Assert.Equal(0, holder.RefCount); + Assert.Equal(1, holder.IncrementRefCount()); + Assert.Equal(1, holder.RefCount); + Assert.Equal(2, holder.IncrementRefCount()); + Assert.Equal(2, holder.RefCount); + } + + [Fact] + public void TelemetryClientHolder_DecrementRefCount_DecrementsCorrectly() + { + // Arrange + var mockClient = new MockTelemetryClient("test-host"); + var holder = new TelemetryClientHolder(mockClient); + holder.IncrementRefCount(); + holder.IncrementRefCount(); + + // Act & Assert + Assert.Equal(2, holder.RefCount); + Assert.Equal(1, holder.DecrementRefCount()); + Assert.Equal(1, holder.RefCount); + Assert.Equal(0, holder.DecrementRefCount()); + Assert.Equal(0, holder.RefCount); + } + + #endregion + + #region TelemetryClientAdapter Tests + + [Fact] + public void TelemetryClientAdapter_Constructor_InitializesCorrectly() + { + // Arrange + var mockExporter = new MockTelemetryExporter(); + + // Act + var adapter = new TelemetryClientAdapter("test-host", mockExporter); + + // Assert + Assert.Equal("test-host", adapter.Host); + Assert.Same(mockExporter, adapter.Exporter); + } + + [Fact] + public void TelemetryClientAdapter_Constructor_NullHost_ThrowsException() + { + // Arrange + var mockExporter = new MockTelemetryExporter(); + + // Act & Assert + Assert.Throws(() => new TelemetryClientAdapter(null!, mockExporter)); + } + + [Fact] + public void TelemetryClientAdapter_Constructor_EmptyHost_ThrowsException() + { + // Arrange + var mockExporter = new MockTelemetryExporter(); + + // Act & Assert + Assert.Throws(() => new TelemetryClientAdapter("", mockExporter)); + } + + [Fact] + public void TelemetryClientAdapter_Constructor_WhitespaceHost_ThrowsException() + { + // Arrange + var mockExporter = new MockTelemetryExporter(); + + // Act & Assert + Assert.Throws(() => new TelemetryClientAdapter(" ", mockExporter)); + } + + [Fact] + public void TelemetryClientAdapter_Constructor_NullExporter_ThrowsException() + { + // Act & Assert + Assert.Throws(() => new TelemetryClientAdapter("test-host", null!)); + } + + [Fact] + public async Task TelemetryClientAdapter_ExportAsync_DelegatesToExporter() + { + // Arrange + var mockExporter = new MockTelemetryExporter(); + var adapter = new TelemetryClientAdapter("test-host", mockExporter); + var logs = new List { CreateTestLog() }; + + // Act + await adapter.ExportAsync(logs); + + // Assert + Assert.Equal(1, mockExporter.ExportCallCount); + Assert.Same(logs, mockExporter.LastExportedLogs); + } + + [Fact] + public async Task TelemetryClientAdapter_CloseAsync_CompletesWithoutException() + { + // Arrange + var mockExporter = new MockTelemetryExporter(); + var adapter = new TelemetryClientAdapter("test-host", mockExporter); + + // Act & Assert - should not throw + await adapter.CloseAsync(); + } + + #endregion + + #region TelemetryClientManager Singleton Tests + + [Fact] + public void TelemetryClientManager_GetInstance_ReturnsSingleton() + { + // Act + var instance1 = TelemetryClientManager.GetInstance(); + var instance2 = TelemetryClientManager.GetInstance(); + + // Assert + Assert.Same(instance1, instance2); + } + + #endregion + + #region TelemetryClientManager_GetOrCreateClient Tests + + [Fact] + public void TelemetryClientManager_GetOrCreateClient_NewHost_CreatesClient() + { + // Arrange + var manager = CreateManagerWithMockFactory(); + var host = "test-host-1.databricks.com"; + var httpClient = new HttpClient(); + var config = new TelemetryConfiguration(); + + // Act + var client = manager.GetOrCreateClient(host, httpClient, config); + + // Assert + Assert.NotNull(client); + Assert.Equal(host, client.Host); + Assert.True(manager.HasClient(host)); + Assert.Equal(1, manager.ClientCount); + } + + [Fact] + public void TelemetryClientManager_GetOrCreateClient_ExistingHost_ReturnsSameClient() + { + // Arrange + var manager = CreateManagerWithMockFactory(); + var host = "test-host-2.databricks.com"; + var httpClient = new HttpClient(); + var config = new TelemetryConfiguration(); + + // Act + var client1 = manager.GetOrCreateClient(host, httpClient, config); + var client2 = manager.GetOrCreateClient(host, httpClient, config); + + // Assert + Assert.Same(client1, client2); + Assert.Equal(1, manager.ClientCount); + + // Verify reference count incremented + manager.TryGetHolder(host, out var holder); + Assert.Equal(2, holder!.RefCount); + } + + [Fact] + public void TelemetryClientManager_GetOrCreateClient_MultipleHosts_CreatesMultipleClients() + { + // Arrange + var manager = CreateManagerWithMockFactory(); + var host1 = "host1.databricks.com"; + var host2 = "host2.databricks.com"; + var httpClient = new HttpClient(); + var config = new TelemetryConfiguration(); + + // Act + var client1 = manager.GetOrCreateClient(host1, httpClient, config); + var client2 = manager.GetOrCreateClient(host2, httpClient, config); + + // Assert + Assert.NotSame(client1, client2); + Assert.Equal(host1, client1.Host); + Assert.Equal(host2, client2.Host); + Assert.Equal(2, manager.ClientCount); + } + + [Fact] + public void TelemetryClientManager_GetOrCreateClient_NullHost_ThrowsException() + { + // Arrange + var manager = CreateManagerWithMockFactory(); + var httpClient = new HttpClient(); + var config = new TelemetryConfiguration(); + + // Act & Assert + Assert.Throws(() => manager.GetOrCreateClient(null!, httpClient, config)); + } + + [Fact] + public void TelemetryClientManager_GetOrCreateClient_EmptyHost_ThrowsException() + { + // Arrange + var manager = CreateManagerWithMockFactory(); + var httpClient = new HttpClient(); + var config = new TelemetryConfiguration(); + + // Act & Assert + Assert.Throws(() => manager.GetOrCreateClient("", httpClient, config)); + } + + [Fact] + public void TelemetryClientManager_GetOrCreateClient_WhitespaceHost_ThrowsException() + { + // Arrange + var manager = CreateManagerWithMockFactory(); + var httpClient = new HttpClient(); + var config = new TelemetryConfiguration(); + + // Act & Assert + Assert.Throws(() => manager.GetOrCreateClient(" ", httpClient, config)); + } + + [Fact] + public void TelemetryClientManager_GetOrCreateClient_NullHttpClient_ThrowsException() + { + // Arrange + var manager = CreateManagerWithMockFactory(); + var config = new TelemetryConfiguration(); + + // Act & Assert + Assert.Throws(() => manager.GetOrCreateClient("host", null!, config)); + } + + [Fact] + public void TelemetryClientManager_GetOrCreateClient_NullConfig_ThrowsException() + { + // Arrange + var manager = CreateManagerWithMockFactory(); + var httpClient = new HttpClient(); + + // Act & Assert + Assert.Throws(() => manager.GetOrCreateClient("host", httpClient, null!)); + } + + [Fact] + public void TelemetryClientManager_GetOrCreateClient_CaseInsensitive() + { + // Arrange + var manager = CreateManagerWithMockFactory(); + var host = "Test-Host.Databricks.com"; + var httpClient = new HttpClient(); + var config = new TelemetryConfiguration(); + + // Act + var client1 = manager.GetOrCreateClient(host.ToLower(), httpClient, config); + var client2 = manager.GetOrCreateClient(host.ToUpper(), httpClient, config); + + // Assert + Assert.Same(client1, client2); + Assert.Equal(1, manager.ClientCount); + } + + #endregion + + #region TelemetryClientManager_ReleaseClientAsync Tests + + [Fact] + public async Task TelemetryClientManager_ReleaseClientAsync_LastReference_ClosesClient() + { + // Arrange + var closedClients = new List(); + var manager = new TelemetryClientManager((host, httpClient, config) => + { + var mockClient = new MockTelemetryClient(host); + mockClient.OnClose = () => closedClients.Add(host); + return mockClient; + }); + var host = "test-host-3.databricks.com"; + var httpClient = new HttpClient(); + var config = new TelemetryConfiguration(); + var client = manager.GetOrCreateClient(host, httpClient, config); + + // Act + await manager.ReleaseClientAsync(host); + + // Assert + Assert.False(manager.HasClient(host)); + Assert.Equal(0, manager.ClientCount); + Assert.Contains(host, closedClients); + } + + [Fact] + public async Task TelemetryClientManager_ReleaseClientAsync_MultipleReferences_KeepsClient() + { + // Arrange + var closedClients = new List(); + var manager = new TelemetryClientManager((host, httpClient, config) => + { + var mockClient = new MockTelemetryClient(host); + mockClient.OnClose = () => closedClients.Add(host); + return mockClient; + }); + var host = "test-host-4.databricks.com"; + var httpClient = new HttpClient(); + var config = new TelemetryConfiguration(); + manager.GetOrCreateClient(host, httpClient, config); + manager.GetOrCreateClient(host, httpClient, config); // Second reference + + // Act + await manager.ReleaseClientAsync(host); + + // Assert + Assert.True(manager.HasClient(host)); + manager.TryGetHolder(host, out var holder); + Assert.Equal(1, holder!.RefCount); + Assert.Empty(closedClients); // Client not closed + } + + [Fact] + public async Task TelemetryClientManager_ReleaseClientAsync_UnknownHost_DoesNothing() + { + // Arrange + var manager = CreateManagerWithMockFactory(); + + // Act - should not throw + await manager.ReleaseClientAsync("unknown-host.databricks.com"); + + // Assert + Assert.Equal(0, manager.ClientCount); + } + + [Fact] + public async Task TelemetryClientManager_ReleaseClientAsync_NullHost_DoesNothing() + { + // Arrange + var manager = CreateManagerWithMockFactory(); + + // Act - should not throw + await manager.ReleaseClientAsync(null!); + + // Assert - no exception thrown + } + + [Fact] + public async Task TelemetryClientManager_ReleaseClientAsync_EmptyHost_DoesNothing() + { + // Arrange + var manager = CreateManagerWithMockFactory(); + + // Act - should not throw + await manager.ReleaseClientAsync(""); + + // Assert - no exception thrown + } + + [Fact] + public async Task TelemetryClientManager_ReleaseClientAsync_AllReleased_ClosesClient() + { + // Arrange + var closedClients = new List(); + var manager = new TelemetryClientManager((host, httpClient, config) => + { + var mockClient = new MockTelemetryClient(host); + mockClient.OnClose = () => closedClients.Add(host); + return mockClient; + }); + var host = "test-host-5.databricks.com"; + var httpClient = new HttpClient(); + var config = new TelemetryConfiguration(); + + // Create 3 references + manager.GetOrCreateClient(host, httpClient, config); + manager.GetOrCreateClient(host, httpClient, config); + manager.GetOrCreateClient(host, httpClient, config); + Assert.Equal(1, manager.ClientCount); + + // Act - Release all + await manager.ReleaseClientAsync(host); + Assert.True(manager.HasClient(host)); // Still has 2 references + + await manager.ReleaseClientAsync(host); + Assert.True(manager.HasClient(host)); // Still has 1 reference + + await manager.ReleaseClientAsync(host); + + // Assert + Assert.False(manager.HasClient(host)); + Assert.Equal(0, manager.ClientCount); + Assert.Single(closedClients); + } + + [Fact] + public async Task TelemetryClientManager_ReleaseClientAsync_CloseThrows_SwallowsException() + { + // Arrange + var manager = new TelemetryClientManager((host, httpClient, config) => + { + var mockClient = new MockTelemetryClient(host); + mockClient.OnClose = () => throw new InvalidOperationException("Close failed"); + return mockClient; + }); + var host = "test-host-6.databricks.com"; + var httpClient = new HttpClient(); + var config = new TelemetryConfiguration(); + manager.GetOrCreateClient(host, httpClient, config); + + // Act - should not throw + await manager.ReleaseClientAsync(host); + + // Assert - client was removed despite exception + Assert.False(manager.HasClient(host)); + } + + #endregion + + #region TelemetryClientManager Thread Safety Tests + + [Fact] + public async Task TelemetryClientManager_ConcurrentGetOrCreateClient_ThreadSafe_NoDuplicates() + { + // Arrange + var manager = CreateManagerWithMockFactory(); + var host = "concurrent-host.databricks.com"; + var httpClient = new HttpClient(); + var config = new TelemetryConfiguration(); + var tasks = new Task[100]; + + // Act + for (int i = 0; i < 100; i++) + { + tasks[i] = Task.Run(() => manager.GetOrCreateClient(host, httpClient, config)); + } + + var clients = await Task.WhenAll(tasks); + + // Assert - All should be the same client + var firstClient = clients[0]; + Assert.All(clients, c => Assert.Same(firstClient, c)); + Assert.Equal(1, manager.ClientCount); + + // Verify reference count + manager.TryGetHolder(host, out var holder); + Assert.Equal(100, holder!.RefCount); + } + + [Fact] + public async Task TelemetryClientManager_ConcurrentReleaseClient_ThreadSafe() + { + // Arrange + var closeCount = 0; + var manager = new TelemetryClientManager((host, httpClient, config) => + { + var mockClient = new MockTelemetryClient(host); + mockClient.OnClose = () => Interlocked.Increment(ref closeCount); + return mockClient; + }); + var host = "concurrent-release-host.databricks.com"; + var httpClient = new HttpClient(); + var config = new TelemetryConfiguration(); + + // Create 100 references + for (int i = 0; i < 100; i++) + { + manager.GetOrCreateClient(host, httpClient, config); + } + + var tasks = new Task[100]; + + // Act - Release all concurrently + for (int i = 0; i < 100; i++) + { + tasks[i] = manager.ReleaseClientAsync(host); + } + + await Task.WhenAll(tasks); + + // Assert - Client should be closed (exactly once) + Assert.False(manager.HasClient(host)); + Assert.Equal(1, closeCount); + } + + [Fact] + public async Task TelemetryClientManager_ConcurrentGetAndRelease_ThreadSafe() + { + // Arrange + var manager = CreateManagerWithMockFactory(); + var host = "get-release-host.databricks.com"; + var httpClient = new HttpClient(); + var config = new TelemetryConfiguration(); + + // Create initial reference + manager.GetOrCreateClient(host, httpClient, config); + + var tasks = new List(); + + // Act - Mix of gets and releases + for (int i = 0; i < 50; i++) + { + tasks.Add(Task.Run(() => manager.GetOrCreateClient(host, httpClient, config))); + tasks.Add(manager.ReleaseClientAsync(host)); + } + + await Task.WhenAll(tasks); + + // Assert - No exceptions thrown (thread safety verified) + // Final state depends on timing, but should be consistent + } + + #endregion + + #region TelemetryClientManager Helper Method Tests + + [Fact] + public void TelemetryClientManager_TryGetHolder_ExistingHost_ReturnsTrue() + { + // Arrange + var manager = CreateManagerWithMockFactory(); + var host = "try-get-host.databricks.com"; + var httpClient = new HttpClient(); + var config = new TelemetryConfiguration(); + manager.GetOrCreateClient(host, httpClient, config); + + // Act + var result = manager.TryGetHolder(host, out var holder); + + // Assert + Assert.True(result); + Assert.NotNull(holder); + Assert.Equal(1, holder!.RefCount); + } + + [Fact] + public void TelemetryClientManager_TryGetHolder_UnknownHost_ReturnsFalse() + { + // Arrange + var manager = CreateManagerWithMockFactory(); + + // Act + var result = manager.TryGetHolder("unknown.databricks.com", out var holder); + + // Assert + Assert.False(result); + Assert.Null(holder); + } + + [Fact] + public void TelemetryClientManager_TryGetHolder_NullHost_ReturnsFalse() + { + // Arrange + var manager = CreateManagerWithMockFactory(); + + // Act + var result = manager.TryGetHolder(null!, out var holder); + + // Assert + Assert.False(result); + Assert.Null(holder); + } + + [Fact] + public void TelemetryClientManager_HasClient_NullHost_ReturnsFalse() + { + // Arrange + var manager = CreateManagerWithMockFactory(); + + // Act + var result = manager.HasClient(null!); + + // Assert + Assert.False(result); + } + + [Fact] + public void TelemetryClientManager_Clear_RemovesAllClients() + { + // Arrange + var manager = CreateManagerWithMockFactory(); + var httpClient = new HttpClient(); + var config = new TelemetryConfiguration(); + manager.GetOrCreateClient("host1.databricks.com", httpClient, config); + manager.GetOrCreateClient("host2.databricks.com", httpClient, config); + manager.GetOrCreateClient("host3.databricks.com", httpClient, config); + Assert.Equal(3, manager.ClientCount); + + // Act + manager.Clear(); + + // Assert + Assert.Equal(0, manager.ClientCount); + } + + #endregion + + #region Helper Methods + + private TelemetryClientManager CreateManagerWithMockFactory() + { + return new TelemetryClientManager((host, httpClient, config) => + new MockTelemetryClient(host)); + } + + private TelemetryFrontendLog CreateTestLog() + { + return new TelemetryFrontendLog + { + FrontendLogEventId = "test-event-id", + WorkspaceId = 12345 + }; + } + + #endregion + + #region Mock Classes + + private class MockTelemetryClient : ITelemetryClient + { + public string Host { get; } + public int ExportCallCount { get; private set; } + public IReadOnlyList? LastExportedLogs { get; private set; } + public Action? OnClose { get; set; } + + public MockTelemetryClient(string host) + { + Host = host; + } + + public Task ExportAsync(IReadOnlyList logs, CancellationToken ct = default) + { + ExportCallCount++; + LastExportedLogs = logs; + return Task.CompletedTask; + } + + public Task CloseAsync() + { + OnClose?.Invoke(); + return Task.CompletedTask; + } + } + + private class MockTelemetryExporter : ITelemetryExporter + { + public int ExportCallCount { get; private set; } + public IReadOnlyList? LastExportedLogs { get; private set; } + + public Task ExportAsync(IReadOnlyList logs, CancellationToken ct = default) + { + ExportCallCount++; + LastExportedLogs = logs; + return Task.CompletedTask; + } + } + + #endregion + } +} diff --git a/csharp/test/Unit/Telemetry/TelemetryClientTests.cs b/csharp/test/Unit/Telemetry/TelemetryClientTests.cs new file mode 100644 index 00000000..50917151 --- /dev/null +++ b/csharp/test/Unit/Telemetry/TelemetryClientTests.cs @@ -0,0 +1,510 @@ +/* +* Copyright (c) 2025 ADBC Drivers Contributors +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Generic; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; +using AdbcDrivers.Databricks.Telemetry; +using AdbcDrivers.Databricks.Telemetry.Models; +using Xunit; + +namespace AdbcDrivers.Databricks.Tests.Unit.Telemetry +{ + /// + /// Unit tests for TelemetryClient. + /// Tests verify: + /// - Constructor initializes listener, aggregator, exporter + /// - ExportAsync delegates to exporter + /// - CloseAsync cancels background task, flushes, disposes + /// - All exceptions swallowed during close + /// + public class TelemetryClientTests + { + private const string TestHost = "test-host.databricks.com"; + + #region Constructor Tests + + [Fact] + public void TelemetryClient_Constructor_InitializesComponents() + { + // Arrange + var httpClient = new HttpClient(); + var config = new TelemetryConfiguration { Enabled = true }; + + // Act + var client = new TelemetryClient(TestHost, httpClient, config); + + // Assert + Assert.NotNull(client); + Assert.Equal(TestHost, client.Host); + Assert.NotNull(client.Listener); + Assert.NotNull(client.Aggregator); + Assert.NotNull(client.Exporter); + Assert.True(client.Listener.IsStarted); + Assert.False(client.IsClosed); + } + + [Fact] + public void TelemetryClient_Constructor_WithWorkspaceIdAndUserAgent() + { + // Arrange + var httpClient = new HttpClient(); + var config = new TelemetryConfiguration { Enabled = true }; + long workspaceId = 12345; + string userAgent = "TestAgent/1.0"; + + // Act + var client = new TelemetryClient(TestHost, httpClient, config, workspaceId, userAgent); + + // Assert + Assert.NotNull(client); + Assert.Equal(TestHost, client.Host); + } + + [Fact] + public void TelemetryClient_Constructor_NullHost_ThrowsException() + { + // Arrange + var httpClient = new HttpClient(); + var config = new TelemetryConfiguration(); + + // Act & Assert + Assert.Throws(() => new TelemetryClient(null!, httpClient, config)); + } + + [Fact] + public void TelemetryClient_Constructor_EmptyHost_ThrowsException() + { + // Arrange + var httpClient = new HttpClient(); + var config = new TelemetryConfiguration(); + + // Act & Assert + Assert.Throws(() => new TelemetryClient("", httpClient, config)); + } + + [Fact] + public void TelemetryClient_Constructor_WhitespaceHost_ThrowsException() + { + // Arrange + var httpClient = new HttpClient(); + var config = new TelemetryConfiguration(); + + // Act & Assert + Assert.Throws(() => new TelemetryClient(" ", httpClient, config)); + } + + [Fact] + public void TelemetryClient_Constructor_NullHttpClient_ThrowsException() + { + // Arrange + var config = new TelemetryConfiguration(); + + // Act & Assert + Assert.Throws(() => new TelemetryClient(TestHost, null!, config)); + } + + [Fact] + public void TelemetryClient_Constructor_NullConfig_ThrowsException() + { + // Arrange + var httpClient = new HttpClient(); + + // Act & Assert + Assert.Throws(() => new TelemetryClient(TestHost, httpClient, null!)); + } + + [Fact] + public void TelemetryClient_Constructor_WithCustomExporter_UsesProvidedExporter() + { + // Arrange + var httpClient = new HttpClient(); + var config = new TelemetryConfiguration(); + var mockExporter = new MockTelemetryExporter(); + + // Act + var client = new TelemetryClient( + TestHost, httpClient, config, + workspaceId: 0, + userAgent: null, + exporter: mockExporter, + aggregator: null, + listener: null); + + // Assert + Assert.Same(mockExporter, client.Exporter); + } + + #endregion + + #region ExportAsync Tests + + [Fact] + public async Task TelemetryClient_ExportAsync_DelegatesToExporter() + { + // Arrange + var mockExporter = new MockTelemetryExporter(); + var httpClient = new HttpClient(); + var config = new TelemetryConfiguration(); + var client = CreateClientWithMockExporter(mockExporter, config); + var logs = new List { CreateTestLog() }; + + // Act + await client.ExportAsync(logs); + + // Assert + Assert.Equal(1, mockExporter.ExportCallCount); + Assert.Same(logs, mockExporter.LastExportedLogs); + } + + [Fact] + public async Task TelemetryClient_ExportAsync_WithCancellation_PropagatesCancellation() + { + // Arrange + var mockExporter = new MockTelemetryExporter { ShouldThrowCancellation = true }; + var httpClient = new HttpClient(); + var config = new TelemetryConfiguration(); + var client = CreateClientWithMockExporter(mockExporter, config); + var logs = new List { CreateTestLog() }; + using var cts = new CancellationTokenSource(); + cts.Cancel(); + + // Act & Assert + await Assert.ThrowsAsync(() => client.ExportAsync(logs, cts.Token)); + } + + [Fact] + public async Task TelemetryClient_ExportAsync_EmptyLogs_DoesNotCallExporter() + { + // Arrange + var mockExporter = new MockTelemetryExporter(); + var httpClient = new HttpClient(); + var config = new TelemetryConfiguration(); + var client = CreateClientWithMockExporter(mockExporter, config); + var logs = new List(); + + // Act + await client.ExportAsync(logs); + + // Assert - Exporter might be called but with empty list - check behavior + // The circuit breaker exporter returns early for empty lists + Assert.Equal(0, mockExporter.ExportCallCount); + } + + [Fact] + public async Task TelemetryClient_ExportAsync_NullLogs_DoesNotCallExporter() + { + // Arrange + var mockExporter = new MockTelemetryExporter(); + var httpClient = new HttpClient(); + var config = new TelemetryConfiguration(); + var client = CreateClientWithMockExporter(mockExporter, config); + + // Act + await client.ExportAsync(null!); + + // Assert - Exporter might be called but with null - check behavior + // The circuit breaker exporter returns early for null + Assert.Equal(0, mockExporter.ExportCallCount); + } + + #endregion + + #region CloseAsync Tests + + [Fact] + public async Task TelemetryClient_CloseAsync_FlushesAndCancels() + { + // Arrange + var mockExporter = new MockTelemetryExporter(); + var httpClient = new HttpClient(); + var config = new TelemetryConfiguration { FlushIntervalMs = 10000 }; // Long interval + var client = CreateClientWithMockExporter(mockExporter, config); + + // Act + await client.CloseAsync(); + + // Assert + Assert.True(client.IsClosed); + Assert.True(client.Listener.IsDisposed); + } + + [Fact] + public async Task TelemetryClient_CloseAsync_Idempotent() + { + // Arrange + var mockExporter = new MockTelemetryExporter(); + var httpClient = new HttpClient(); + var config = new TelemetryConfiguration(); + var client = CreateClientWithMockExporter(mockExporter, config); + + // Act - Call multiple times + await client.CloseAsync(); + await client.CloseAsync(); + await client.CloseAsync(); + + // Assert - No exceptions, client remains closed + Assert.True(client.IsClosed); + } + + [Fact] + public async Task TelemetryClient_CloseAsync_ExceptionSwallowed() + { + // Arrange + var mockExporter = new MockTelemetryExporter { ShouldThrowOnExport = true }; + var httpClient = new HttpClient(); + var config = new TelemetryConfiguration { FlushIntervalMs = 50 }; + var client = CreateClientWithMockExporter(mockExporter, config); + + // Act - should not throw even if internal operations fail + await client.CloseAsync(); + + // Assert + Assert.True(client.IsClosed); + } + + [Fact] + public async Task TelemetryClient_CloseAsync_StopsListener() + { + // Arrange + var mockExporter = new MockTelemetryExporter(); + var httpClient = new HttpClient(); + var config = new TelemetryConfiguration(); + var client = CreateClientWithMockExporter(mockExporter, config); + Assert.True(client.Listener.IsStarted); + + // Act + await client.CloseAsync(); + + // Assert + Assert.True(client.Listener.IsDisposed); + } + + [Fact] + public async Task TelemetryClient_CloseAsync_WaitsForBackgroundTask() + { + // Arrange + var mockExporter = new MockTelemetryExporter(); + var httpClient = new HttpClient(); + var config = new TelemetryConfiguration { FlushIntervalMs = 100 }; // Short interval + var client = CreateClientWithMockExporter(mockExporter, config); + + // Let background task run briefly + await Task.Delay(50); + + // Act + var closeTask = client.CloseAsync(); + + // Wait with timeout + var completed = await Task.WhenAny(closeTask, Task.Delay(TimeSpan.FromSeconds(10))); + + // Assert - Close should complete within reasonable time + Assert.Same(closeTask, completed); + } + + #endregion + + #region Background Flush Task Tests + + [Fact] + public async Task TelemetryClient_BackgroundFlush_FlushesMetrics() + { + // Arrange + var mockExporter = new MockTelemetryExporter(); + var httpClient = new HttpClient(); + var config = new TelemetryConfiguration + { + Enabled = true, + FlushIntervalMs = 100, // Short interval for testing + BatchSize = 1000 // Large batch size to prevent immediate flush + }; + var client = CreateClientWithMockExporter(mockExporter, config); + + // Enqueue some events via the aggregator + var logs = new List { CreateTestLog() }; + await client.ExportAsync(logs); + + // Wait for background flush + await Task.Delay(200); + + // Assert - Background flush should have triggered + Assert.True(mockExporter.ExportCallCount >= 1); + + // Cleanup + await client.CloseAsync(); + } + + [Fact] + public async Task TelemetryClient_BackgroundFlush_StopsOnClose() + { + // Arrange + var mockExporter = new MockTelemetryExporter(); + var httpClient = new HttpClient(); + var config = new TelemetryConfiguration + { + Enabled = true, + FlushIntervalMs = 50 + }; + var client = CreateClientWithMockExporter(mockExporter, config); + + // Wait for a few flush cycles + await Task.Delay(150); + var countBeforeClose = mockExporter.ExportCallCount; + + // Act + await client.CloseAsync(); + + // Wait and verify no more flushes + await Task.Delay(150); + + // Assert - Export count should not increase significantly after close + // Allow for one more flush during close + Assert.True(mockExporter.ExportCallCount <= countBeforeClose + 1); + } + + #endregion + + #region Thread Safety Tests + + [Fact] + public async Task TelemetryClient_ConcurrentExport_ThreadSafe() + { + // Arrange + var mockExporter = new MockTelemetryExporter(); + var httpClient = new HttpClient(); + var config = new TelemetryConfiguration(); + var client = CreateClientWithMockExporter(mockExporter, config); + var tasks = new List(); + var logs = new List { CreateTestLog() }; + + // Act - Concurrent exports + for (int i = 0; i < 100; i++) + { + tasks.Add(client.ExportAsync(logs)); + } + + await Task.WhenAll(tasks); + + // Assert - All exports should complete without exception + Assert.Equal(100, mockExporter.ExportCallCount); + + // Cleanup + await client.CloseAsync(); + } + + [Fact] + public async Task TelemetryClient_ExportDuringClose_ThreadSafe() + { + // Arrange + var mockExporter = new MockTelemetryExporter(); + var httpClient = new HttpClient(); + var config = new TelemetryConfiguration(); + var client = CreateClientWithMockExporter(mockExporter, config); + var logs = new List { CreateTestLog() }; + + // Act - Start closing and export concurrently + var closeTask = client.CloseAsync(); + var exportTasks = new List(); + for (int i = 0; i < 10; i++) + { + exportTasks.Add(client.ExportAsync(logs)); + } + + // All should complete without throwing + await Task.WhenAll(exportTasks); + await closeTask; + + // Assert + Assert.True(client.IsClosed); + } + + #endregion + + #region Helper Methods + + private TelemetryClient CreateClientWithMockExporter(MockTelemetryExporter exporter, TelemetryConfiguration config) + { + var httpClient = new HttpClient(); + + // Create a mock aggregator that uses the mock exporter + var mockAggregator = new MetricsAggregator(exporter, config, workspaceId: 12345, userAgent: "TestAgent"); + + // Create client with injected dependencies + return new TelemetryClient( + TestHost, + httpClient, + config, + workspaceId: 12345, + userAgent: "TestAgent", + exporter: exporter, + aggregator: mockAggregator, + listener: null); + } + + private TelemetryFrontendLog CreateTestLog() + { + return new TelemetryFrontendLog + { + FrontendLogEventId = Guid.NewGuid().ToString(), + WorkspaceId = 12345, + Context = new FrontendLogContext + { + ClientContext = new TelemetryClientContext { UserAgent = "TestAgent" }, + TimestampMillis = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds() + } + }; + } + + #endregion + + #region Mock Classes + + private class MockTelemetryExporter : ITelemetryExporter + { + private int _exportCallCount; + public int ExportCallCount => _exportCallCount; + public IReadOnlyList? LastExportedLogs { get; private set; } + public bool ShouldThrowOnExport { get; set; } + public bool ShouldThrowCancellation { get; set; } + + public Task ExportAsync(IReadOnlyList logs, CancellationToken ct = default) + { + if (ShouldThrowCancellation) + { + throw new OperationCanceledException(); + } + + if (ShouldThrowOnExport) + { + throw new InvalidOperationException("Export failed"); + } + + if (logs == null || logs.Count == 0) + { + return Task.CompletedTask; + } + + Interlocked.Increment(ref _exportCallCount); + LastExportedLogs = logs; + return Task.CompletedTask; + } + } + + #endregion + } +}