diff --git a/csharp/Directory.Packages.props b/csharp/Directory.Packages.props index 58713e76..c98b74e0 100644 --- a/csharp/Directory.Packages.props +++ b/csharp/Directory.Packages.props @@ -35,6 +35,8 @@ + + diff --git a/csharp/doc/telemetry-design.md b/csharp/doc/telemetry-design.md index 9e0717b7..f012b2dd 100644 --- a/csharp/doc/telemetry-design.md +++ b/csharp/doc/telemetry-design.md @@ -176,69 +176,419 @@ sequenceDiagram ### 3.1 FeatureFlagCache (Per-Host) -**Purpose**: Cache feature flag values at the host level to avoid repeated API calls and rate limiting. +**Purpose**: Cache **all** feature flag values at the host level to avoid repeated API calls and rate limiting. This is a generic cache that can be used for any driver configuration controlled by server-side feature flags, not just telemetry. -**Location**: `Apache.Arrow.Adbc.Drivers.Databricks.Telemetry.FeatureFlagCache` +**Location**: `AdbcDrivers.Databricks.FeatureFlagCache` (note: not in Telemetry namespace - this is a general-purpose component) #### Rationale +- **Generic feature flag support**: Cache returns all flags, allowing any driver feature to be controlled server-side - **Per-host caching**: Feature flags cached by host (not per connection) to prevent rate limiting - **Reference counting**: Tracks number of connections per host for proper cleanup -- **Automatic expiration**: Refreshes cached flags after TTL expires (15 minutes) +- **Server-controlled TTL**: Refresh interval controlled by server-provided `ttl_seconds` (default: 15 minutes) +- **Background refresh**: Scheduled refresh at server-specified intervals - **Thread-safe**: Uses ConcurrentDictionary for concurrent access from multiple connections +#### Configuration Priority Order + +Feature flags are integrated directly into the existing ADBC driver property parsing logic as an **extra layer** in the property value resolution. The priority order is: + +``` +1. User-specified properties (highest priority) +2. Feature flags from server +3. Driver default values (lowest priority) +``` + +**Integration Approach**: Feature flags are merged into the `Properties` dictionary at connection initialization time. This means: +- The existing `Properties.TryGetValue()` pattern continues to work unchanged +- Feature flags are transparently available as properties +- No changes needed to existing property parsing code + +```mermaid +flowchart LR + A[User Properties] --> D[Merged Properties] + B[Feature Flags] --> D + C[Driver Defaults] --> D + D --> E[Properties.TryGetValue] + E --> F[Existing Parsing Logic] +``` + +**Merge Logic** (in `DatabricksConnection` initialization): +```csharp +// Current flow: +// 1. User properties from connection string/config +// 2. Environment properties from DATABRICKS_CONFIG_FILE + +// New flow with feature flags: +// 1. User properties from connection string/config (highest priority) +// 2. Feature flags from server (middle priority) +// 3. Environment properties / driver defaults (lowest priority) + +private Dictionary MergePropertiesWithFeatureFlags( + Dictionary userProperties, + IReadOnlyDictionary featureFlags) +{ + var merged = new Dictionary(StringComparer.OrdinalIgnoreCase); + + // Start with feature flags as base (lower priority) + foreach (var flag in featureFlags) + { + // Map feature flag names to property names if needed + string propertyName = MapFeatureFlagToPropertyName(flag.Key); + if (propertyName != null) + { + merged[propertyName] = flag.Value; + } + } + + // Override with user properties (higher priority) + foreach (var prop in userProperties) + { + merged[prop.Key] = prop.Value; + } + + return merged; +} +``` + +**Feature Flag to Property Name Mapping**: +```csharp +// Feature flags have long names, map to driver property names +private static readonly Dictionary FeatureFlagToPropertyMap = new() +{ + ["databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForAdbc"] = "telemetry.enabled", + ["databricks.partnerplatform.clientConfigsFeatureFlags.enableCloudFetch"] = "cloudfetch.enabled", + // ... more mappings +}; +``` + +This approach: +- **Preserves existing code**: All `Properties.TryGetValue()` calls work unchanged +- **Transparent integration**: Feature flags appear as regular properties after merge +- **Clear priority**: User settings always win over server flags +- **Single merge point**: Feature flag integration happens once at connection initialization +- **Fresh values per connection**: Each new connection uses the latest cached feature flag values + +#### Per-Connection Property Resolution + +Each new connection applies property merging with the **latest** cached feature flag values: + +```mermaid +sequenceDiagram + participant C1 as Connection 1 + participant C2 as Connection 2 + participant FFC as FeatureFlagCache + participant BG as Background Refresh + + Note over FFC: Cache has flags v1 + + C1->>FFC: GetOrCreateContext(host) + FFC-->>C1: context (refCount=1) + C1->>C1: Merge properties with flags v1 + Note over C1: Properties frozen with v1 + + BG->>FFC: Refresh flags + Note over FFC: Cache updated to flags v2 + + C2->>FFC: GetOrCreateContext(host) + FFC-->>C2: context (refCount=2) + C2->>C2: Merge properties with flags v2 + Note over C2: Properties frozen with v2 + + Note over C1,C2: C1 uses v1, C2 uses v2 +``` + +**Key Points**: +- **Shared cache, per-connection merge**: The `FeatureFlagCache` is shared (per-host), but property merging happens at each connection initialization +- **Latest values for new connections**: When a new connection is created, it reads the current cached values (which may have been updated by background refresh) +- **Stable within connection**: Once merged, a connection's `Properties` are stable for its lifetime (no mid-connection changes) +- **Background refresh benefits new connections**: The scheduled refresh ensures new connections get up-to-date flag values without waiting for a fetch + +#### Feature Flag API + +**Endpoint**: `GET /api/2.0/connector-service/feature-flags/OSS_JDBC/{driver_version}` + +> **Note**: Currently using the JDBC endpoint (`OSS_JDBC`) until the ADBC endpoint (`OSS_ADBC`) is configured server-side. The feature flag name will still use `enableTelemetryForAdbc` to distinguish ADBC telemetry from JDBC telemetry. + +Where `{driver_version}` is the driver version (e.g., `1.0.0`). + +**Request Headers**: +- `Authorization`: Bearer token (same as connection auth) +- `User-Agent`: Custom user agent for connector service + +**Response Format** (JSON): +```json +{ + "flags": [ + { + "name": "databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForAdbc", + "value": "true" + }, + { + "name": "databricks.partnerplatform.clientConfigsFeatureFlags.enableCloudFetch", + "value": "true" + }, + { + "name": "databricks.partnerplatform.clientConfigsFeatureFlags.maxDownloadThreads", + "value": "10" + } + ], + "ttl_seconds": 900 +} +``` + +**Response Fields**: +- `flags`: Array of feature flag entries with `name` and `value` (string). Names can be mapped to driver property names. +- `ttl_seconds`: Server-controlled refresh interval in seconds (default: 900 = 15 minutes) + +**JDBC Reference**: See `DatabricksDriverFeatureFlagsContext.java:30-33` for endpoint format. + +#### Refresh Strategy + +The feature flag cache follows the JDBC driver pattern: + +1. **Initial Blocking Fetch**: On connection open, make a blocking HTTP call to fetch all feature flags +2. **Cache All Flags**: Store all returned flags in a local cache (Guava Cache in JDBC, ConcurrentDictionary in C#) +3. **Scheduled Background Refresh**: Start a daemon thread that refreshes flags at intervals based on `ttl_seconds` +4. **Dynamic TTL**: If server returns a different `ttl_seconds`, reschedule the refresh interval + +```mermaid +sequenceDiagram + participant Conn as Connection.Open + participant FFC as FeatureFlagContext + participant Server as Databricks Server + + Conn->>FFC: GetOrCreateContext(host) + FFC->>Server: GET /api/2.0/connector-service/feature-flags/OSS_JDBC/{version} + Note over FFC,Server: Blocking initial fetch + Server-->>FFC: {flags: [...], ttl_seconds: 900} + FFC->>FFC: Cache all flags + FFC->>FFC: Schedule refresh at ttl_seconds interval + + loop Every ttl_seconds + FFC->>Server: GET /api/2.0/connector-service/feature-flags/OSS_JDBC/{version} + Server-->>FFC: {flags: [...], ttl_seconds: N} + FFC->>FFC: Update cache, reschedule if TTL changed + end +``` + +**JDBC Reference**: See `DatabricksDriverFeatureFlagsContext.java:48-58` for initial fetch and scheduling. + +#### HTTP Client Pattern + +The feature flag cache does **not** use a separate dedicated HTTP client. Instead, it reuses the connection's existing HTTP client infrastructure: + +```mermaid +graph LR + A[DatabricksConnection] -->|provides| B[HttpClient] + A -->|provides| C[Auth Headers] + B --> D[FeatureFlagContext] + C --> D + D -->|HTTP GET| E[Feature Flag Endpoint] +``` + +**Key Points**: +1. **Reuse connection's HttpClient**: The `FeatureFlagContext` receives the connection's `HttpClient` (already configured with base address, timeouts, etc.) +2. **Reuse connection's authentication**: Auth headers (Bearer token) come from the connection's authentication mechanism +3. **Custom User-Agent**: Set a connector-service-specific User-Agent header for the feature flag requests + +**JDBC Implementation** (`DatabricksDriverFeatureFlagsContext.java:89-105`): +```java +// Get shared HTTP client from connection +IDatabricksHttpClient httpClient = + DatabricksHttpClientFactory.getInstance().getClient(connectionContext); + +// Create request +HttpGet request = new HttpGet(featureFlagEndpoint); + +// Set custom User-Agent for connector service +request.setHeader("User-Agent", + UserAgentManager.buildUserAgentForConnectorService(connectionContext)); + +// Add auth headers from connection's auth config +DatabricksClientConfiguratorManager.getInstance() + .getConfigurator(connectionContext) + .getDatabricksConfig() + .authenticate() + .forEach(request::addHeader); +``` + +**C# Equivalent Pattern**: +```csharp +// In DatabricksConnection - create HttpClient for feature flags +private HttpClient CreateFeatureFlagHttpClient() +{ + var handler = HiveServer2TlsImpl.NewHttpClientHandler(TlsOptions, _proxyConfigurator); + var httpClient = new HttpClient(handler); + + // Set base address + httpClient.BaseAddress = new Uri($"https://{_host}"); + + // Set auth header (reuse connection's token) + if (Properties.TryGetValue(SparkParameters.Token, out string? token)) + { + httpClient.DefaultRequestHeaders.Authorization = + new AuthenticationHeaderValue("Bearer", token); + } + + // Set custom User-Agent for connector service + httpClient.DefaultRequestHeaders.UserAgent.ParseAdd( + BuildConnectorServiceUserAgent()); + + return httpClient; +} + +// Pass to FeatureFlagContext +var context = featureFlagCache.GetOrCreateContext(_host, CreateFeatureFlagHttpClient()); +``` + +This approach: +- Avoids duplicating HTTP client configuration +- Ensures consistent authentication across all API calls +- Allows proper resource cleanup when connection closes + #### Interface ```csharp -namespace Apache.Arrow.Adbc.Drivers.Databricks.Telemetry +namespace AdbcDrivers.Databricks { /// /// Singleton that manages feature flag cache per host. /// Prevents rate limiting by caching feature flag responses. + /// This is a generic cache for all feature flags, not just telemetry. /// internal sealed class FeatureFlagCache { - private static readonly FeatureFlagCache Instance = new(); - public static FeatureFlagCache GetInstance() => Instance; + private static readonly FeatureFlagCache s_instance = new FeatureFlagCache(); + public static FeatureFlagCache GetInstance() => s_instance; /// /// Gets or creates a feature flag context for the host. /// Increments reference count. + /// Makes initial blocking fetch if context is new. /// - public FeatureFlagContext GetOrCreateContext(string host); + public FeatureFlagContext GetOrCreateContext(string host, HttpClient httpClient, string driverVersion); /// /// Decrements reference count for the host. - /// Removes context when ref count reaches zero. + /// Removes context and stops refresh scheduler when ref count reaches zero. /// public void ReleaseContext(string host); + } + /// + /// Holds feature flag state and reference count for a host. + /// Manages background refresh scheduling. + /// Uses the HttpClient provided by the connection for API calls. + /// + internal sealed class FeatureFlagContext : IDisposable + { /// - /// Checks if telemetry is enabled for the host. - /// Uses cached value if available and not expired. + /// Creates a new context with the given HTTP client. + /// Makes initial blocking fetch to populate cache. + /// Starts background refresh scheduler. /// - public Task IsTelemetryEnabledAsync( - string host, - HttpClient httpClient, - CancellationToken ct = default); + /// The Databricks host. + /// + /// HttpClient from the connection, pre-configured with: + /// - Base address (https://{host}) + /// - Auth headers (Bearer token) + /// - Custom User-Agent for connector service + /// + public FeatureFlagContext(string host, HttpClient httpClient); + + public int RefCount { get; } + public TimeSpan RefreshInterval { get; } // From server ttl_seconds + + /// + /// Gets a feature flag value by name. + /// Returns null if the flag is not found. + /// + public string? GetFlagValue(string flagName); + + /// + /// Checks if a feature flag is enabled (value is "true"). + /// Returns false if flag is not found or value is not "true". + /// + public bool IsFeatureEnabled(string flagName); + + /// + /// Gets all cached feature flags as a dictionary. + /// Can be used to merge with user properties. + /// + public IReadOnlyDictionary GetAllFlags(); + + /// + /// Stops the background refresh scheduler. + /// + public void Shutdown(); + + public void Dispose(); } /// - /// Holds feature flag state and reference count for a host. + /// Response model for feature flags API. /// - internal sealed class FeatureFlagContext + internal sealed class FeatureFlagsResponse { - public bool? TelemetryEnabled { get; set; } - public DateTime? LastFetched { get; set; } - public int RefCount { get; set; } - public TimeSpan CacheDuration { get; } = TimeSpan.FromMinutes(15); + public List? Flags { get; set; } + public int? TtlSeconds { get; set; } + } - public bool IsExpired => LastFetched == null || - DateTime.UtcNow - LastFetched.Value > CacheDuration; + internal sealed class FeatureFlagEntry + { + public string Name { get; set; } = string.Empty; + public string Value { get; set; } = string.Empty; } } ``` -**JDBC Reference**: `DatabricksDriverFeatureFlagsContextFactory.java:27` maintains per-compute (host) feature flag contexts with reference counting. +#### Usage Example + +```csharp +// In DatabricksConnection constructor/initialization +// This runs for EACH new connection, using LATEST cached feature flags + +// Step 1: Get or create feature flag context +// - If context exists: returns existing context with latest cached flags +// - If new: creates context, does initial blocking fetch, starts background refresh +var featureFlagCache = FeatureFlagCache.GetInstance(); +var featureFlagContext = featureFlagCache.GetOrCreateContext(_host, CreateFeatureFlagHttpClient()); + +// Step 2: Merge feature flags into properties using LATEST cached values +// Each new connection gets a fresh merge with current flag values +Properties = MergePropertiesWithFeatureFlags( + userProperties, + featureFlagContext.GetAllFlags()); // Returns current cached flags + +// Step 3: Existing property parsing works unchanged! +// Feature flags are now transparently available as properties +bool IsTelemetryEnabled() +{ + // This works whether the value came from: + // - User property (highest priority) + // - Feature flag (merged in) + // - Or falls back to driver default + if (Properties.TryGetValue("telemetry.enabled", out var value)) + { + return bool.TryParse(value, out var result) && result; + } + return true; // Driver default +} + +// Same pattern for all other properties - no changes needed! +if (Properties.TryGetValue(DatabricksParameters.CloudFetchEnabled, out var cfValue)) +{ + // Value could be from user OR from feature flag - transparent! +} +``` + +**Key Benefits**: +- Existing code like `Properties.TryGetValue()` continues to work unchanged +- Each new connection uses the **latest** cached feature flag values +- Feature flag integration is a one-time merge at connection initialization +- Properties are stable for the lifetime of the connection (no mid-connection changes) + +**JDBC Reference**: `DatabricksDriverFeatureFlagsContextFactory.java:27` maintains per-compute (host) feature flag contexts with reference counting. `DatabricksDriverFeatureFlagsContext.java` implements the caching, refresh scheduling, and API calls. --- @@ -257,7 +607,7 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.Telemetry #### Interface ```csharp -namespace Apache.Arrow.Adbc.Drivers.Databricks.Telemetry +namespace AdbcDrivers.Databricks.Telemetry { /// /// Singleton factory that manages one telemetry client per host. @@ -319,7 +669,7 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.Telemetry #### Interface ```csharp -namespace Apache.Arrow.Adbc.Drivers.Databricks.Telemetry +namespace AdbcDrivers.Databricks.Telemetry { /// /// Wraps telemetry exporter with circuit breaker pattern. @@ -372,7 +722,7 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.Telemetry #### Interface ```csharp -namespace Apache.Arrow.Adbc.Drivers.Databricks.Telemetry +namespace AdbcDrivers.Databricks.Telemetry { /// /// Custom ActivityListener that aggregates metrics from Activity events @@ -460,7 +810,7 @@ private ActivityListener CreateListener() #### Interface ```csharp -namespace Apache.Arrow.Adbc.Drivers.Databricks.Telemetry +namespace AdbcDrivers.Databricks.Telemetry { /// /// Aggregates metrics from activities by statement_id and includes session_id. @@ -549,38 +899,58 @@ flowchart TD **Purpose**: Export aggregated metrics to Databricks telemetry service. -**Location**: `Apache.Arrow.Adbc.Drivers.Databricks.Telemetry.DatabricksTelemetryExporter` +**Location**: `AdbcDrivers.Databricks.Telemetry.DatabricksTelemetryExporter` + +**Status**: Implemented (WI-3.4) #### Interface ```csharp -namespace Apache.Arrow.Adbc.Drivers.Databricks.Telemetry +namespace AdbcDrivers.Databricks.Telemetry { public interface ITelemetryExporter { /// - /// Export metrics to Databricks service. Never throws. + /// Export telemetry frontend logs to the backend service. + /// Never throws exceptions (all swallowed and logged at TRACE level). /// Task ExportAsync( - IReadOnlyList metrics, + IReadOnlyList logs, CancellationToken ct = default); } internal sealed class DatabricksTelemetryExporter : ITelemetryExporter { + // Authenticated telemetry endpoint + internal const string AuthenticatedEndpoint = "/telemetry-ext"; + + // Unauthenticated telemetry endpoint + internal const string UnauthenticatedEndpoint = "/telemetry-unauth"; + public DatabricksTelemetryExporter( HttpClient httpClient, - DatabricksConnection connection, + string host, + bool isAuthenticated, TelemetryConfiguration config); public Task ExportAsync( - IReadOnlyList metrics, + IReadOnlyList logs, CancellationToken ct = default); + + // Creates TelemetryRequest wrapper with uploadTime and protoLogs + internal TelemetryRequest CreateTelemetryRequest(IReadOnlyList logs); } } ``` -**Same implementation as original design**: Circuit breaker, retry logic, endpoints. +**Implementation Details**: +- Creates `TelemetryRequest` with `uploadTime` (Unix ms) and `protoLogs` (JSON-serialized `TelemetryFrontendLog` array) +- Uses `/telemetry-ext` for authenticated requests +- Uses `/telemetry-unauth` for unauthenticated requests +- Implements retry logic for transient failures (configurable via `MaxRetries` and `RetryDelayMs`) +- Uses `ExceptionClassifier` to identify terminal vs retryable errors +- Never throws exceptions (all caught and logged at TRACE level) +- Cancellation is propagated (not swallowed) --- @@ -609,7 +979,7 @@ Telemetry/ **File**: `TagDefinitions/TelemetryTag.cs` ```csharp -namespace Apache.Arrow.Adbc.Drivers.Databricks.Telemetry.TagDefinitions +namespace AdbcDrivers.Databricks.Telemetry.TagDefinitions { /// /// Defines export scope for telemetry tags. @@ -648,7 +1018,7 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.Telemetry.TagDefinitions **File**: `TagDefinitions/ConnectionOpenEvent.cs` ```csharp -namespace Apache.Arrow.Adbc.Drivers.Databricks.Telemetry.TagDefinitions +namespace AdbcDrivers.Databricks.Telemetry.TagDefinitions { /// /// Tag definitions for Connection.Open events. @@ -726,7 +1096,7 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.Telemetry.TagDefinitions **File**: `TagDefinitions/StatementExecutionEvent.cs` ```csharp -namespace Apache.Arrow.Adbc.Drivers.Databricks.Telemetry.TagDefinitions +namespace AdbcDrivers.Databricks.Telemetry.TagDefinitions { /// /// Tag definitions for Statement execution events. @@ -812,7 +1182,7 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.Telemetry.TagDefinitions **File**: `TagDefinitions/TelemetryTagRegistry.cs` ```csharp -namespace Apache.Arrow.Adbc.Drivers.Databricks.Telemetry.TagDefinitions +namespace AdbcDrivers.Databricks.Telemetry.TagDefinitions { /// /// Central registry for all telemetry tags and events. @@ -1069,9 +1439,15 @@ public sealed class TelemetryConfiguration public int CircuitBreakerThreshold { get; set; } = 5; public TimeSpan CircuitBreakerTimeout { get; set; } = TimeSpan.FromMinutes(1); - // Feature flag + // Feature flag name to check in the cached flags public const string FeatureFlagName = "databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForAdbc"; + + // Feature flag endpoint (relative to host) + // {0} = driver version without OSS suffix + // NOTE: Using OSS_JDBC endpoint until OSS_ADBC is configured server-side + public const string FeatureFlagEndpointFormat = + "/api/2.0/connector-service/feature-flags/OSS_JDBC/{0}"; } ``` @@ -1801,13 +2177,25 @@ The Activity-based design was selected because it: ## 12. Implementation Checklist ### Phase 1: Feature Flag Cache & Per-Host Management -- [ ] Create `FeatureFlagCache` singleton with per-host contexts +- [ ] Create `FeatureFlagCache` singleton with per-host contexts (in `Apache.Arrow.Adbc.Drivers.Databricks` namespace, not Telemetry) +- [ ] Make cache generic - return all flags, not just telemetry-specific ones - [ ] Implement `FeatureFlagContext` with reference counting -- [ ] Add cache expiration logic (15 minute TTL) -- [ ] Implement `FetchFeatureFlagAsync` to call feature endpoint +- [ ] Implement `GetFlagValue(string)` and `GetAllFlags()` methods for generic flag access +- [ ] Implement API call to `/api/2.0/connector-service/feature-flags/OSS_JDBC/{version}` (use JDBC endpoint initially) +- [ ] Parse `FeatureFlagsResponse` with `flags` array and `ttl_seconds` +- [ ] Implement initial blocking fetch on context creation +- [ ] Implement background refresh scheduler using server-provided `ttl_seconds` +- [ ] Add `Shutdown()` method to stop scheduler and cleanup +- [ ] Implement configuration priority: user properties > feature flags > driver defaults - [ ] Create `TelemetryClientManager` singleton - [ ] Implement `TelemetryClientHolder` with reference counting - [ ] Add unit tests for cache behavior and reference counting +- [ ] Add unit tests for background refresh scheduling +- [ ] Add unit tests for configuration priority order + +**Code Guidelines**: +- Avoid `#if` preprocessor directives - write code compatible with all target .NET versions (netstandard2.0, net472, net8.0) +- Use polyfills or runtime checks instead of compile-time conditionals where needed ### Phase 2: Circuit Breaker - [ ] Create `CircuitBreaker` class with state machine @@ -1837,7 +2225,7 @@ The Activity-based design was selected because it: ### Phase 5: Core Implementation - [ ] Create `DatabricksActivityListener` class - [ ] Create `MetricsAggregator` class (with exception buffering) -- [ ] Create `DatabricksTelemetryExporter` class +- [x] Create `DatabricksTelemetryExporter` class (WI-3.4) - [ ] Add necessary tags to existing activities (using defined constants) - [ ] Update connection to use per-host management @@ -1896,6 +2284,19 @@ This ensures compatibility with OTEL ecosystem. - Listener is optional (only activated when telemetry enabled) - Activity overhead already exists +### 13.4 Feature Flag Endpoint Migration + +**Question**: When should we migrate from `OSS_JDBC` to `OSS_ADBC` endpoint? + +**Current State**: The ADBC driver currently uses the JDBC feature flag endpoint (`/api/2.0/connector-service/feature-flags/OSS_JDBC/{version}`) because the ADBC endpoint is not yet configured on the server side. + +**Migration Plan**: +1. Server team configures the `OSS_ADBC` endpoint with appropriate feature flags +2. Update `TelemetryConfiguration.FeatureFlagEndpointFormat` to use `OSS_ADBC` +3. Coordinate with server team on feature flag name (`enableTelemetryForAdbc`) + +**Tracking**: Create a follow-up ticket to track this migration once server-side support is ready. + --- ## 14. References @@ -1920,8 +2321,12 @@ This ensures compatibility with OTEL ecosystem. - `CircuitBreakerTelemetryPushClient.java:15`: Circuit breaker wrapper - `CircuitBreakerManager.java:25`: Per-host circuit breaker management - `TelemetryPushClient.java:86-94`: Exception re-throwing for circuit breaker -- `TelemetryHelper.java:60-71`: Feature flag checking -- `DatabricksDriverFeatureFlagsContextFactory.java:27`: Per-host feature flag cache +- `TelemetryHelper.java:45-46,77`: Feature flag name and checking +- `DatabricksDriverFeatureFlagsContextFactory.java`: Per-host feature flag cache with reference counting +- `DatabricksDriverFeatureFlagsContext.java:30-33`: Feature flag API endpoint format +- `DatabricksDriverFeatureFlagsContext.java:48-58`: Initial fetch and background refresh scheduling +- `DatabricksDriverFeatureFlagsContext.java:89-140`: HTTP call and response parsing +- `FeatureFlagsResponse.java`: Response model with `flags` array and `ttl_seconds` --- diff --git a/csharp/doc/telemetry-sprint-plan.md b/csharp/doc/telemetry-sprint-plan.md index 02905a1c..2b7516bc 100644 --- a/csharp/doc/telemetry-sprint-plan.md +++ b/csharp/doc/telemetry-sprint-plan.md @@ -436,6 +436,8 @@ Implement the core telemetry infrastructure including feature flag management, p #### WI-5.3: MetricsAggregator **Description**: Aggregates Activity data by statement_id, handles exception buffering. +**Status**: ✅ **COMPLETED** + **Location**: `csharp/src/Telemetry/MetricsAggregator.cs` **Input**: @@ -443,7 +445,7 @@ Implement the core telemetry infrastructure including feature flag management, p - ITelemetryExporter for flushing **Output**: -- Aggregated TelemetryMetric per statement +- Aggregated TelemetryEvent per statement - Batched flush on threshold or interval **Test Expectations**: @@ -452,13 +454,35 @@ Implement the core telemetry infrastructure including feature flag management, p |-----------|-----------|-------|-----------------| | Unit | `MetricsAggregator_ProcessActivity_ConnectionOpen_EmitsImmediately` | Connection.Open activity | Metric queued for export | | Unit | `MetricsAggregator_ProcessActivity_Statement_AggregatesByStatementId` | Multiple activities with same statement_id | Single aggregated metric | -| Unit | `MetricsAggregator_CompleteStatement_EmitsAggregatedMetric` | Call CompleteStatement() | Queues aggregated metric | -| Unit | `MetricsAggregator_FlushAsync_BatchSizeReached_ExportsMetrics` | 100 metrics (batch size) | Calls exporter | -| Unit | `MetricsAggregator_FlushAsync_TimeInterval_ExportsMetrics` | Wait 5 seconds | Calls exporter | +| Unit | `MetricsAggregator_CompleteStatement_EmitsAggregatedEvent` | Call CompleteStatement() | Queues aggregated metric | +| Unit | `MetricsAggregator_FlushAsync_BatchSizeReached_ExportsEvents` | Batch size reached | Calls exporter | +| Unit | `MetricsAggregator_FlushAsync_TimeInterval_ExportsEvents` | Wait for interval | Calls exporter | | Unit | `MetricsAggregator_RecordException_Terminal_FlushesImmediately` | Terminal exception | Immediately exports error metric | | Unit | `MetricsAggregator_RecordException_Retryable_BuffersUntilComplete` | Retryable exception | Buffers, exports on CompleteStatement | | Unit | `MetricsAggregator_ProcessActivity_ExceptionSwallowed_NoThrow` | Activity processing throws | No exception propagated | | Unit | `MetricsAggregator_ProcessActivity_FiltersTags_UsingRegistry` | Activity with sensitive tags | Only safe tags in metric | +| Unit | `MetricsAggregator_WrapInFrontendLog_CreatesValidStructure` | TelemetryEvent | Valid TelemetryFrontendLog structure | + +**Implementation Notes**: +- Uses `ConcurrentDictionary` for thread-safe aggregation by statement_id +- Connection events emit immediately without aggregation +- Statement events are aggregated until `CompleteStatement()` is called +- Terminal exceptions (via `ExceptionClassifier`) are queued immediately +- Retryable exceptions are buffered and only emitted when `CompleteStatement(failed: true)` is called +- Uses `TelemetryTagRegistry.ShouldExportToDatabricks()` for tag filtering +- Creates `TelemetryFrontendLog` wrapper with workspace_id, client context, and timestamp +- All exceptions swallowed and logged at TRACE level using `Debug.WriteLine()` +- Timer-based periodic flush using `System.Threading.Timer` +- Comprehensive test coverage with 29 unit tests in `MetricsAggregatorTests.cs` +- Test file location: `csharp/test/Unit/Telemetry/MetricsAggregatorTests.cs` + +**Key Design Decisions**: +1. **ConcurrentDictionary for aggregation**: Thread-safe statement aggregation without explicit locking +2. **Nested StatementTelemetryContext**: Holds aggregated metrics and buffered exceptions per statement +3. **Immediate connection events**: Connection open events don't require aggregation and are emitted immediately +4. **Exception buffering**: Retryable exceptions are buffered per statement and only emitted on failed completion +5. **Timer-based flush**: Uses `System.Threading.Timer` for periodic flush based on `FlushIntervalMs` +6. **Graceful disposal**: `Dispose()` stops timer and performs final flush --- diff --git a/csharp/src/AdbcDrivers.Databricks.csproj b/csharp/src/AdbcDrivers.Databricks.csproj index 2ca93e77..0a54336a 100644 --- a/csharp/src/AdbcDrivers.Databricks.csproj +++ b/csharp/src/AdbcDrivers.Databricks.csproj @@ -7,6 +7,7 @@ + diff --git a/csharp/src/Auth/AuthHelper.cs b/csharp/src/Auth/AuthHelper.cs new file mode 100644 index 00000000..0cfb01e8 --- /dev/null +++ b/csharp/src/Auth/AuthHelper.cs @@ -0,0 +1,128 @@ +/* +* Copyright (c) 2025 ADBC Drivers Contributors +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Generic; +using Apache.Arrow.Adbc.Drivers.Apache.Spark; + +namespace AdbcDrivers.Databricks.Auth +{ + /// + /// Helper methods for authentication operations. + /// Provides shared functionality used by HttpHandlerFactory and FeatureFlagCache. + /// + internal static class AuthHelper + { + /// + /// Gets the access token from connection properties. + /// Tries access_token first, then falls back to token. + /// + /// Connection properties. + /// The token, or null if not found. + public static string? GetTokenFromProperties(IReadOnlyDictionary properties) + { + if (properties.TryGetValue(SparkParameters.AccessToken, out string? accessToken) && !string.IsNullOrEmpty(accessToken)) + { + return accessToken; + } + + if (properties.TryGetValue(SparkParameters.Token, out string? token) && !string.IsNullOrEmpty(token)) + { + return token; + } + + return null; + } + + /// + /// Gets the access token based on authentication configuration. + /// Supports token-based (PAT) and OAuth M2M (client_credentials) authentication. + /// + /// The Databricks host. + /// Connection properties containing auth configuration. + /// HTTP client timeout for OAuth token requests. + /// The access token, or null if no valid authentication is configured. + public static string? GetAccessToken(string host, IReadOnlyDictionary properties, TimeSpan timeout) + { + // Check if OAuth authentication is configured + bool useOAuth = properties.TryGetValue(SparkParameters.AuthType, out string? authType) && + SparkAuthTypeParser.TryParse(authType, out SparkAuthType authTypeValue) && + authTypeValue == SparkAuthType.OAuth; + + if (useOAuth) + { + // Determine grant type (defaults to AccessToken if not specified) + properties.TryGetValue(DatabricksParameters.OAuthGrantType, out string? grantTypeStr); + DatabricksOAuthGrantTypeParser.TryParse(grantTypeStr, out DatabricksOAuthGrantType grantType); + + if (grantType == DatabricksOAuthGrantType.ClientCredentials) + { + // OAuth M2M authentication + return GetOAuthClientCredentialsToken(host, properties, timeout); + } + else if (grantType == DatabricksOAuthGrantType.AccessToken) + { + // OAuth with access_token grant type + return GetTokenFromProperties(properties); + } + } + + // Non-OAuth authentication: use static Bearer token if provided + return GetTokenFromProperties(properties); + } + + /// + /// Gets an OAuth access token using client credentials (M2M) flow. + /// + /// The Databricks host. + /// Connection properties containing OAuth credentials. + /// HTTP client timeout. + /// The access token, or null if credentials are missing or token acquisition fails. + public static string? GetOAuthClientCredentialsToken(string host, IReadOnlyDictionary properties, TimeSpan timeout) + { + properties.TryGetValue(DatabricksParameters.OAuthClientId, out string? clientId); + properties.TryGetValue(DatabricksParameters.OAuthClientSecret, out string? clientSecret); + properties.TryGetValue(DatabricksParameters.OAuthScope, out string? scope); + + if (string.IsNullOrEmpty(clientId) || string.IsNullOrEmpty(clientSecret)) + { + return null; + } + + try + { + // Create a separate HttpClient for OAuth token acquisition with TLS and proxy settings + using var oauthHttpClient = Http.HttpClientFactory.CreateOAuthHttpClient(properties, timeout); + + using var tokenProvider = new OAuthClientCredentialsProvider( + oauthHttpClient, + clientId, + clientSecret, + host, + scope: scope ?? "sql", + timeoutMinutes: 1); + + // Get access token synchronously (blocking call) + return tokenProvider.GetAccessTokenAsync().GetAwaiter().GetResult(); + } + catch + { + // Auth failures should be handled by caller + return null; + } + } + } +} diff --git a/csharp/src/DatabricksConnection.cs b/csharp/src/DatabricksConnection.cs index 6306363f..1cfa5187 100644 --- a/csharp/src/DatabricksConnection.cs +++ b/csharp/src/DatabricksConnection.cs @@ -126,12 +126,13 @@ internal DatabricksConnection( IReadOnlyDictionary properties, Microsoft.IO.RecyclableMemoryStreamManager? memoryStreamManager, System.Buffers.ArrayPool? lz4BufferPool) - : base(MergeWithDefaultEnvironmentConfig(properties)) + : base(FeatureFlagCache.GetInstance().MergePropertiesWithFeatureFlags(MergeWithDefaultEnvironmentConfig(properties), s_assemblyVersion)) { // Use provided manager (from Database) or create new instance (for direct construction) RecyclableMemoryStreamManager = memoryStreamManager ?? new Microsoft.IO.RecyclableMemoryStreamManager(); // Use provided pool (from Database) or create new instance (for direct construction) Lz4BufferPool = lz4BufferPool ?? System.Buffers.ArrayPool.Create(maxArrayLength: 4 * 1024 * 1024, maxArraysPerBucket: 10); + ValidateProperties(); } @@ -526,7 +527,7 @@ internal override IArrowArrayStream NewReader(T statement, Schema schema, IRe isLz4Compressed = metadataResp.Lz4Compressed; } - HttpClient httpClient = new HttpClient(HiveServer2TlsImpl.NewHttpClientHandler(TlsOptions, _proxyConfigurator)); + HttpClient httpClient = HttpClientFactory.CreateCloudFetchHttpClient(Properties); return new DatabricksCompositeReader(databricksStatement, schema, response, isLz4Compressed, httpClient); } diff --git a/csharp/src/DatabricksParameters.cs b/csharp/src/DatabricksParameters.cs index 77236931..3bfd4ba1 100644 --- a/csharp/src/DatabricksParameters.cs +++ b/csharp/src/DatabricksParameters.cs @@ -357,6 +357,12 @@ public class DatabricksParameters : SparkParameters /// public const string EnableSessionManagement = "adbc.databricks.rest.enable_session_management"; + /// + /// Timeout in seconds for feature flag API requests. + /// Default value is 10 seconds if not specified. + /// + public const string FeatureFlagTimeoutSeconds = "adbc.databricks.feature_flag_timeout_seconds"; + } /// diff --git a/csharp/src/FeatureFlagCache.cs b/csharp/src/FeatureFlagCache.cs new file mode 100644 index 00000000..8a0598e6 --- /dev/null +++ b/csharp/src/FeatureFlagCache.cs @@ -0,0 +1,492 @@ +/* +* Copyright (c) 2025 ADBC Drivers Contributors +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; +using Apache.Arrow.Adbc.Drivers.Apache; +using Apache.Arrow.Adbc.Drivers.Apache.Spark; +using Microsoft.Extensions.Caching.Memory; + +namespace AdbcDrivers.Databricks +{ + /// + /// Singleton that manages feature flag cache per host using IMemoryCache. + /// Prevents rate limiting by caching feature flag responses with TTL-based expiration. + /// + /// + /// + /// This class implements a per-host caching pattern: + /// - Feature flags are cached by host to prevent rate limiting + /// - TTL-based expiration using server's ttl_seconds + /// - Background refresh task runs based on TTL within each FeatureFlagContext + /// - Automatic cleanup via IMemoryCache eviction + /// - Thread-safe using IMemoryCache and SemaphoreSlim for async locking + /// + /// + /// JDBC Reference: DatabricksDriverFeatureFlagsContextFactory.java + /// + /// + internal sealed class FeatureFlagCache : IDisposable + { + private static readonly FeatureFlagCache s_instance = new FeatureFlagCache(); + + /// + /// Activity source for feature flag tracing. + /// + private static readonly ActivitySource s_activitySource = new ActivitySource("AdbcDrivers.Databricks.FeatureFlagCache"); + + private readonly IMemoryCache _cache; + private readonly SemaphoreSlim _createLock = new SemaphoreSlim(1, 1); + private bool _disposed; + + /// + /// Gets the singleton instance of the FeatureFlagCache. + /// + public static FeatureFlagCache GetInstance() => s_instance; + + /// + /// Creates a new FeatureFlagCache with default MemoryCache. + /// + internal FeatureFlagCache() : this(new MemoryCache(new MemoryCacheOptions())) + { + } + + /// + /// Creates a new FeatureFlagCache with custom IMemoryCache (for testing). + /// + /// The memory cache to use. + internal FeatureFlagCache(IMemoryCache cache) + { + _cache = cache ?? throw new ArgumentNullException(nameof(cache)); + } + + /// + /// Gets or creates a feature flag context for the host asynchronously. + /// Waits for initial fetch to complete before returning. + /// + /// The host (Databricks workspace URL) to get or create a context for. + /// + /// HttpClient from the connection, pre-configured with: + /// - Base address (https://{host}) + /// - Auth headers (Bearer token) + /// - Custom User-Agent for connector service + /// + /// The driver version for the API endpoint. + /// Optional custom endpoint format. If null, uses the default endpoint. + /// Cancellation token. + /// The feature flag context for the host. + /// Thrown when host is null or whitespace. + /// Thrown when httpClient is null. + public async Task GetOrCreateContextAsync( + string host, + HttpClient httpClient, + string driverVersion, + string? endpointFormat = null, + CancellationToken cancellationToken = default) + { + if (string.IsNullOrWhiteSpace(host)) + { + throw new ArgumentException("Host cannot be null or whitespace.", nameof(host)); + } + + if (httpClient == null) + { + throw new ArgumentNullException(nameof(httpClient)); + } + + var cacheKey = GetCacheKey(host); + + // Try to get existing context (fast path) + if (_cache.TryGetValue(cacheKey, out FeatureFlagContext? context) && context != null) + { + Activity.Current?.AddEvent(new ActivityEvent("feature_flags.cache_hit", + tags: new ActivityTagsCollection { { "host", host } })); + + return context; + } + + // Cache miss - create new context with async lock to prevent duplicate creation + await _createLock.WaitAsync(cancellationToken).ConfigureAwait(false); + try + { + // Double-check after acquiring lock + if (_cache.TryGetValue(cacheKey, out context) && context != null) + { + return context; + } + + // Create context asynchronously - this waits for initial fetch to complete + context = await FeatureFlagContext.CreateAsync( + host, + httpClient, + driverVersion, + endpointFormat, + cancellationToken).ConfigureAwait(false); + + // Set cache options with TTL from context + var cacheOptions = new MemoryCacheEntryOptions() + .SetAbsoluteExpiration(context.Ttl) + .RegisterPostEvictionCallback(OnCacheEntryEvicted); + + _cache.Set(cacheKey, context, cacheOptions); + + Activity.Current?.AddEvent(new ActivityEvent("feature_flags.context_created", + tags: new ActivityTagsCollection + { + { "host", host }, + { "ttl_seconds", context.Ttl.TotalSeconds } + })); + + return context; + } + finally + { + _createLock.Release(); + } + } + + /// + /// Synchronous wrapper for GetOrCreateContextAsync. + /// Used for backward compatibility with synchronous callers. + /// + public FeatureFlagContext GetOrCreateContext( + string host, + HttpClient httpClient, + string driverVersion, + string? endpointFormat = null) + { + return GetOrCreateContextAsync(host, httpClient, driverVersion, endpointFormat) + .ConfigureAwait(false) + .GetAwaiter() + .GetResult(); + } + + /// + /// Callback invoked when a cache entry is evicted. + /// Disposes the context to clean up resources (stops background refresh task). + /// + private static void OnCacheEntryEvicted(object key, object? value, EvictionReason reason, object? state) + { + if (value is FeatureFlagContext context) + { + context.Dispose(); + + Activity.Current?.AddEvent(new ActivityEvent("feature_flags.context_evicted", + tags: new ActivityTagsCollection + { + { "host", context.Host }, + { "reason", reason.ToString() } + })); + } + } + + /// + /// Gets the cache key for a host. + /// + private static string GetCacheKey(string host) + { + return $"feature_flags:{host.ToLowerInvariant()}"; + } + + /// + /// Gets the number of hosts currently cached. + /// Note: IMemoryCache doesn't expose count directly, so this returns -1 if not available. + /// + internal int CachedHostCount + { + get + { + if (_cache is MemoryCache memoryCache) + { + return memoryCache.Count; + } + return -1; + } + } + + /// + /// Checks if a context exists for the specified host. + /// + /// The host to check. + /// True if a context exists, false otherwise. + internal bool HasContext(string host) + { + if (string.IsNullOrWhiteSpace(host)) + { + return false; + } + + return _cache.TryGetValue(GetCacheKey(host), out _); + } + + /// + /// Gets the context for the specified host, if it exists. + /// Does not create a new context. + /// + /// The host to get the context for. + /// The context if found, null otherwise. + /// True if the context was found, false otherwise. + internal bool TryGetContext(string host, out FeatureFlagContext? context) + { + context = null; + + if (string.IsNullOrWhiteSpace(host)) + { + return false; + } + + return _cache.TryGetValue(GetCacheKey(host), out context); + } + + /// + /// Removes a context from the cache. + /// + /// The host to remove. + internal void RemoveContext(string host) + { + if (!string.IsNullOrWhiteSpace(host)) + { + _cache.Remove(GetCacheKey(host)); + } + } + + /// + /// Clears all cached contexts. + /// This is primarily for testing purposes. + /// + internal void Clear() + { + if (_cache is MemoryCache memoryCache) + { + memoryCache.Compact(1.0); // Remove all entries + } + } + + /// + /// Disposes the cache and all cached contexts. + /// + public void Dispose() + { + if (_disposed) + { + return; + } + + _disposed = true; + + _createLock.Dispose(); + + if (_cache is IDisposable disposableCache) + { + disposableCache.Dispose(); + } + } + + /// + /// Merges feature flags from server into properties asynchronously. + /// Feature flags (remote properties) have lower priority than user-specified properties (local properties). + /// Priority: Local Properties > Remote Properties (Feature Flags) > Driver Defaults + /// + /// Local properties from user configuration and environment. + /// The driver version for the API endpoint. + /// Cancellation token. + /// Properties with remote feature flags merged in (local properties take precedence). + public async Task> MergePropertiesWithFeatureFlagsAsync( + IReadOnlyDictionary localProperties, + string assemblyVersion, + CancellationToken cancellationToken = default) + { + using var activity = s_activitySource.StartActivity("MergePropertiesWithFeatureFlags"); + + try + { + // Extract host from local properties + var host = TryGetHost(localProperties); + if (string.IsNullOrEmpty(host)) + { + activity?.AddEvent(new ActivityEvent("feature_flags.skipped", + tags: new ActivityTagsCollection { { "reason", "no_host" } })); + return localProperties; + } + + activity?.SetTag("feature_flags.host", host); + + // Create HttpClient for feature flag API + using var httpClient = CreateFeatureFlagHttpClient(host, assemblyVersion, localProperties); + + if (httpClient == null) + { + activity?.AddEvent(new ActivityEvent("feature_flags.skipped", + tags: new ActivityTagsCollection { { "reason", "no_auth_credentials" } })); + return localProperties; + } + + // Get or create feature flag context asynchronously (waits for initial fetch) + var context = await GetOrCreateContextAsync(host, httpClient, assemblyVersion, cancellationToken: cancellationToken).ConfigureAwait(false); + + // Get all flags from cache (remote properties) + var remoteProperties = context.GetAllFlags(); + + if (remoteProperties.Count == 0) + { + activity?.AddEvent(new ActivityEvent("feature_flags.skipped", + tags: new ActivityTagsCollection { { "reason", "no_flags_returned" } })); + return localProperties; + } + + activity?.SetTag("feature_flags.count", remoteProperties.Count); + activity?.AddEvent(new ActivityEvent("feature_flags.merging", + tags: new ActivityTagsCollection { { "flags_count", remoteProperties.Count } })); + + // Merge: remote properties (feature flags) as base, local properties override + // This ensures local properties take precedence over remote flags + return MergeProperties(remoteProperties, localProperties); + } + catch (Exception ex) + { + // Feature flag failures should never break the connection + activity?.SetStatus(ActivityStatusCode.Error, ex.Message); + activity?.AddEvent(new ActivityEvent("feature_flags.error", + tags: new ActivityTagsCollection + { + { "error.type", ex.GetType().Name }, + { "error.message", ex.Message } + })); + return localProperties; + } + } + + /// + /// Synchronous wrapper for MergePropertiesWithFeatureFlagsAsync. + /// Used for backward compatibility with synchronous callers. + /// + public IReadOnlyDictionary MergePropertiesWithFeatureFlags( + IReadOnlyDictionary localProperties, + string assemblyVersion) + { + return MergePropertiesWithFeatureFlagsAsync(localProperties, assemblyVersion) + .ConfigureAwait(false) + .GetAwaiter() + .GetResult(); + } + + /// + /// Tries to extract the host from properties without throwing. + /// Handles cases where user puts protocol in host (e.g., "https://myhost.databricks.com"). + /// + /// Connection properties. + /// The host (without protocol), or null if not found. + internal static string? TryGetHost(IReadOnlyDictionary properties) + { + if (properties.TryGetValue(SparkParameters.HostName, out string? host) && !string.IsNullOrEmpty(host)) + { + // Handle case where user puts protocol in host + return StripProtocol(host); + } + + if (properties.TryGetValue(Apache.Arrow.Adbc.AdbcOptions.Uri, out string? uri) && !string.IsNullOrEmpty(uri)) + { + if (Uri.TryCreate(uri, UriKind.Absolute, out Uri? parsedUri)) + { + return parsedUri.Host; + } + } + + return null; + } + + /// + /// Strips protocol prefix from a host string if present. + /// + /// The host string that may contain a protocol. + /// The host without protocol prefix. + private static string StripProtocol(string host) + { + // Try to parse as URI first to handle full URLs + if (Uri.TryCreate(host, UriKind.Absolute, out Uri? parsedUri) && + (parsedUri.Scheme == Uri.UriSchemeHttp || parsedUri.Scheme == Uri.UriSchemeHttps)) + { + return parsedUri.Host; + } + + // Fallback: strip common protocol prefixes manually + if (host.StartsWith("https://", StringComparison.OrdinalIgnoreCase)) + { + return host.Substring(8); + } + if (host.StartsWith("http://", StringComparison.OrdinalIgnoreCase)) + { + return host.Substring(7); + } + + return host; + } + + /// + /// Creates an HttpClient configured for the feature flag API. + /// Supports all authentication methods including: + /// - PAT (Personal Access Token) + /// - OAuth M2M (client_credentials) + /// - OAuth U2M (access_token) + /// - Workload Identity Federation (via MandatoryTokenExchangeDelegatingHandler) + /// Respects proxy settings and TLS options from connection properties. + /// + /// The Databricks host (without protocol). + /// The driver version for the User-Agent. + /// Connection properties for proxy, TLS, and auth configuration. + /// Configured HttpClient, or null if no valid authentication is available. + private static HttpClient? CreateFeatureFlagHttpClient( + string host, + string assemblyVersion, + IReadOnlyDictionary properties) + { + // Use centralized factory to create HttpClient with full auth handler chain + // This properly supports Workload Identity Federation via MandatoryTokenExchangeDelegatingHandler + return Http.HttpClientFactory.CreateFeatureFlagHttpClient(properties, host, assemblyVersion); + } + + /// + /// Merges two property dictionaries. Additional properties override base properties. + /// + /// Base properties (lower priority). + /// Additional properties (higher priority). + /// Merged properties. + private static IReadOnlyDictionary MergeProperties( + IReadOnlyDictionary baseProperties, + IReadOnlyDictionary additionalProperties) + { + var merged = new Dictionary(StringComparer.OrdinalIgnoreCase); + + // Add base properties first + foreach (var kvp in baseProperties) + { + merged[kvp.Key] = kvp.Value; + } + + // Additional properties override base properties + foreach (var kvp in additionalProperties) + { + merged[kvp.Key] = kvp.Value; + } + + return merged; + } + } +} diff --git a/csharp/src/FeatureFlagContext.cs b/csharp/src/FeatureFlagContext.cs new file mode 100644 index 00000000..1cf44f7b --- /dev/null +++ b/csharp/src/FeatureFlagContext.cs @@ -0,0 +1,370 @@ +/* +* Copyright (c) 2025 ADBC Drivers Contributors +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Diagnostics; +using System.Net.Http; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using Apache.Arrow.Adbc.Drivers.Apache; +using Apache.Arrow.Adbc.Tracing; + +namespace AdbcDrivers.Databricks +{ + /// + /// Holds feature flag state for a host. + /// Cached by FeatureFlagCache with TTL-based expiration. + /// + /// + /// Each host (Databricks workspace) has one FeatureFlagContext instance + /// that is shared across all connections to that host. The context: + /// - Caches all feature flags returned by the server + /// - Tracks TTL from server response for cache expiration + /// - Runs a background refresh task based on TTL + /// + /// Thread-safety is ensured using ConcurrentDictionary for flag storage. + /// + /// JDBC Reference: DatabricksDriverFeatureFlagsContext.java + /// + internal sealed class FeatureFlagContext : IDisposable + { + /// + /// Activity source for feature flag tracing. + /// + private static readonly ActivitySource s_activitySource = new ActivitySource("AdbcDrivers.Databricks.FeatureFlags"); + + /// + /// Assembly version for the driver. + /// + private static readonly string s_assemblyVersion = ApacheUtility.GetAssemblyVersion(typeof(FeatureFlagContext)); + + /// + /// Default TTL (15 minutes) if server doesn't specify ttl_seconds. + /// + public static readonly TimeSpan DefaultTtl = TimeSpan.FromMinutes(15); + + /// + /// Default feature flag endpoint format. {0} = driver version. + /// NOTE: Using OSS_JDBC endpoint until OSS_ADBC is configured server-side. + /// + internal const string DefaultFeatureFlagEndpointFormat = "/api/2.0/connector-service/feature-flags/OSS_JDBC/{0}"; + + private readonly string _host; + private readonly string _driverVersion; + private readonly string _endpointFormat; + private readonly HttpClient? _httpClient; + private readonly ConcurrentDictionary _flags; + private readonly CancellationTokenSource _refreshCts; + private readonly object _ttlLock = new object(); + + private Task? _refreshTask; + private TimeSpan _ttl; + private bool _disposed; + + /// + /// Gets the current TTL (from server ttl_seconds). + /// + public TimeSpan Ttl + { + get + { + lock (_ttlLock) + { + return _ttl; + } + } + internal set + { + lock (_ttlLock) + { + _ttl = value; + } + } + } + + /// + /// Gets the host this context is for. + /// + public string Host => _host; + + /// + /// Gets the current refresh interval (alias for Ttl). + /// + public TimeSpan RefreshInterval => Ttl; + + /// + /// Internal constructor - use CreateAsync factory method for production code. + /// Made internal to allow test code to create instances without HTTP calls. + /// + internal FeatureFlagContext(string host, HttpClient? httpClient, string driverVersion, string? endpointFormat) + { + _host = host; + _httpClient = httpClient; + _driverVersion = driverVersion ?? "1.0.0"; + _endpointFormat = endpointFormat ?? DefaultFeatureFlagEndpointFormat; + _flags = new ConcurrentDictionary(StringComparer.OrdinalIgnoreCase); + _ttl = DefaultTtl; + _refreshCts = new CancellationTokenSource(); + } + + /// + /// Creates a new context with the given HTTP client. + /// Performs initial async fetch to populate cache, then starts background refresh task. + /// + /// The Databricks host. + /// + /// HttpClient from the connection, pre-configured with: + /// - Base address (https://{host}) + /// - Auth headers (Bearer token) + /// - Custom User-Agent for connector service + /// + /// The driver version for the API endpoint. + /// Optional custom endpoint format. If null, uses the default endpoint. + /// Cancellation token for the initial fetch. + /// A fully initialized FeatureFlagContext. + public static async Task CreateAsync( + string host, + HttpClient httpClient, + string driverVersion, + string? endpointFormat = null, + CancellationToken cancellationToken = default) + { + if (string.IsNullOrWhiteSpace(host)) + { + throw new ArgumentException("Host cannot be null or whitespace.", nameof(host)); + } + + if (httpClient == null) + { + throw new ArgumentNullException(nameof(httpClient)); + } + + var context = new FeatureFlagContext(host, httpClient, driverVersion, endpointFormat); + + // Initial async fetch - wait for it to complete + await context.FetchFeatureFlagsAsync("Initial", cancellationToken).ConfigureAwait(false); + + // Start background refresh task + context.StartBackgroundRefresh(); + + return context; + } + + /// + /// Gets a feature flag value by name. + /// Returns null if the flag is not found. + /// + /// The feature flag name. + /// The flag value, or null if not found. + public string? GetFlagValue(string flagName) + { + if (string.IsNullOrWhiteSpace(flagName)) + { + return null; + } + + return _flags.TryGetValue(flagName, out var value) ? value : null; + } + + /// + /// Gets all cached feature flags as a dictionary. + /// Can be used to merge with user properties. + /// + /// A read-only dictionary of all cached flags. + public IReadOnlyDictionary GetAllFlags() + { + // Return a snapshot to avoid concurrency issues + return new Dictionary(_flags, StringComparer.OrdinalIgnoreCase); + } + + /// + /// Starts the background refresh task that periodically fetches flags based on TTL. + /// + private void StartBackgroundRefresh() + { + _refreshTask = Task.Run(async () => + { + while (!_refreshCts.Token.IsCancellationRequested) + { + try + { + // Wait for TTL duration before refreshing + await Task.Delay(Ttl, _refreshCts.Token).ConfigureAwait(false); + + if (!_refreshCts.Token.IsCancellationRequested) + { + await FetchFeatureFlagsAsync("Background", _refreshCts.Token).ConfigureAwait(false); + } + } + catch (OperationCanceledException) + { + // Normal cancellation, exit the loop + break; + } + catch (Exception ex) + { + // Log error but continue the refresh loop + Activity.Current?.AddEvent("feature_flags.background_refresh.error", [ + new("error.message", ex.Message), + new("error.type", ex.GetType().Name) + ]); + } + } + }, _refreshCts.Token); + } + + /// + /// Disposes the context and stops the background refresh task. + /// + public void Dispose() + { + if (_disposed) + { + return; + } + + _disposed = true; + + // Cancel the background refresh task + _refreshCts.Cancel(); + + // Wait briefly for the task to complete (don't block indefinitely) + try + { + _refreshTask?.Wait(TimeSpan.FromSeconds(1)); + } + catch (AggregateException) + { + // Task was cancelled, ignore + } + + _refreshCts.Dispose(); + } + + /// + /// Fetches feature flags from the API endpoint asynchronously. + /// + /// Type of fetch for logging purposes (e.g., "Initial" or "Background"). + /// Cancellation token. + private async Task FetchFeatureFlagsAsync(string fetchType, CancellationToken cancellationToken) + { + if (_httpClient == null) + { + return; + } + + using var activity = s_activitySource.StartActivity($"FetchFeatureFlags.{fetchType}"); + activity?.SetTag("feature_flags.host", _host); + activity?.SetTag("feature_flags.fetch_type", fetchType); + + try + { + var endpoint = string.Format(_endpointFormat, _driverVersion); + activity?.SetTag("feature_flags.endpoint", endpoint); + + var response = await _httpClient.GetAsync(endpoint, cancellationToken).ConfigureAwait(false); + + activity?.SetTag("feature_flags.response.status_code", (int)response.StatusCode); + + // Use the standard EnsureSuccessOrThrow extension method + response.EnsureSuccessOrThrow(); + + var content = await response.Content.ReadAsStringAsync().ConfigureAwait(false); + ProcessResponse(content, activity); + + activity?.SetStatus(ActivityStatusCode.Ok); + } + catch (OperationCanceledException) + { + // Propagate cancellation + throw; + } + catch (Exception ex) + { + // Swallow exceptions - telemetry should not break the connection + activity?.SetStatus(ActivityStatusCode.Error, ex.Message); + activity?.AddEvent("feature_flags.fetch.failed", [ + new("error.message", ex.Message), + new("error.type", ex.GetType().Name) + ]); + } + } + + /// + /// Processes the JSON response and updates the cache. + /// + /// The JSON response content. + /// The current activity for tracing. + private void ProcessResponse(string content, Activity? activity) + { + try + { + var response = JsonSerializer.Deserialize(content); + + if (response?.Flags != null) + { + foreach (var flag in response.Flags) + { + if (!string.IsNullOrEmpty(flag.Name)) + { + _flags[flag.Name] = flag.Value ?? string.Empty; + } + } + + activity?.SetTag("feature_flags.count", response.Flags.Count); + activity?.AddEvent("feature_flags.updated", [ + new("flags_count", response.Flags.Count) + ]); + } + + // Update TTL if server provides a different value + if (response?.TtlSeconds != null && response.TtlSeconds > 0) + { + Ttl = TimeSpan.FromSeconds(response.TtlSeconds.Value); + activity?.SetTag("feature_flags.ttl_seconds", response.TtlSeconds.Value); + } + } + catch (JsonException ex) + { + activity?.AddEvent("feature_flags.parse.failed", [ + new("error.message", ex.Message), + new("error.type", ex.GetType().Name) + ]); + } + } + + /// + /// Clears all cached flags. + /// This is primarily for testing purposes. + /// + internal void ClearFlags() + { + _flags.Clear(); + } + + /// + /// Sets a flag value directly. + /// This is primarily for testing purposes. + /// + internal void SetFlag(string name, string value) + { + _flags[name] = value; + } + } +} diff --git a/csharp/src/FeatureFlagsResponse.cs b/csharp/src/FeatureFlagsResponse.cs new file mode 100644 index 00000000..088fcf3d --- /dev/null +++ b/csharp/src/FeatureFlagsResponse.cs @@ -0,0 +1,59 @@ +/* +* Copyright (c) 2025 ADBC Drivers Contributors +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System.Collections.Generic; +using System.Text.Json.Serialization; + +namespace AdbcDrivers.Databricks +{ + /// + /// Response model for the feature flags API. + /// Maps to the JSON response from /api/2.0/connector-service/feature-flags/{driver_type}/{version}. + /// + internal sealed class FeatureFlagsResponse + { + /// + /// Array of feature flag entries with name and value. + /// + [JsonPropertyName("flags")] + public List? Flags { get; set; } + + /// + /// Server-controlled refresh interval in seconds. + /// Default is 900 (15 minutes) if not provided. + /// + [JsonPropertyName("ttl_seconds")] + public int? TtlSeconds { get; set; } + } + + /// + /// Individual feature flag entry with name and value. + /// + internal sealed class FeatureFlagEntry + { + /// + /// The feature flag name (e.g., "databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForAdbc"). + /// + [JsonPropertyName("name")] + public string Name { get; set; } = string.Empty; + + /// + /// The feature flag value as a string (e.g., "true", "false", "10"). + /// + [JsonPropertyName("value")] + public string Value { get; set; } = string.Empty; + } +} diff --git a/csharp/src/Http/HttpClientFactory.cs b/csharp/src/Http/HttpClientFactory.cs new file mode 100644 index 00000000..b9282ae3 --- /dev/null +++ b/csharp/src/Http/HttpClientFactory.cs @@ -0,0 +1,140 @@ +/* +* Copyright (c) 2025 ADBC Drivers Contributors +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Generic; +using System.Net.Http; +using System.Net.Http.Headers; +using Apache.Arrow.Adbc.Drivers.Apache; +using Apache.Arrow.Adbc.Drivers.Apache.Hive2; + +namespace AdbcDrivers.Databricks.Http +{ + /// + /// Centralized factory for creating HttpClient instances with proper TLS and proxy configuration. + /// All HttpClient creation should go through this factory to ensure consistent configuration. + /// + internal static class HttpClientFactory + { + /// + /// Creates an HttpClientHandler with TLS and proxy settings from connection properties. + /// This is the base handler used by all HttpClient instances. + /// + /// Connection properties containing TLS and proxy configuration. + /// Configured HttpClientHandler. + public static HttpClientHandler CreateHandler(IReadOnlyDictionary properties) + { + var tlsOptions = HiveServer2TlsImpl.GetHttpTlsOptions(properties); + var proxyConfigurator = HiveServer2ProxyConfigurator.FromProperties(properties); + return HiveServer2TlsImpl.NewHttpClientHandler(tlsOptions, proxyConfigurator); + } + + /// + /// Creates a basic HttpClient with TLS and proxy settings. + /// Use this for simple HTTP operations that don't need auth handlers or retries. + /// + /// Connection properties containing TLS and proxy configuration. + /// Optional timeout. If not specified, uses default HttpClient timeout. + /// Configured HttpClient. + public static HttpClient CreateBasicHttpClient(IReadOnlyDictionary properties, TimeSpan? timeout = null) + { + var handler = CreateHandler(properties); + var httpClient = new HttpClient(handler); + + if (timeout.HasValue) + { + httpClient.Timeout = timeout.Value; + } + + return httpClient; + } + + /// + /// Creates an HttpClient for CloudFetch downloads. + /// Includes TLS and proxy settings but no auth headers (CloudFetch uses pre-signed URLs). + /// + /// Connection properties containing TLS and proxy configuration. + /// Configured HttpClient for CloudFetch. + public static HttpClient CreateCloudFetchHttpClient(IReadOnlyDictionary properties) + { + int timeoutMinutes = PropertyHelper.GetPositiveIntPropertyWithValidation( + properties, + DatabricksParameters.CloudFetchTimeoutMinutes, + DatabricksConstants.DefaultCloudFetchTimeoutMinutes); + + return CreateBasicHttpClient(properties, TimeSpan.FromMinutes(timeoutMinutes)); + } + + /// + /// Creates an HttpClient for feature flag API calls. + /// Includes TLS, proxy settings, and full authentication handler chain. + /// Supports PAT, OAuth M2M, OAuth U2M, and Workload Identity Federation. + /// + /// Connection properties. + /// The Databricks host (without protocol). + /// The driver version for the User-Agent. + /// Configured HttpClient for feature flags, or null if no valid authentication is available. + public static HttpClient? CreateFeatureFlagHttpClient( + IReadOnlyDictionary properties, + string host, + string assemblyVersion) + { + const int DefaultFeatureFlagTimeoutSeconds = 10; + + var timeoutSeconds = PropertyHelper.GetPositiveIntPropertyWithValidation( + properties, + DatabricksParameters.FeatureFlagTimeoutSeconds, + DefaultFeatureFlagTimeoutSeconds); + + // Create handler with full auth chain (including WIF support) + var handler = HttpHandlerFactory.CreateFeatureFlagHandler(properties, host, timeoutSeconds); + if (handler == null) + { + return null; + } + + var httpClient = new HttpClient(handler) + { + BaseAddress = new Uri($"https://{host}"), + Timeout = TimeSpan.FromSeconds(timeoutSeconds) + }; + + // Use same User-Agent format as other Databricks HTTP clients + string userAgent = $"DatabricksJDBCDriverOSS/{assemblyVersion} (ADBC)"; + string userAgentEntry = PropertyHelper.GetStringProperty(properties, "adbc.spark.user_agent_entry", string.Empty); + if (!string.IsNullOrWhiteSpace(userAgentEntry)) + { + userAgent = $"{userAgent} {userAgentEntry}"; + } + + httpClient.DefaultRequestHeaders.UserAgent.ParseAdd(userAgent); + + return httpClient; + } + + /// + /// Creates an HttpClient for OAuth token operations. + /// Includes TLS and proxy settings with configurable timeout. + /// + /// Connection properties. + /// Timeout for OAuth operations. + /// Configured HttpClient for OAuth. + public static HttpClient CreateOAuthHttpClient(IReadOnlyDictionary properties, TimeSpan timeout) + { + return CreateBasicHttpClient(properties, timeout); + } + } +} diff --git a/csharp/src/Http/HttpHandlerFactory.cs b/csharp/src/Http/HttpHandlerFactory.cs index 31778e5d..0f27363c 100644 --- a/csharp/src/Http/HttpHandlerFactory.cs +++ b/csharp/src/Http/HttpHandlerFactory.cs @@ -126,6 +126,151 @@ internal class HandlerResult public HttpClient? AuthHttpClient { get; set; } } + /// + /// Checks if OAuth authentication is configured in properties. + /// + private static bool IsOAuthEnabled(IReadOnlyDictionary properties) + { + return properties.TryGetValue(SparkParameters.AuthType, out string? authType) && + SparkAuthTypeParser.TryParse(authType, out SparkAuthType authTypeValue) && + authTypeValue == SparkAuthType.OAuth; + } + + /// + /// Gets the OAuth grant type from properties. + /// + private static DatabricksOAuthGrantType GetOAuthGrantType(IReadOnlyDictionary properties) + { + properties.TryGetValue(DatabricksParameters.OAuthGrantType, out string? grantTypeStr); + DatabricksOAuthGrantTypeParser.TryParse(grantTypeStr, out DatabricksOAuthGrantType grantType); + return grantType; + } + + /// + /// Creates an OAuthClientCredentialsProvider for M2M authentication. + /// Returns null if client ID or secret is missing. + /// + private static OAuthClientCredentialsProvider? CreateOAuthClientCredentialsProvider( + IReadOnlyDictionary properties, + HttpClient authHttpClient, + string host) + { + properties.TryGetValue(DatabricksParameters.OAuthClientId, out string? clientId); + properties.TryGetValue(DatabricksParameters.OAuthClientSecret, out string? clientSecret); + properties.TryGetValue(DatabricksParameters.OAuthScope, out string? scope); + + if (string.IsNullOrEmpty(clientId) || string.IsNullOrEmpty(clientSecret)) + { + return null; + } + + return new OAuthClientCredentialsProvider( + authHttpClient, + clientId, + clientSecret, + host, + scope: scope ?? "sql", + timeoutMinutes: 1); + } + + /// + /// Adds authentication handlers to the handler chain. + /// Returns the new handler chain with auth handlers added, or null if auth is required but not available. + /// + /// The current handler chain. + /// Connection properties. + /// The Databricks host. + /// HTTP client for auth operations (required for OAuth). + /// Identity federation client ID (optional). + /// Whether to enable JWT token refresh for access_token grant type. + /// Output: the auth HTTP client if created. + /// Handler with auth handlers added, or null if required auth is not available. + private static HttpMessageHandler? AddAuthHandlers( + HttpMessageHandler handler, + IReadOnlyDictionary properties, + string host, + HttpClient? authHttpClient, + string? identityFederationClientId, + bool enableTokenRefresh, + out HttpClient? authHttpClientOut) + { + authHttpClientOut = authHttpClient; + + if (IsOAuthEnabled(properties)) + { + if (authHttpClient == null) + { + return null; // OAuth requires auth HTTP client + } + + ITokenExchangeClient tokenExchangeClient = new TokenExchangeClient(authHttpClient, host); + + // Mandatory token exchange should be the inner handler so that it happens + // AFTER the OAuth handlers (e.g. after M2M sets the access token) + handler = new MandatoryTokenExchangeDelegatingHandler( + handler, + tokenExchangeClient, + identityFederationClientId); + + var grantType = GetOAuthGrantType(properties); + + if (grantType == DatabricksOAuthGrantType.ClientCredentials) + { + var tokenProvider = CreateOAuthClientCredentialsProvider(properties, authHttpClient, host); + if (tokenProvider == null) + { + return null; // Missing client credentials + } + handler = new OAuthDelegatingHandler(handler, tokenProvider); + } + else if (grantType == DatabricksOAuthGrantType.AccessToken) + { + string? accessToken = AuthHelper.GetTokenFromProperties(properties); + if (string.IsNullOrEmpty(accessToken)) + { + return null; // No access token + } + + if (enableTokenRefresh && + properties.TryGetValue(DatabricksParameters.TokenRenewLimit, out string? tokenRenewLimitStr) && + int.TryParse(tokenRenewLimitStr, out int tokenRenewLimit) && + tokenRenewLimit > 0 && + JwtTokenDecoder.TryGetExpirationTime(accessToken, out DateTime expiryTime)) + { + handler = new TokenRefreshDelegatingHandler( + handler, + tokenExchangeClient, + accessToken, + expiryTime, + tokenRenewLimit); + } + else + { + handler = new StaticBearerTokenHandler(handler, accessToken); + } + } + else + { + return null; // Unknown grant type + } + } + else + { + // Non-OAuth authentication: use static Bearer token if provided + string? accessToken = AuthHelper.GetTokenFromProperties(properties); + if (!string.IsNullOrEmpty(accessToken)) + { + handler = new StaticBearerTokenHandler(handler, accessToken); + } + else + { + return null; // No auth available + } + } + + return handler; + } + /// /// Creates HTTP handlers with OAuth and other delegating handlers. /// @@ -175,115 +320,83 @@ public static HandlerResult CreateHandlers(HandlerConfig config) authHandler = new ThriftErrorMessageHandler(authHandler); } + // Create auth HTTP client for OAuth if needed HttpClient? authHttpClient = null; - - // Check if OAuth authentication is configured - bool useOAuth = config.Properties.TryGetValue(SparkParameters.AuthType, out string? authType) && - SparkAuthTypeParser.TryParse(authType, out SparkAuthType authTypeValue) && - authTypeValue == SparkAuthType.OAuth; - - if (useOAuth) + if (IsOAuthEnabled(config.Properties)) { - // Create auth HTTP client for token operations authHttpClient = new HttpClient(authHandler) { Timeout = TimeSpan.FromMinutes(config.TimeoutMinutes) }; + } - ITokenExchangeClient tokenExchangeClient = new TokenExchangeClient(authHttpClient, config.Host); - - // Mandatory token exchange should be the inner handler so that it happens - // AFTER the OAuth handlers (e.g. after M2M sets the access token) - handler = new MandatoryTokenExchangeDelegatingHandler( - handler, - tokenExchangeClient, - config.IdentityFederationClientId); - - // Determine grant type (defaults to AccessToken if not specified) - config.Properties.TryGetValue(DatabricksParameters.OAuthGrantType, out string? grantTypeStr); - DatabricksOAuthGrantTypeParser.TryParse(grantTypeStr, out DatabricksOAuthGrantType grantType); + // Add auth handlers + var resultHandler = AddAuthHandlers( + handler, + config.Properties, + config.Host, + authHttpClient, + config.IdentityFederationClientId, + enableTokenRefresh: true, + out authHttpClient); - // Add OAuth client credentials handler if OAuth M2M authentication is being used - if (grantType == DatabricksOAuthGrantType.ClientCredentials) - { - config.Properties.TryGetValue(DatabricksParameters.OAuthClientId, out string? clientId); - config.Properties.TryGetValue(DatabricksParameters.OAuthClientSecret, out string? clientSecret); - config.Properties.TryGetValue(DatabricksParameters.OAuthScope, out string? scope); - - var tokenProvider = new OAuthClientCredentialsProvider( - authHttpClient, - clientId!, - clientSecret!, - config.Host, - scope: scope ?? "sql", - timeoutMinutes: 1 - ); + return new HandlerResult + { + Handler = resultHandler ?? handler, // Fall back to handler without auth if auth not configured + AuthHttpClient = authHttpClient + }; + } - handler = new OAuthDelegatingHandler(handler, tokenProvider); - } - // For access_token grant type, get the access token from properties - else if (grantType == DatabricksOAuthGrantType.AccessToken) - { - // Get the access token from properties - string accessToken = string.Empty; - if (config.Properties.TryGetValue(SparkParameters.AccessToken, out string? token)) - { - accessToken = token ?? string.Empty; - } - else if (config.Properties.TryGetValue(SparkParameters.Token, out string? fallbackToken)) - { - accessToken = fallbackToken ?? string.Empty; - } + /// + /// Creates an HTTP handler chain specifically for feature flag API calls. + /// This is a simplified version of CreateHandlers that includes auth but not + /// tracing, retry, or Thrift error handlers. + /// + /// Handler chain order (outermost to innermost): + /// 1. OAuth handlers (OAuthDelegatingHandler or StaticBearerTokenHandler) - token management + /// 2. MandatoryTokenExchangeDelegatingHandler (if OAuth) - workload identity federation + /// 3. Base HTTP handler - actual network communication + /// + /// This properly supports all authentication methods including: + /// - PAT (Personal Access Token) + /// - OAuth M2M (client_credentials) + /// - OAuth U2M (access_token) + /// - Workload Identity Federation (via MandatoryTokenExchangeDelegatingHandler) + /// + /// Connection properties containing configuration. + /// The Databricks host (without protocol). + /// HTTP client timeout in seconds. + /// Configured HttpMessageHandler, or null if no valid authentication is available. + public static HttpMessageHandler? CreateFeatureFlagHandler( + IReadOnlyDictionary properties, + string host, + int timeoutSeconds) + { + HttpMessageHandler baseHandler = HttpClientFactory.CreateHandler(properties); - if (!string.IsNullOrEmpty(accessToken)) - { - // Check if token renewal is configured and token is JWT - if (config.Properties.TryGetValue(DatabricksParameters.TokenRenewLimit, out string? tokenRenewLimitStr) && - int.TryParse(tokenRenewLimitStr, out int tokenRenewLimit) && - tokenRenewLimit > 0 && - JwtTokenDecoder.TryGetExpirationTime(accessToken, out DateTime expiryTime)) - { - // Use TokenRefreshDelegatingHandler for JWT tokens with renewal configured - handler = new TokenRefreshDelegatingHandler( - handler, - tokenExchangeClient, - accessToken, - expiryTime, - tokenRenewLimit); - } - else - { - // Use StaticBearerTokenHandler for tokens without renewal - handler = new StaticBearerTokenHandler(handler, accessToken); - } - } - } - } - else + // Create auth HTTP client for OAuth if needed + HttpClient? authHttpClient = null; + if (IsOAuthEnabled(properties)) { - // Non-OAuth authentication: use static Bearer token if provided - // Try access_token first, then fall back to token - string accessToken = string.Empty; - if (config.Properties.TryGetValue(SparkParameters.AccessToken, out string? token)) - { - accessToken = token ?? string.Empty; - } - else if (config.Properties.TryGetValue(SparkParameters.Token, out string? fallbackToken)) - { - accessToken = fallbackToken ?? string.Empty; - } - - if (!string.IsNullOrEmpty(accessToken)) + HttpMessageHandler baseAuthHandler = HttpClientFactory.CreateHandler(properties); + authHttpClient = new HttpClient(baseAuthHandler) { - handler = new StaticBearerTokenHandler(handler, accessToken); - } + Timeout = TimeSpan.FromSeconds(timeoutSeconds) + }; } - return new HandlerResult - { - Handler = handler, - AuthHttpClient = authHttpClient - }; + // Get identity federation client ID + properties.TryGetValue(DatabricksParameters.IdentityFederationClientId, out string? identityFederationClientId); + + // Add auth handlers (no token refresh for feature flags) + return AddAuthHandlers( + baseHandler, + properties, + host, + authHttpClient, + identityFederationClientId, + enableTokenRefresh: false, + out _); } } } diff --git a/csharp/src/StatementExecution/StatementExecutionConnection.cs b/csharp/src/StatementExecution/StatementExecutionConnection.cs index d18a456b..bc2a5a53 100644 --- a/csharp/src/StatementExecution/StatementExecutionConnection.cs +++ b/csharp/src/StatementExecution/StatementExecutionConnection.cs @@ -221,11 +221,8 @@ private StatementExecutionConnection( // Create a separate HTTP client for CloudFetch downloads (without auth headers) // This is needed because CloudFetch uses pre-signed URLs from cloud storage (S3, Azure Blob, etc.) // and those services reject requests with multiple authentication methods - int timeoutMinutes = PropertyHelper.GetPositiveIntPropertyWithValidation(properties, DatabricksParameters.CloudFetchTimeoutMinutes, DatabricksConstants.DefaultCloudFetchTimeoutMinutes); - _cloudFetchHttpClient = new HttpClient() - { - Timeout = TimeSpan.FromMinutes(timeoutMinutes) - }; + // Note: We still need proxy and TLS configuration for corporate network access + _cloudFetchHttpClient = HttpClientFactory.CreateCloudFetchHttpClient(properties); // Create REST API client _client = new StatementExecutionClient(_httpClient, baseUrl); @@ -251,8 +248,8 @@ private HttpClient CreateHttpClient(IReadOnlyDictionary properti var config = new HttpHandlerFactory.HandlerConfig { - BaseHandler = new HttpClientHandler(), - BaseAuthHandler = new HttpClientHandler(), + BaseHandler = HttpClientFactory.CreateHandler(properties), + BaseAuthHandler = HttpClientFactory.CreateHandler(properties), Properties = properties, Host = GetHost(properties), ActivityTracer = this, diff --git a/csharp/src/Telemetry/DatabricksTelemetryExporter.cs b/csharp/src/Telemetry/DatabricksTelemetryExporter.cs new file mode 100644 index 00000000..e0c8d6dd --- /dev/null +++ b/csharp/src/Telemetry/DatabricksTelemetryExporter.cs @@ -0,0 +1,285 @@ +/* +* Copyright (c) 2025 ADBC Drivers Contributors +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Net.Http; +using System.Text; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using AdbcDrivers.Databricks.Telemetry.Models; + +namespace AdbcDrivers.Databricks.Telemetry +{ + /// + /// Exports telemetry events to the Databricks telemetry service. + /// + /// + /// This exporter: + /// - Creates TelemetryRequest wrapper with uploadTime and protoLogs + /// - Uses /telemetry-ext for authenticated requests + /// - Uses /telemetry-unauth for unauthenticated requests + /// - Implements retry logic for transient failures + /// - Never throws exceptions (all swallowed and traced at Verbose level) + /// + /// JDBC Reference: TelemetryPushClient.java + /// + internal sealed class DatabricksTelemetryExporter : ITelemetryExporter + { + /// + /// Authenticated telemetry endpoint path. + /// + internal const string AuthenticatedEndpoint = "/telemetry-ext"; + + /// + /// Unauthenticated telemetry endpoint path. + /// + internal const string UnauthenticatedEndpoint = "/telemetry-unauth"; + + /// + /// Activity source for telemetry exporter tracing. + /// + private static readonly ActivitySource s_activitySource = new ActivitySource("AdbcDrivers.Databricks.TelemetryExporter"); + + private readonly HttpClient _httpClient; + private readonly string _host; + private readonly bool _isAuthenticated; + private readonly TelemetryConfiguration _config; + + private static readonly JsonSerializerOptions s_jsonOptions = new JsonSerializerOptions + { + PropertyNamingPolicy = JsonNamingPolicy.CamelCase + }; + + /// + /// Gets the host URL for the telemetry endpoint. + /// + internal string Host => _host; + + /// + /// Gets whether this exporter uses authenticated endpoints. + /// + internal bool IsAuthenticated => _isAuthenticated; + + /// + /// Creates a new DatabricksTelemetryExporter. + /// + /// The HTTP client to use for sending requests. + /// The Databricks host URL. + /// Whether to use authenticated endpoints. + /// The telemetry configuration. + /// Thrown when httpClient, host, or config is null. + /// Thrown when host is empty or whitespace. + public DatabricksTelemetryExporter( + HttpClient httpClient, + string host, + bool isAuthenticated, + TelemetryConfiguration config) + { + _httpClient = httpClient ?? throw new ArgumentNullException(nameof(httpClient)); + + if (string.IsNullOrWhiteSpace(host)) + { + throw new ArgumentException("Host cannot be null or whitespace.", nameof(host)); + } + + _host = host; + _isAuthenticated = isAuthenticated; + _config = config ?? throw new ArgumentNullException(nameof(config)); + } + + /// + /// Export telemetry frontend logs to the Databricks telemetry service. + /// + /// The list of telemetry frontend logs to export. + /// Cancellation token. + /// + /// True if the export succeeded (HTTP 2xx response), false if it failed. + /// Returns true for empty/null logs since there's nothing to export. + /// + /// + /// This method never throws exceptions. All errors are caught and traced using ActivitySource. + /// + public async Task ExportAsync(IReadOnlyList logs, CancellationToken ct = default) + { + if (logs == null || logs.Count == 0) + { + return true; + } + + try + { + var request = CreateTelemetryRequest(logs); + var json = SerializeRequest(request); + + return await SendWithRetryAsync(json, ct).ConfigureAwait(false); + } + catch (OperationCanceledException) + { + // Don't swallow cancellation - let it propagate + throw; + } + catch (Exception ex) + { + // Swallow all other exceptions per telemetry requirement + // Trace at Verbose level to avoid customer anxiety + Activity.Current?.AddEvent(new ActivityEvent("telemetry.export.error", + tags: new ActivityTagsCollection + { + { "error.message", ex.Message }, + { "error.type", ex.GetType().Name } + })); + return false; + } + } + + /// + /// Creates a TelemetryRequest from a list of frontend logs. + /// + internal TelemetryRequest CreateTelemetryRequest(IReadOnlyList logs) + { + var protoLogs = new List(logs.Count); + + foreach (var log in logs) + { + var serializedLog = JsonSerializer.Serialize(log, s_jsonOptions); + protoLogs.Add(serializedLog); + } + + return new TelemetryRequest + { + UploadTime = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds(), + ProtoLogs = protoLogs + }; + } + + /// + /// Serializes the telemetry request to JSON. + /// + internal string SerializeRequest(TelemetryRequest request) + { + return JsonSerializer.Serialize(request, s_jsonOptions); + } + + /// + /// Gets the telemetry endpoint URL based on authentication status. + /// + internal string GetEndpointUrl() + { + var endpoint = _isAuthenticated ? AuthenticatedEndpoint : UnauthenticatedEndpoint; + var host = _host.TrimEnd('/'); + return $"{host}{endpoint}"; + } + + /// + /// Sends the telemetry request with retry logic. + /// + /// True if the request succeeded, false otherwise. + private async Task SendWithRetryAsync(string json, CancellationToken ct) + { + var endpointUrl = GetEndpointUrl(); + Exception? lastException = null; + + for (int attempt = 0; attempt <= _config.MaxRetries; attempt++) + { + try + { + if (attempt > 0 && _config.RetryDelayMs > 0) + { + await Task.Delay(_config.RetryDelayMs, ct).ConfigureAwait(false); + } + + await SendRequestAsync(endpointUrl, json, ct).ConfigureAwait(false); + + Activity.Current?.AddEvent(new ActivityEvent("telemetry.export.success", + tags: new ActivityTagsCollection + { + { "endpoint", endpointUrl }, + { "attempt", attempt + 1 } + })); + return true; + } + catch (OperationCanceledException) + { + // Don't retry on cancellation + throw; + } + catch (HttpRequestException ex) + { + lastException = ex; + + // Check if this is a terminal error that shouldn't be retried + if (ExceptionClassifier.IsTerminalException(ex)) + { + Activity.Current?.AddEvent(new ActivityEvent("telemetry.export.terminal_error", + tags: new ActivityTagsCollection + { + { "error.message", ex.Message }, + { "error.type", ex.GetType().Name } + })); + return false; + } + + Activity.Current?.AddEvent(new ActivityEvent("telemetry.export.retry", + tags: new ActivityTagsCollection + { + { "attempt", attempt + 1 }, + { "max_attempts", _config.MaxRetries + 1 }, + { "error.message", ex.Message } + })); + } + catch (Exception ex) + { + lastException = ex; + Activity.Current?.AddEvent(new ActivityEvent("telemetry.export.retry", + tags: new ActivityTagsCollection + { + { "attempt", attempt + 1 }, + { "max_attempts", _config.MaxRetries + 1 }, + { "error.message", ex.Message }, + { "error.type", ex.GetType().Name } + })); + } + } + + if (lastException != null) + { + Activity.Current?.AddEvent(new ActivityEvent("telemetry.export.exhausted", + tags: new ActivityTagsCollection + { + { "total_attempts", _config.MaxRetries + 1 }, + { "error.message", lastException.Message }, + { "error.type", lastException.GetType().Name } + })); + } + + return false; + } + + /// + /// Sends the HTTP request to the telemetry endpoint. + /// + private async Task SendRequestAsync(string endpointUrl, string json, CancellationToken ct) + { + using var content = new StringContent(json, Encoding.UTF8, "application/json"); + using var response = await _httpClient.PostAsync(endpointUrl, content, ct).ConfigureAwait(false); + + response.EnsureSuccessStatusCode(); + } + } +} diff --git a/csharp/src/Telemetry/ITelemetryExporter.cs b/csharp/src/Telemetry/ITelemetryExporter.cs new file mode 100644 index 00000000..934a5bba --- /dev/null +++ b/csharp/src/Telemetry/ITelemetryExporter.cs @@ -0,0 +1,53 @@ +/* +* Copyright (c) 2025 ADBC Drivers Contributors +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using AdbcDrivers.Databricks.Telemetry.Models; + +namespace AdbcDrivers.Databricks.Telemetry +{ + /// + /// Interface for exporting telemetry events to a backend service. + /// + /// + /// Implementations of this interface must be safe to call from any context. + /// All methods should be non-blocking and should never throw exceptions + /// (exceptions should be caught and logged at TRACE level internally). + /// This follows the telemetry design principle that telemetry operations + /// should never impact driver operations. + /// + public interface ITelemetryExporter + { + /// + /// Export telemetry frontend logs to the backend service. + /// + /// The list of telemetry frontend logs to export. + /// Cancellation token. + /// + /// A task that resolves to true if the export succeeded (HTTP 2xx response), + /// or false if the export failed or was skipped. Returns true for empty/null logs + /// since there's nothing to export (no failure occurred). + /// + /// + /// This method must never throw exceptions. All errors should be caught + /// and logged at TRACE level internally. The method may return early + /// if the circuit breaker is open or if there are no logs to export. + /// + Task ExportAsync(IReadOnlyList logs, CancellationToken ct = default); + } +} diff --git a/csharp/src/Telemetry/MetricsAggregator.cs b/csharp/src/Telemetry/MetricsAggregator.cs new file mode 100644 index 00000000..29d8e3df --- /dev/null +++ b/csharp/src/Telemetry/MetricsAggregator.cs @@ -0,0 +1,737 @@ +/* +* Copyright (c) 2025 ADBC Drivers Contributors +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using AdbcDrivers.Databricks.Telemetry.Models; +using AdbcDrivers.Databricks.Telemetry.TagDefinitions; + +namespace AdbcDrivers.Databricks.Telemetry +{ + /// + /// Aggregates metrics from activities by statement_id and includes session_id. + /// Follows JDBC driver pattern: aggregation by statement, export with both IDs. + /// + /// + /// This class: + /// - Aggregates Activity data by statement_id using ConcurrentDictionary + /// - Connection events emit immediately (no aggregation needed) + /// - Statement events aggregate until CompleteStatement is called + /// - Terminal exceptions flush immediately, retryable exceptions buffer until complete + /// - Uses TelemetryTagRegistry to filter tags + /// - Creates TelemetryFrontendLog wrapper with workspace_id + /// - All exceptions swallowed (logged at TRACE level) + /// + /// JDBC Reference: TelemetryCollector.java + /// + internal sealed class MetricsAggregator : IDisposable + { + private readonly ITelemetryExporter _exporter; + private readonly TelemetryConfiguration _config; + private readonly long _workspaceId; + private readonly string _userAgent; + private readonly ConcurrentDictionary _statementContexts; + private readonly ConcurrentQueue _pendingEvents; + private readonly object _flushLock = new object(); + private readonly Timer _flushTimer; + private bool _disposed; + + /// + /// Gets the number of pending events waiting to be flushed. + /// + internal int PendingEventCount => _pendingEvents.Count; + + /// + /// Gets the number of active statement contexts being tracked. + /// + internal int ActiveStatementCount => _statementContexts.Count; + + /// + /// Creates a new MetricsAggregator. + /// + /// The telemetry exporter to use for flushing events. + /// The telemetry configuration. + /// The Databricks workspace ID for all events. + /// The user agent string for client context. + /// Thrown when exporter or config is null. + public MetricsAggregator( + ITelemetryExporter exporter, + TelemetryConfiguration config, + long workspaceId, + string? userAgent = null) + { + _exporter = exporter ?? throw new ArgumentNullException(nameof(exporter)); + _config = config ?? throw new ArgumentNullException(nameof(config)); + _workspaceId = workspaceId; + _userAgent = userAgent ?? "AdbcDatabricksDriver"; + _statementContexts = new ConcurrentDictionary(); + _pendingEvents = new ConcurrentQueue(); + + // Start flush timer + _flushTimer = new Timer( + OnFlushTimerTick, + null, + _config.FlushIntervalMs, + _config.FlushIntervalMs); + } + + /// + /// Processes a completed activity and extracts telemetry metrics. + /// + /// The activity to process. + /// + /// This method determines the event type based on the activity operation name: + /// - Connection.* activities emit connection events immediately + /// - Statement.* activities are aggregated by statement_id + /// - Activities with error.type tag are treated as error events + /// + /// All exceptions are swallowed and logged at TRACE level. + /// + public void ProcessActivity(Activity? activity) + { + if (activity == null) + { + return; + } + + try + { + var eventType = DetermineEventType(activity); + + switch (eventType) + { + case TelemetryEventType.ConnectionOpen: + ProcessConnectionActivity(activity); + break; + + case TelemetryEventType.StatementExecution: + ProcessStatementActivity(activity); + break; + + case TelemetryEventType.Error: + ProcessErrorActivity(activity); + break; + } + + // Check if we should flush based on batch size + if (_pendingEvents.Count >= _config.BatchSize) + { + _ = FlushAsync(); + } + } + catch (Exception ex) + { + // Swallow all exceptions per telemetry requirement + // Log at TRACE level to avoid customer anxiety + Debug.WriteLine($"[TRACE] MetricsAggregator: Error processing activity: {ex.Message}"); + } + } + + /// + /// Marks a statement as complete and emits the aggregated telemetry event. + /// + /// The statement ID to complete. + /// Whether the statement execution failed. + /// + /// This method: + /// - Removes the statement context from the aggregation dictionary + /// - Emits the aggregated event with all collected metrics + /// - If failed, also emits any buffered retryable exception events + /// + /// All exceptions are swallowed and logged at TRACE level. + /// + public void CompleteStatement(string statementId, bool failed = false) + { + if (string.IsNullOrEmpty(statementId)) + { + return; + } + + try + { + if (_statementContexts.TryRemove(statementId, out var context)) + { + // Emit the aggregated statement event + var telemetryEvent = CreateTelemetryEvent(context); + var frontendLog = WrapInFrontendLog(telemetryEvent); + _pendingEvents.Enqueue(frontendLog); + + // If statement failed and we have buffered exceptions, emit them + if (failed && context.BufferedExceptions.Count > 0) + { + foreach (var exception in context.BufferedExceptions) + { + var errorEvent = CreateErrorTelemetryEvent( + context.SessionId, + statementId, + exception); + var errorLog = WrapInFrontendLog(errorEvent); + _pendingEvents.Enqueue(errorLog); + } + } + } + } + catch (Exception ex) + { + // Swallow all exceptions per telemetry requirement + Debug.WriteLine($"[TRACE] MetricsAggregator: Error completing statement: {ex.Message}"); + } + } + + /// + /// Records an exception for a statement. + /// + /// The statement ID. + /// The session ID. + /// The exception to record. + /// + /// Terminal exceptions are flushed immediately. + /// Retryable exceptions are buffered until CompleteStatement is called. + /// + /// All exceptions are swallowed and logged at TRACE level. + /// + public void RecordException(string statementId, string sessionId, Exception? exception) + { + if (exception == null || string.IsNullOrEmpty(statementId)) + { + return; + } + + try + { + if (ExceptionClassifier.IsTerminalException(exception)) + { + // Terminal exception: flush immediately + var errorEvent = CreateErrorTelemetryEvent(sessionId, statementId, exception); + var errorLog = WrapInFrontendLog(errorEvent); + _pendingEvents.Enqueue(errorLog); + + Debug.WriteLine($"[TRACE] MetricsAggregator: Terminal exception recorded, flushing immediately"); + } + else + { + // Retryable exception: buffer until statement completes + var context = _statementContexts.GetOrAdd( + statementId, + _ => new StatementTelemetryContext(statementId, sessionId)); + context.BufferedExceptions.Add(exception); + + Debug.WriteLine($"[TRACE] MetricsAggregator: Retryable exception buffered for statement {statementId}"); + } + } + catch (Exception ex) + { + // Swallow all exceptions per telemetry requirement + Debug.WriteLine($"[TRACE] MetricsAggregator: Error recording exception: {ex.Message}"); + } + } + + /// + /// Flushes all pending telemetry events to the exporter. + /// + /// Cancellation token. + /// A task representing the flush operation. + /// + /// This method is thread-safe and uses a lock to prevent concurrent flushes. + /// All exceptions are swallowed and logged at TRACE level. + /// + public async Task FlushAsync(CancellationToken ct = default) + { + try + { + List eventsToFlush; + + lock (_flushLock) + { + if (_pendingEvents.IsEmpty) + { + return; + } + + eventsToFlush = new List(); + while (_pendingEvents.TryDequeue(out var eventLog)) + { + eventsToFlush.Add(eventLog); + } + } + + if (eventsToFlush.Count > 0) + { + Debug.WriteLine($"[TRACE] MetricsAggregator: Flushing {eventsToFlush.Count} events"); + await _exporter.ExportAsync(eventsToFlush, ct).ConfigureAwait(false); + } + } + catch (OperationCanceledException) + { + // Don't swallow cancellation + throw; + } + catch (Exception ex) + { + // Swallow all other exceptions per telemetry requirement + Debug.WriteLine($"[TRACE] MetricsAggregator: Error flushing events: {ex.Message}"); + } + } + + /// + /// Disposes the MetricsAggregator and flushes any remaining events. + /// + public void Dispose() + { + if (_disposed) + { + return; + } + + _disposed = true; + + try + { + _flushTimer.Dispose(); + + // Final flush + FlushAsync().GetAwaiter().GetResult(); + } + catch (Exception ex) + { + Debug.WriteLine($"[TRACE] MetricsAggregator: Error during dispose: {ex.Message}"); + } + } + + #region Private Methods + + /// + /// Timer callback for periodic flushing. + /// + private void OnFlushTimerTick(object? state) + { + if (_disposed) + { + return; + } + + try + { + _ = FlushAsync(); + } + catch (Exception ex) + { + Debug.WriteLine($"[TRACE] MetricsAggregator: Error in flush timer: {ex.Message}"); + } + } + + /// + /// Determines the telemetry event type based on the activity. + /// + private static TelemetryEventType DetermineEventType(Activity activity) + { + // Check for errors first + if (activity.GetTagItem("error.type") != null) + { + return TelemetryEventType.Error; + } + + // Map based on operation name + var operationName = activity.OperationName ?? string.Empty; + + if (operationName.StartsWith("Connection.", StringComparison.OrdinalIgnoreCase) || + operationName.Equals("OpenConnection", StringComparison.OrdinalIgnoreCase) || + operationName.Equals("OpenAsync", StringComparison.OrdinalIgnoreCase)) + { + return TelemetryEventType.ConnectionOpen; + } + + // Default to statement execution for statement operations and others + return TelemetryEventType.StatementExecution; + } + + /// + /// Processes a connection activity and emits it immediately. + /// + private void ProcessConnectionActivity(Activity activity) + { + var sessionId = GetTagValue(activity, "session.id"); + var telemetryEvent = new TelemetryEvent + { + SessionId = sessionId, + OperationLatencyMs = (long)activity.Duration.TotalMilliseconds, + SystemConfiguration = ExtractSystemConfiguration(activity), + ConnectionParameters = ExtractConnectionParameters(activity) + }; + + var frontendLog = WrapInFrontendLog(telemetryEvent); + _pendingEvents.Enqueue(frontendLog); + + Debug.WriteLine($"[TRACE] MetricsAggregator: Connection event emitted immediately"); + } + + /// + /// Processes a statement activity and aggregates it by statement_id. + /// + private void ProcessStatementActivity(Activity activity) + { + var statementId = GetTagValue(activity, "statement.id"); + var sessionId = GetTagValue(activity, "session.id"); + + if (string.IsNullOrEmpty(statementId)) + { + // No statement ID, cannot aggregate - emit immediately + var telemetryEvent = new TelemetryEvent + { + SessionId = sessionId, + OperationLatencyMs = (long)activity.Duration.TotalMilliseconds, + SqlExecutionEvent = ExtractSqlExecutionEvent(activity) + }; + + var frontendLog = WrapInFrontendLog(telemetryEvent); + _pendingEvents.Enqueue(frontendLog); + return; + } + + // Get or create context for this statement + var context = _statementContexts.GetOrAdd( + statementId, + _ => new StatementTelemetryContext(statementId, sessionId)); + + // Aggregate metrics + context.AddLatency((long)activity.Duration.TotalMilliseconds); + AggregateActivityTags(context, activity); + } + + /// + /// Processes an error activity and emits it immediately. + /// + private void ProcessErrorActivity(Activity activity) + { + var statementId = GetTagValue(activity, "statement.id"); + var sessionId = GetTagValue(activity, "session.id"); + var errorType = GetTagValue(activity, "error.type"); + var errorMessage = GetTagValue(activity, "error.message"); + var errorCode = GetTagValue(activity, "error.code"); + + var telemetryEvent = new TelemetryEvent + { + SessionId = sessionId, + SqlStatementId = statementId, + OperationLatencyMs = (long)activity.Duration.TotalMilliseconds, + ErrorInfo = new DriverErrorInfo + { + ErrorType = errorType, + ErrorMessage = errorMessage, + ErrorCode = errorCode + } + }; + + var frontendLog = WrapInFrontendLog(telemetryEvent); + _pendingEvents.Enqueue(frontendLog); + + Debug.WriteLine($"[TRACE] MetricsAggregator: Error event emitted immediately"); + } + + /// + /// Aggregates activity tags into the statement context. + /// + private void AggregateActivityTags(StatementTelemetryContext context, Activity activity) + { + var eventType = TelemetryEventType.StatementExecution; + + foreach (var tag in activity.Tags) + { + // Filter using TelemetryTagRegistry + if (!TelemetryTagRegistry.ShouldExportToDatabricks(eventType, tag.Key)) + { + continue; + } + + switch (tag.Key) + { + case "result.format": + context.ResultFormat = tag.Value; + break; + case "result.chunk_count": + if (int.TryParse(tag.Value, out int chunkCount)) + { + context.ChunkCount = (context.ChunkCount ?? 0) + chunkCount; + } + break; + case "result.bytes_downloaded": + if (long.TryParse(tag.Value, out long bytesDownloaded)) + { + context.BytesDownloaded = (context.BytesDownloaded ?? 0) + bytesDownloaded; + } + break; + case "result.compression_enabled": + if (bool.TryParse(tag.Value, out bool compressionEnabled)) + { + context.CompressionEnabled = compressionEnabled; + } + break; + case "result.row_count": + if (long.TryParse(tag.Value, out long rowCount)) + { + context.RowCount = rowCount; + } + break; + case "poll.count": + if (int.TryParse(tag.Value, out int pollCount)) + { + context.PollCount = (context.PollCount ?? 0) + pollCount; + } + break; + case "poll.latency_ms": + if (long.TryParse(tag.Value, out long pollLatencyMs)) + { + context.PollLatencyMs = (context.PollLatencyMs ?? 0) + pollLatencyMs; + } + break; + case "execution.status": + context.ExecutionStatus = tag.Value; + break; + case "statement.type": + context.StatementType = tag.Value; + break; + } + } + } + + /// + /// Creates a TelemetryEvent from a statement context. + /// + private TelemetryEvent CreateTelemetryEvent(StatementTelemetryContext context) + { + return new TelemetryEvent + { + SessionId = context.SessionId, + SqlStatementId = context.StatementId, + OperationLatencyMs = context.TotalLatencyMs, + SqlExecutionEvent = new SqlExecutionEvent + { + ResultFormat = context.ResultFormat, + ChunkCount = context.ChunkCount, + BytesDownloaded = context.BytesDownloaded, + CompressionEnabled = context.CompressionEnabled, + RowCount = context.RowCount, + PollCount = context.PollCount, + PollLatencyMs = context.PollLatencyMs, + ExecutionStatus = context.ExecutionStatus, + StatementType = context.StatementType + } + }; + } + + /// + /// Creates an error TelemetryEvent from an exception. + /// + private TelemetryEvent CreateErrorTelemetryEvent( + string? sessionId, + string statementId, + Exception exception) + { + var isTerminal = ExceptionClassifier.IsTerminalException(exception); + int? httpStatusCode = null; + +#if NET5_0_OR_GREATER + if (exception is System.Net.Http.HttpRequestException httpEx && httpEx.StatusCode.HasValue) + { + httpStatusCode = (int)httpEx.StatusCode.Value; + } +#endif + + return new TelemetryEvent + { + SessionId = sessionId, + SqlStatementId = statementId, + ErrorInfo = new DriverErrorInfo + { + ErrorType = exception.GetType().Name, + ErrorMessage = SanitizeErrorMessage(exception.Message), + IsTerminal = isTerminal, + HttpStatusCode = httpStatusCode + } + }; + } + + /// + /// Sanitizes an error message to remove potential PII. + /// + private static string SanitizeErrorMessage(string? message) + { + if (string.IsNullOrEmpty(message)) + { + return string.Empty; + } + + // Truncate long messages + const int maxLength = 200; + if (message.Length > maxLength) + { + message = message.Substring(0, maxLength) + "..."; + } + + return message; + } + + /// + /// Wraps a TelemetryEvent in a TelemetryFrontendLog. + /// + internal TelemetryFrontendLog WrapInFrontendLog(TelemetryEvent telemetryEvent) + { + return new TelemetryFrontendLog + { + WorkspaceId = _workspaceId, + FrontendLogEventId = Guid.NewGuid().ToString(), + Context = new FrontendLogContext + { + ClientContext = new TelemetryClientContext + { + UserAgent = _userAgent + }, + TimestampMillis = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds() + }, + Entry = new FrontendLogEntry + { + SqlDriverLog = telemetryEvent + } + }; + } + + /// + /// Extracts system configuration from activity tags. + /// + private static DriverSystemConfiguration? ExtractSystemConfiguration(Activity activity) + { + var driverVersion = GetTagValue(activity, "driver.version"); + var osName = GetTagValue(activity, "driver.os"); + var runtime = GetTagValue(activity, "driver.runtime"); + + if (driverVersion == null && osName == null && runtime == null) + { + return null; + } + + return new DriverSystemConfiguration + { + DriverName = "Databricks ADBC Driver", + DriverVersion = driverVersion, + OsName = osName, + RuntimeName = ".NET", + RuntimeVersion = runtime + }; + } + + /// + /// Extracts connection parameters from activity tags. + /// + private static DriverConnectionParameters? ExtractConnectionParameters(Activity activity) + { + var cloudFetchEnabled = GetTagValue(activity, "feature.cloudfetch"); + var lz4Enabled = GetTagValue(activity, "feature.lz4"); + var directResultsEnabled = GetTagValue(activity, "feature.direct_results"); + + if (cloudFetchEnabled == null && lz4Enabled == null && directResultsEnabled == null) + { + return null; + } + + return new DriverConnectionParameters + { + CloudFetchEnabled = bool.TryParse(cloudFetchEnabled, out var cf) ? cf : (bool?)null, + Lz4CompressionEnabled = bool.TryParse(lz4Enabled, out var lz4) ? lz4 : (bool?)null, + DirectResultsEnabled = bool.TryParse(directResultsEnabled, out var dr) ? dr : (bool?)null + }; + } + + /// + /// Extracts SQL execution event details from activity tags. + /// + private static SqlExecutionEvent? ExtractSqlExecutionEvent(Activity activity) + { + var resultFormat = GetTagValue(activity, "result.format"); + var chunkCountStr = GetTagValue(activity, "result.chunk_count"); + var bytesDownloadedStr = GetTagValue(activity, "result.bytes_downloaded"); + var pollCountStr = GetTagValue(activity, "poll.count"); + + if (resultFormat == null && chunkCountStr == null && bytesDownloadedStr == null && pollCountStr == null) + { + return null; + } + + return new SqlExecutionEvent + { + ResultFormat = resultFormat, + ChunkCount = int.TryParse(chunkCountStr, out var cc) ? cc : (int?)null, + BytesDownloaded = long.TryParse(bytesDownloadedStr, out var bd) ? bd : (long?)null, + PollCount = int.TryParse(pollCountStr, out var pc) ? pc : (int?)null + }; + } + + /// + /// Gets a tag value from an activity. + /// + private static string? GetTagValue(Activity activity, string tagName) + { + return activity.GetTagItem(tagName)?.ToString(); + } + + #endregion + + #region Nested Classes + + /// + /// Holds aggregated telemetry data for a statement. + /// + internal sealed class StatementTelemetryContext + { + private long _totalLatencyMs; + + public string StatementId { get; } + public string? SessionId { get; set; } + public long TotalLatencyMs => Interlocked.Read(ref _totalLatencyMs); + + // Aggregated metrics + public string? ResultFormat { get; set; } + public int? ChunkCount { get; set; } + public long? BytesDownloaded { get; set; } + public bool? CompressionEnabled { get; set; } + public long? RowCount { get; set; } + public int? PollCount { get; set; } + public long? PollLatencyMs { get; set; } + public string? ExecutionStatus { get; set; } + public string? StatementType { get; set; } + + // Buffered exceptions for retryable errors + public ConcurrentBag BufferedExceptions { get; } = new ConcurrentBag(); + + public StatementTelemetryContext(string statementId, string? sessionId) + { + StatementId = statementId ?? throw new ArgumentNullException(nameof(statementId)); + SessionId = sessionId; + } + + public void AddLatency(long latencyMs) + { + Interlocked.Add(ref _totalLatencyMs, latencyMs); + } + } + + #endregion + } +} diff --git a/csharp/src/Telemetry/TelemetryConfiguration.cs b/csharp/src/Telemetry/TelemetryConfiguration.cs index d2b1732c..a66dbb5c 100644 --- a/csharp/src/Telemetry/TelemetryConfiguration.cs +++ b/csharp/src/Telemetry/TelemetryConfiguration.cs @@ -84,6 +84,13 @@ public sealed class TelemetryConfiguration /// public const string FeatureFlagName = "databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetryForAdbc"; + /// + /// Feature flag endpoint format (relative to host). + /// {0} = driver version without OSS suffix. + /// NOTE: Using OSS_JDBC endpoint until OSS_ADBC is configured server-side. + /// + public const string FeatureFlagEndpointFormat = "/api/2.0/connector-service/feature-flags/OSS_JDBC/{0}"; + /// /// Gets or sets whether telemetry is enabled. /// Default is true. diff --git a/csharp/test/E2E/FeatureFlagCacheE2ETest.cs b/csharp/test/E2E/FeatureFlagCacheE2ETest.cs new file mode 100644 index 00000000..1c674efe --- /dev/null +++ b/csharp/test/E2E/FeatureFlagCacheE2ETest.cs @@ -0,0 +1,237 @@ +/* +* Copyright (c) 2025 ADBC Drivers Contributors +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Generic; +using System.Threading.Tasks; +using Apache.Arrow.Adbc.Tests; +using Xunit; +using Xunit.Abstractions; + +namespace AdbcDrivers.Databricks.Tests +{ + /// + /// End-to-end tests for the FeatureFlagCache functionality using a real Databricks instance. + /// Tests that feature flags are properly fetched and cached from the Databricks connector service. + /// + public class FeatureFlagCacheE2ETest : TestBase + { + public FeatureFlagCacheE2ETest(ITestOutputHelper? outputHelper) + : base(outputHelper, new DatabricksTestEnvironment.Factory()) + { + // Skip the test if the DATABRICKS_TEST_CONFIG_FILE environment variable is not set + Skip.IfNot(Utils.CanExecuteTestConfig(TestConfigVariable)); + } + + /// + /// Tests that creating a connection successfully initializes the feature flag cache + /// and verifies that flags are actually fetched from the server. + /// + [SkippableFact] + public async Task TestFeatureFlagCacheInitialization() + { + // Arrange + var cache = FeatureFlagCache.GetInstance(); + var hostName = GetNormalizedHostName(); + Skip.If(string.IsNullOrEmpty(hostName), "Cannot determine host name from test configuration"); + + // Act - Create a connection which initializes the feature flag cache + using var connection = NewConnection(TestConfiguration); + + // Assert - The connection should be created successfully + Assert.NotNull(connection); + + // Verify the feature flag context exists for this host + Assert.True(cache.TryGetContext(hostName!, out var context), "Feature flag context should exist after connection creation"); + Assert.NotNull(context); + + // Verify that some flags were fetched from the server + // The server should return at least some feature flags + var flags = context.GetAllFlags(); + OutputHelper?.WriteLine($"[FeatureFlagCacheE2ETest] Fetched {flags.Count} feature flags from server"); + foreach (var flag in flags) + { + OutputHelper?.WriteLine($" - {flag.Key}: {flag.Value}"); + } + + // Log the TTL from the server + OutputHelper?.WriteLine($"[FeatureFlagCacheE2ETest] TTL from server: {context.Ttl.TotalSeconds} seconds"); + + // Note: We don't assert flags.Count > 0 because the server may return empty flags + // in some environments, but we verify the infrastructure works + + // Execute a simple query to verify the connection works + using var statement = connection.CreateStatement(); + statement.SqlQuery = "SELECT 1 as test_value"; + var result = await statement.ExecuteQueryAsync(); + + Assert.NotNull(result.Stream); + var batch = await result.Stream.ReadNextRecordBatchAsync(); + Assert.NotNull(batch); + Assert.Equal(1, batch.Length); + + OutputHelper?.WriteLine("[FeatureFlagCacheE2ETest] Connection with feature flag cache initialized successfully"); + } + + /// + /// Tests that multiple connections to the same host share the same feature flag context. + /// This verifies the per-host caching behavior. + /// + [SkippableFact] + public async Task TestFeatureFlagCacheSharedAcrossConnections() + { + // Arrange - Get the singleton cache instance + var cache = FeatureFlagCache.GetInstance(); + + // Act - Create two connections to the same host + using var connection1 = NewConnection(TestConfiguration); + using var connection2 = NewConnection(TestConfiguration); + + // Assert - Both connections should work properly + Assert.NotNull(connection1); + Assert.NotNull(connection2); + + // Verify both connections can execute queries + using var statement1 = connection1.CreateStatement(); + statement1.SqlQuery = "SELECT 1 as conn1_test"; + var result1 = await statement1.ExecuteQueryAsync(); + Assert.NotNull(result1.Stream); + + using var statement2 = connection2.CreateStatement(); + statement2.SqlQuery = "SELECT 2 as conn2_test"; + var result2 = await statement2.ExecuteQueryAsync(); + Assert.NotNull(result2.Stream); + + OutputHelper?.WriteLine("[FeatureFlagCacheE2ETest] Multiple connections sharing feature flag cache work correctly"); + } + + /// + /// Tests that the feature flag cache persists after connections close (TTL-based expiration). + /// With the IMemoryCache implementation, contexts stay in cache until TTL expires, + /// not when connections close. + /// + [SkippableFact] + public async Task TestFeatureFlagCachePersistsAfterConnectionClose() + { + // Arrange + var cache = FeatureFlagCache.GetInstance(); + var hostName = GetNormalizedHostName(); + Skip.If(string.IsNullOrEmpty(hostName), "Cannot determine host name from test configuration"); + + OutputHelper?.WriteLine($"[FeatureFlagCacheE2ETest] Initial cache count: {cache.CachedHostCount}"); + + // Act - Create and close a connection + using (var connection = NewConnection(TestConfiguration)) + { + // Connection is active, cache should have a context for this host + Assert.NotNull(connection); + + // Verify context exists during connection + Assert.True(cache.TryGetContext(hostName!, out var context), "Context should exist while connection is active"); + Assert.NotNull(context); + + // Verify flags were fetched + var flags = context.GetAllFlags(); + OutputHelper?.WriteLine($"[FeatureFlagCacheE2ETest] Flags fetched: {flags.Count}"); + OutputHelper?.WriteLine($"[FeatureFlagCacheE2ETest] TTL: {context.Ttl.TotalSeconds} seconds"); + + // Execute a query to ensure the connection is fully initialized + using var statement = connection.CreateStatement(); + statement.SqlQuery = "SELECT 1"; + var result = await statement.ExecuteQueryAsync(); + Assert.NotNull(result.Stream); + } + // Connection is disposed here + + // With TTL-based caching, the context should still exist in the cache + // (it only gets removed when TTL expires or cache is explicitly cleared) + Assert.True(cache.TryGetContext(hostName!, out var contextAfterDispose), + "Context should still exist in cache after connection close (TTL-based expiration)"); + Assert.NotNull(contextAfterDispose); + + OutputHelper?.WriteLine($"[FeatureFlagCacheE2ETest] Context still exists after connection close with TTL: {contextAfterDispose.Ttl.TotalSeconds}s"); + OutputHelper?.WriteLine($"[FeatureFlagCacheE2ETest] Final cache count: {cache.CachedHostCount}"); + OutputHelper?.WriteLine("[FeatureFlagCacheE2ETest] Feature flag cache TTL-based persistence test completed"); + } + + /// + /// Tests that connections work correctly with feature flags enabled. + /// This is a basic sanity check that the feature flag infrastructure doesn't + /// interfere with normal connection operations. + /// + [SkippableFact] + public async Task TestConnectionWithFeatureFlagsExecutesQueries() + { + // Arrange + using var connection = NewConnection(TestConfiguration); + + // Act - Execute multiple queries to ensure feature flags don't interfere + var queries = new[] + { + "SELECT 1 as value", + "SELECT 'hello' as greeting", + "SELECT CURRENT_DATE() as today" + }; + + foreach (var query in queries) + { + using var statement = connection.CreateStatement(); + statement.SqlQuery = query; + var result = await statement.ExecuteQueryAsync(); + + // Assert + Assert.NotNull(result.Stream); + var batch = await result.Stream.ReadNextRecordBatchAsync(); + Assert.NotNull(batch); + Assert.True(batch.Length > 0, $"Query '{query}' should return at least one row"); + + OutputHelper?.WriteLine($"[FeatureFlagCacheE2ETest] Query executed successfully: {query}"); + } + } + + /// + /// Gets the normalized host name from test configuration. + /// Strips protocol prefix if present (e.g., "https://host" -> "host"). + /// + private string? GetNormalizedHostName() + { + var hostName = TestConfiguration.HostName ?? TestConfiguration.Uri; + if (string.IsNullOrEmpty(hostName)) + { + return null; + } + + // Try to parse as URI first + if (Uri.TryCreate(hostName, UriKind.Absolute, out Uri? parsedUri) && + (parsedUri.Scheme == Uri.UriSchemeHttp || parsedUri.Scheme == Uri.UriSchemeHttps)) + { + return parsedUri.Host; + } + + // Fallback: strip common protocol prefixes manually + if (hostName.StartsWith("https://", StringComparison.OrdinalIgnoreCase)) + { + return hostName.Substring(8); + } + if (hostName.StartsWith("http://", StringComparison.OrdinalIgnoreCase)) + { + return hostName.Substring(7); + } + + return hostName; + } + } +} diff --git a/csharp/test/E2E/Telemetry/ClientTelemetryE2ETests.cs b/csharp/test/E2E/Telemetry/ClientTelemetryE2ETests.cs new file mode 100644 index 00000000..e4bca1e1 --- /dev/null +++ b/csharp/test/E2E/Telemetry/ClientTelemetryE2ETests.cs @@ -0,0 +1,491 @@ +/* +* Copyright (c) 2025 ADBC Drivers Contributors +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Generic; +using System.Net.Http; +using System.Net.Http.Headers; +using System.Threading.Tasks; +using AdbcDrivers.Databricks.Telemetry; +using AdbcDrivers.Databricks.Telemetry.Models; +using Apache.Arrow.Adbc.Tests; +using Xunit; +using Xunit.Abstractions; + +namespace AdbcDrivers.Databricks.Tests.E2E.Telemetry +{ + /// + /// End-to-end tests for client telemetry that sends telemetry to Databricks endpoint. + /// These tests verify that the DatabricksTelemetryExporter can successfully send + /// telemetry events to real Databricks telemetry endpoints. + /// + public class ClientTelemetryE2ETests : TestBase + { + public ClientTelemetryE2ETests(ITestOutputHelper? outputHelper) + : base(outputHelper, new DatabricksTestEnvironment.Factory()) + { + } + + /// + /// Tests that telemetry can be sent to the authenticated endpoint (/telemetry-ext). + /// This endpoint requires a valid authentication token. + /// + [SkippableFact] + public async Task CanSendTelemetryToAuthenticatedEndpoint() + { + // Skip if no token is available + Skip.If(string.IsNullOrEmpty(TestConfiguration.Token) && string.IsNullOrEmpty(TestConfiguration.AccessToken), + "Token is required for authenticated telemetry endpoint test"); + + var host = GetDatabricksHost(); + Skip.If(string.IsNullOrEmpty(host), "Databricks host is required"); + + OutputHelper?.WriteLine($"Testing authenticated telemetry endpoint at {host}/telemetry-ext"); + + // Create HttpClient with authentication + using var httpClient = CreateAuthenticatedHttpClient(); + + var config = new TelemetryConfiguration + { + MaxRetries = 2, + RetryDelayMs = 100 + }; + + var exporter = new DatabricksTelemetryExporter(httpClient, host, isAuthenticated: true, config); + + // Verify endpoint URL + var endpointUrl = exporter.GetEndpointUrl(); + Assert.Equal($"{host}/telemetry-ext", endpointUrl); + OutputHelper?.WriteLine($"Endpoint URL: {endpointUrl}"); + + // Create a test telemetry log + var logs = CreateTestTelemetryLogs(1); + + // Send telemetry - should succeed and return true + var success = await exporter.ExportAsync(logs); + + // ExportAsync should return true indicating HTTP 200 response + Assert.True(success, "ExportAsync should return true indicating successful HTTP 200 response"); + OutputHelper?.WriteLine("Successfully sent telemetry to authenticated endpoint"); + } + + /// + /// Tests that telemetry can be sent to the unauthenticated endpoint (/telemetry-unauth). + /// This endpoint does not require authentication. + /// + [SkippableFact] + public async Task CanSendTelemetryToUnauthenticatedEndpoint() + { + var host = GetDatabricksHost(); + Skip.If(string.IsNullOrEmpty(host), "Databricks host is required"); + + OutputHelper?.WriteLine($"Testing unauthenticated telemetry endpoint at {host}/telemetry-unauth"); + + // Create HttpClient without authentication + using var httpClient = new HttpClient(); + + var config = new TelemetryConfiguration + { + MaxRetries = 2, + RetryDelayMs = 100 + }; + + var exporter = new DatabricksTelemetryExporter(httpClient, host, isAuthenticated: false, config); + + // Verify endpoint URL + var endpointUrl = exporter.GetEndpointUrl(); + Assert.Equal($"{host}/telemetry-unauth", endpointUrl); + OutputHelper?.WriteLine($"Endpoint URL: {endpointUrl}"); + + // Create a test telemetry log + var logs = CreateTestTelemetryLogs(1); + + // Send telemetry - should succeed and return true + var success = await exporter.ExportAsync(logs); + + // ExportAsync should return true indicating HTTP 200 response + Assert.True(success, "ExportAsync should return true indicating successful HTTP 200 response"); + OutputHelper?.WriteLine("Successfully sent telemetry to unauthenticated endpoint"); + } + + /// + /// Tests that multiple telemetry logs can be batched and sent together. + /// + [SkippableFact] + public async Task CanSendBatchedTelemetryLogs() + { + Skip.If(string.IsNullOrEmpty(TestConfiguration.Token) && string.IsNullOrEmpty(TestConfiguration.AccessToken), + "Token is required for authenticated telemetry endpoint test"); + + var host = GetDatabricksHost(); + Skip.If(string.IsNullOrEmpty(host), "Databricks host is required"); + + OutputHelper?.WriteLine($"Testing batched telemetry to {host}"); + + using var httpClient = CreateAuthenticatedHttpClient(); + + var config = new TelemetryConfiguration + { + MaxRetries = 2, + RetryDelayMs = 100 + }; + + var exporter = new DatabricksTelemetryExporter(httpClient, host, isAuthenticated: true, config); + + // Create multiple telemetry logs + var logs = CreateTestTelemetryLogs(5); + OutputHelper?.WriteLine($"Created {logs.Count} telemetry logs for batch send"); + + // Send telemetry - should succeed and return true + var success = await exporter.ExportAsync(logs); + + Assert.True(success, "ExportAsync should return true for batched telemetry"); + OutputHelper?.WriteLine("Successfully sent batched telemetry logs"); + } + + /// + /// Tests that the telemetry request is properly formatted. + /// + [SkippableFact] + public void TelemetryRequestIsProperlyFormatted() + { + var host = GetDatabricksHost(); + Skip.If(string.IsNullOrEmpty(host), "Databricks host is required"); + + using var httpClient = new HttpClient(); + var config = new TelemetryConfiguration(); + var exporter = new DatabricksTelemetryExporter(httpClient, host, isAuthenticated: true, config); + + // Create test logs + var logs = CreateTestTelemetryLogs(2); + + // Create the request + var request = exporter.CreateTelemetryRequest(logs); + + // Verify request structure + Assert.True(request.UploadTime > 0, "UploadTime should be a positive timestamp"); + Assert.Equal(2, request.ProtoLogs.Count); + + // Verify each log is serialized as JSON + foreach (var protoLog in request.ProtoLogs) + { + Assert.NotEmpty(protoLog); + Assert.Contains("workspace_id", protoLog); + Assert.Contains("frontend_log_event_id", protoLog); + OutputHelper?.WriteLine($"Serialized log: {protoLog}"); + } + + // Verify the full request serialization + var json = exporter.SerializeRequest(request); + Assert.Contains("uploadTime", json); + Assert.Contains("protoLogs", json); + OutputHelper?.WriteLine($"Full request JSON: {json}"); + } + + /// + /// Tests telemetry with a complete TelemetryFrontendLog including all nested objects. + /// + [SkippableFact] + public async Task CanSendCompleteTelemetryEvent() + { + Skip.If(string.IsNullOrEmpty(TestConfiguration.Token) && string.IsNullOrEmpty(TestConfiguration.AccessToken), + "Token is required for authenticated telemetry endpoint test"); + + var host = GetDatabricksHost(); + Skip.If(string.IsNullOrEmpty(host), "Databricks host is required"); + + OutputHelper?.WriteLine($"Testing complete telemetry event to {host}"); + + using var httpClient = CreateAuthenticatedHttpClient(); + + var config = new TelemetryConfiguration + { + MaxRetries = 2, + RetryDelayMs = 100 + }; + + var exporter = new DatabricksTelemetryExporter(httpClient, host, isAuthenticated: true, config); + + // Create a complete telemetry log with all fields populated + var log = new TelemetryFrontendLog + { + WorkspaceId = 12345678901234, + FrontendLogEventId = Guid.NewGuid().ToString(), + Context = new FrontendLogContext + { + TimestampMillis = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds(), + ClientContext = new TelemetryClientContext + { + UserAgent = "AdbcDatabricksDriver/1.0.0-test (.NET; E2E Test)" + } + }, + Entry = new FrontendLogEntry + { + SqlDriverLog = new TelemetryEvent + { + SessionId = Guid.NewGuid().ToString(), + SqlStatementId = Guid.NewGuid().ToString(), + OperationLatencyMs = 150, + SystemConfiguration = new DriverSystemConfiguration + { + DriverName = "Databricks ADBC Driver", + DriverVersion = "1.0.0-test", + OsName = Environment.OSVersion.Platform.ToString(), + OsVersion = Environment.OSVersion.Version.ToString(), + RuntimeName = ".NET", + RuntimeVersion = Environment.Version.ToString(), + Locale = System.Globalization.CultureInfo.CurrentCulture.Name, + Timezone = TimeZoneInfo.Local.Id + } + } + } + }; + + var logs = new List { log }; + + OutputHelper?.WriteLine($"Sending complete telemetry event:"); + OutputHelper?.WriteLine($" WorkspaceId: {log.WorkspaceId}"); + OutputHelper?.WriteLine($" EventId: {log.FrontendLogEventId}"); + OutputHelper?.WriteLine($" SessionId: {log.Entry?.SqlDriverLog?.SessionId}"); + + // Send telemetry - should succeed and return true + var success = await exporter.ExportAsync(logs); + + Assert.True(success, "ExportAsync should return true for complete telemetry event"); + OutputHelper?.WriteLine("Successfully sent complete telemetry event"); + } + + /// + /// Tests that empty log lists are handled gracefully without making HTTP requests. + /// + [SkippableFact] + public async Task EmptyLogListDoesNotSendRequest() + { + var host = GetDatabricksHost(); + Skip.If(string.IsNullOrEmpty(host), "Databricks host is required"); + + using var httpClient = new HttpClient(); + var config = new TelemetryConfiguration(); + var exporter = new DatabricksTelemetryExporter(httpClient, host, isAuthenticated: true, config); + + // Send empty list - should return immediately with true (nothing to export) + var success = await exporter.ExportAsync(new List()); + + Assert.True(success, "ExportAsync should return true for empty list (nothing to export)"); + OutputHelper?.WriteLine("Empty log list handled gracefully"); + + // Send null list - should also return immediately with true (nothing to export) + success = await exporter.ExportAsync(null!); + + Assert.True(success, "ExportAsync should return true for null list (nothing to export)"); + OutputHelper?.WriteLine("Null log list handled gracefully"); + } + + /// + /// Tests that the exporter properly retries on transient failures. + /// This test verifies the retry configuration is respected. + /// + [SkippableFact] + public async Task TelemetryExporterRespectsRetryConfiguration() + { + Skip.If(string.IsNullOrEmpty(TestConfiguration.Token) && string.IsNullOrEmpty(TestConfiguration.AccessToken), + "Token is required for authenticated telemetry endpoint test"); + + var host = GetDatabricksHost(); + Skip.If(string.IsNullOrEmpty(host), "Databricks host is required"); + + OutputHelper?.WriteLine("Testing retry configuration"); + + using var httpClient = CreateAuthenticatedHttpClient(); + + // Configure with specific retry settings + var config = new TelemetryConfiguration + { + MaxRetries = 3, + RetryDelayMs = 50 + }; + + var exporter = new DatabricksTelemetryExporter(httpClient, host, isAuthenticated: true, config); + + var logs = CreateTestTelemetryLogs(1); + + // This should succeed and return true + var success = await exporter.ExportAsync(logs); + + // Exporter should return true on success + Assert.True(success, "ExportAsync should return true with retry configuration"); + OutputHelper?.WriteLine("Retry configuration test completed"); + } + + /// + /// Tests that the authenticated telemetry endpoint returns HTTP 200. + /// This test directly calls the endpoint to verify the response status code. + /// + [SkippableFact] + public async Task AuthenticatedEndpointReturnsHttp200() + { + Skip.If(string.IsNullOrEmpty(TestConfiguration.Token) && string.IsNullOrEmpty(TestConfiguration.AccessToken), + "Token is required for authenticated telemetry endpoint test"); + + var host = GetDatabricksHost(); + Skip.If(string.IsNullOrEmpty(host), "Databricks host is required"); + + var endpointUrl = $"{host}/telemetry-ext"; + OutputHelper?.WriteLine($"Testing HTTP response from {endpointUrl}"); + + using var httpClient = CreateAuthenticatedHttpClient(); + + // Create a minimal telemetry request + var logs = CreateTestTelemetryLogs(1); + var config = new TelemetryConfiguration(); + var exporter = new DatabricksTelemetryExporter(httpClient, host, isAuthenticated: true, config); + var request = exporter.CreateTelemetryRequest(logs); + var json = exporter.SerializeRequest(request); + + // Send the request directly to verify HTTP status + using var content = new System.Net.Http.StringContent(json, System.Text.Encoding.UTF8, "application/json"); + using var response = await httpClient.PostAsync(endpointUrl, content); + + OutputHelper?.WriteLine($"HTTP Status Code: {(int)response.StatusCode} ({response.StatusCode})"); + OutputHelper?.WriteLine($"Response Headers: {response.Headers}"); + + var responseBody = await response.Content.ReadAsStringAsync(); + OutputHelper?.WriteLine($"Response Body: {responseBody}"); + + // Assert we get HTTP 200 + Assert.Equal(System.Net.HttpStatusCode.OK, response.StatusCode); + OutputHelper?.WriteLine("Verified: Authenticated endpoint returns HTTP 200"); + } + + /// + /// Tests that the unauthenticated telemetry endpoint returns HTTP 200. + /// This test directly calls the endpoint to verify the response status code. + /// + [SkippableFact] + public async Task UnauthenticatedEndpointReturnsHttp200() + { + var host = GetDatabricksHost(); + Skip.If(string.IsNullOrEmpty(host), "Databricks host is required"); + + var endpointUrl = $"{host}/telemetry-unauth"; + OutputHelper?.WriteLine($"Testing HTTP response from {endpointUrl}"); + + using var httpClient = new HttpClient(); + + // Create a minimal telemetry request + var logs = CreateTestTelemetryLogs(1); + var config = new TelemetryConfiguration(); + var exporter = new DatabricksTelemetryExporter(httpClient, host, isAuthenticated: false, config); + var request = exporter.CreateTelemetryRequest(logs); + var json = exporter.SerializeRequest(request); + + // Send the request directly to verify HTTP status + using var content = new System.Net.Http.StringContent(json, System.Text.Encoding.UTF8, "application/json"); + using var response = await httpClient.PostAsync(endpointUrl, content); + + OutputHelper?.WriteLine($"HTTP Status Code: {(int)response.StatusCode} ({response.StatusCode})"); + OutputHelper?.WriteLine($"Response Headers: {response.Headers}"); + + var responseBody = await response.Content.ReadAsStringAsync(); + OutputHelper?.WriteLine($"Response Body: {responseBody}"); + + // Assert we get HTTP 200 + Assert.Equal(System.Net.HttpStatusCode.OK, response.StatusCode); + OutputHelper?.WriteLine("Verified: Unauthenticated endpoint returns HTTP 200"); + } + + #region Helper Methods + + /// + /// Gets the Databricks host URL from the test configuration. + /// + private string GetDatabricksHost() + { + // Try Uri first, then fall back to HostName + if (!string.IsNullOrEmpty(TestConfiguration.Uri)) + { + var uri = new Uri(TestConfiguration.Uri); + return $"{uri.Scheme}://{uri.Host}"; + } + + if (!string.IsNullOrEmpty(TestConfiguration.HostName)) + { + return $"https://{TestConfiguration.HostName}"; + } + + return string.Empty; + } + + /// + /// Creates an HttpClient with authentication headers. + /// + private HttpClient CreateAuthenticatedHttpClient() + { + var httpClient = new HttpClient(); + + // Use AccessToken if available, otherwise fall back to Token + var token = !string.IsNullOrEmpty(TestConfiguration.AccessToken) + ? TestConfiguration.AccessToken + : TestConfiguration.Token; + + if (!string.IsNullOrEmpty(token)) + { + httpClient.DefaultRequestHeaders.Authorization = + new AuthenticationHeaderValue("Bearer", token); + } + + return httpClient; + } + + /// + /// Creates test telemetry logs for E2E testing. + /// + private IReadOnlyList CreateTestTelemetryLogs(int count) + { + var logs = new List(count); + + for (int i = 0; i < count; i++) + { + logs.Add(new TelemetryFrontendLog + { + WorkspaceId = 5870029948831567, + FrontendLogEventId = Guid.NewGuid().ToString(), + Context = new FrontendLogContext + { + TimestampMillis = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds(), + ClientContext = new TelemetryClientContext + { + UserAgent = $"AdbcDatabricksDriver/1.0.0-test (.NET; E2E Test {i})" + } + }, + Entry = new FrontendLogEntry + { + SqlDriverLog = new TelemetryEvent + { + SessionId = Guid.NewGuid().ToString(), + SqlStatementId = Guid.NewGuid().ToString(), + OperationLatencyMs = 100 + (i * 10) + } + } + }); + } + + return logs; + } + + #endregion + } +} diff --git a/csharp/test/Unit/FeatureFlagCacheTests.cs b/csharp/test/Unit/FeatureFlagCacheTests.cs new file mode 100644 index 00000000..e61e7d91 --- /dev/null +++ b/csharp/test/Unit/FeatureFlagCacheTests.cs @@ -0,0 +1,773 @@ +/* +* Copyright (c) 2025 ADBC Drivers Contributors +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Generic; +using System.Net; +using System.Net.Http; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using AdbcDrivers.Databricks; +using Moq; +using Moq.Protected; +using Xunit; + +namespace AdbcDrivers.Databricks.Tests.Unit +{ + /// + /// Tests for FeatureFlagCache and FeatureFlagContext classes. + /// + public class FeatureFlagCacheTests + { + private const string TestHost = "test-host.databricks.com"; + private const string DriverVersion = "1.0.0"; + + #region FeatureFlagContext Tests - Basic Functionality + + [Fact] + public void FeatureFlagContext_GetFlagValue_ReturnsValue() + { + // Arrange + var flags = new Dictionary + { + ["flag1"] = "value1", + ["flag2"] = "value2" + }; + var context = CreateTestContext(flags); + + // Act & Assert + Assert.Equal("value1", context.GetFlagValue("flag1")); + Assert.Equal("value2", context.GetFlagValue("flag2")); + } + + [Fact] + public void FeatureFlagContext_GetFlagValue_NotFound_ReturnsNull() + { + // Arrange + var context = CreateTestContext(); + + // Act & Assert + Assert.Null(context.GetFlagValue("nonexistent")); + } + + [Fact] + public void FeatureFlagContext_GetFlagValue_NullOrEmpty_ReturnsNull() + { + // Arrange + var context = CreateTestContext(); + + // Act & Assert + Assert.Null(context.GetFlagValue(null!)); + Assert.Null(context.GetFlagValue("")); + Assert.Null(context.GetFlagValue(" ")); + } + + [Fact] + public void FeatureFlagContext_GetFlagValue_CaseInsensitive() + { + // Arrange + var flags = new Dictionary + { + ["MyFlag"] = "value" + }; + var context = CreateTestContext(flags); + + // Act & Assert + Assert.Equal("value", context.GetFlagValue("myflag")); + Assert.Equal("value", context.GetFlagValue("MYFLAG")); + Assert.Equal("value", context.GetFlagValue("MyFlag")); + } + + [Fact] + public void FeatureFlagContext_GetAllFlags_ReturnsAllFlags() + { + // Arrange + var flags = new Dictionary + { + ["flag1"] = "value1", + ["flag2"] = "value2", + ["flag3"] = "value3" + }; + var context = CreateTestContext(flags); + + // Act + var allFlags = context.GetAllFlags(); + + // Assert + Assert.Equal(3, allFlags.Count); + Assert.Equal("value1", allFlags["flag1"]); + Assert.Equal("value2", allFlags["flag2"]); + Assert.Equal("value3", allFlags["flag3"]); + } + + [Fact] + public void FeatureFlagContext_GetAllFlags_ReturnsSnapshot() + { + // Arrange + var context = CreateTestContext(); + context.SetFlag("flag1", "value1"); + + // Act + var snapshot = context.GetAllFlags(); + context.SetFlag("flag2", "value2"); + + // Assert - snapshot should not include new flag + Assert.Single(snapshot); + Assert.Equal("value1", snapshot["flag1"]); + } + + [Fact] + public void FeatureFlagContext_GetAllFlags_Empty_ReturnsEmptyDictionary() + { + // Arrange + var context = CreateTestContext(); + + // Act + var allFlags = context.GetAllFlags(); + + // Assert + Assert.Empty(allFlags); + } + + #endregion + + #region FeatureFlagContext Tests - TTL + + [Fact] + public void FeatureFlagContext_DefaultTtl_Is15Minutes() + { + // Arrange + var context = CreateTestContext(); + + // Assert + Assert.Equal(TimeSpan.FromMinutes(15), context.Ttl); + Assert.Equal(TimeSpan.FromMinutes(15), context.RefreshInterval); // Alias + } + + [Fact] + public void FeatureFlagContext_CustomTtl() + { + // Arrange + var customTtl = TimeSpan.FromMinutes(5); + var context = CreateTestContext(null, customTtl); + + // Assert + Assert.Equal(customTtl, context.Ttl); + Assert.Equal(customTtl, context.RefreshInterval); + } + + #endregion + + #region FeatureFlagContext Tests - Dispose + + [Fact] + public void FeatureFlagContext_Dispose_CanBeCalledMultipleTimes() + { + // Arrange + var context = CreateTestContext(); + + // Act - should not throw + context.Dispose(); + context.Dispose(); + context.Dispose(); + } + + #endregion + + #region FeatureFlagContext Tests - Internal Methods + + [Fact] + public void FeatureFlagContext_SetFlag_AddsOrUpdatesFlag() + { + // Arrange + var context = CreateTestContext(); + + // Act + context.SetFlag("flag1", "value1"); + context.SetFlag("flag2", "value2"); + context.SetFlag("flag1", "updated"); + + // Assert + Assert.Equal("updated", context.GetFlagValue("flag1")); + Assert.Equal("value2", context.GetFlagValue("flag2")); + } + + [Fact] + public void FeatureFlagContext_ClearFlags_RemovesAllFlags() + { + // Arrange + var flags = new Dictionary + { + ["flag1"] = "value1", + ["flag2"] = "value2" + }; + var context = CreateTestContext(flags); + + // Act + context.ClearFlags(); + + // Assert + Assert.Empty(context.GetAllFlags()); + } + + #endregion + + #region FeatureFlagCache Singleton Tests + + [Fact] + public void FeatureFlagCache_GetInstance_ReturnsSingleton() + { + // Act + var instance1 = FeatureFlagCache.GetInstance(); + var instance2 = FeatureFlagCache.GetInstance(); + + // Assert + Assert.Same(instance1, instance2); + } + + #endregion + + #region FeatureFlagCache_GetOrCreateContext Tests + + [Fact] + public void FeatureFlagCache_GetOrCreateContext_NewHost_CreatesContext() + { + // Arrange + var cache = new FeatureFlagCache(); + var httpClient = CreateMockHttpClient(new FeatureFlagsResponse()); + + // Act + var context = cache.GetOrCreateContext("test-host-1.databricks.com", httpClient, DriverVersion); + + // Assert + Assert.NotNull(context); + Assert.True(cache.HasContext("test-host-1.databricks.com")); + + // Cleanup + cache.Clear(); + } + + [Fact] + public void FeatureFlagCache_GetOrCreateContext_ExistingHost_ReturnsSameContext() + { + // Arrange + var cache = new FeatureFlagCache(); + var host = "test-host-2.databricks.com"; + var httpClient = CreateMockHttpClient(new FeatureFlagsResponse()); + + // Act + var context1 = cache.GetOrCreateContext(host, httpClient, DriverVersion); + var context2 = cache.GetOrCreateContext(host, httpClient, DriverVersion); + + // Assert + Assert.Same(context1, context2); + + // Cleanup + cache.Clear(); + } + + [Fact] + public void FeatureFlagCache_GetOrCreateContext_MultipleHosts_CreatesMultipleContexts() + { + // Arrange + var cache = new FeatureFlagCache(); + var httpClient = CreateMockHttpClient(new FeatureFlagsResponse()); + + // Act + var context1 = cache.GetOrCreateContext("host1.databricks.com", httpClient, DriverVersion); + var context2 = cache.GetOrCreateContext("host2.databricks.com", httpClient, DriverVersion); + + // Assert + Assert.NotSame(context1, context2); + Assert.Equal(2, cache.CachedHostCount); + + // Cleanup + cache.Clear(); + } + + [Fact] + public void FeatureFlagCache_GetOrCreateContext_NullHost_ThrowsException() + { + // Arrange + var cache = new FeatureFlagCache(); + var httpClient = CreateMockHttpClient(new FeatureFlagsResponse()); + + // Act & Assert + Assert.Throws(() => cache.GetOrCreateContext(null!, httpClient, DriverVersion)); + } + + [Fact] + public void FeatureFlagCache_GetOrCreateContext_EmptyHost_ThrowsException() + { + // Arrange + var cache = new FeatureFlagCache(); + var httpClient = CreateMockHttpClient(new FeatureFlagsResponse()); + + // Act & Assert + Assert.Throws(() => cache.GetOrCreateContext("", httpClient, DriverVersion)); + } + + [Fact] + public void FeatureFlagCache_GetOrCreateContext_NullHttpClient_ThrowsException() + { + // Arrange + var cache = new FeatureFlagCache(); + + // Act & Assert + Assert.Throws(() => cache.GetOrCreateContext(TestHost, null!, DriverVersion)); + } + + [Fact] + public void FeatureFlagCache_GetOrCreateContext_CaseInsensitive() + { + // Arrange + var cache = new FeatureFlagCache(); + var host = "Test-Host.Databricks.com"; + var httpClient = CreateMockHttpClient(new FeatureFlagsResponse()); + + // Act + var context1 = cache.GetOrCreateContext(host.ToLower(), httpClient, DriverVersion); + var context2 = cache.GetOrCreateContext(host.ToUpper(), httpClient, DriverVersion); + + // Assert + Assert.Same(context1, context2); + Assert.Equal(1, cache.CachedHostCount); + + // Cleanup + cache.Clear(); + } + + #endregion + + #region FeatureFlagCache_RemoveContext Tests + + [Fact] + public void FeatureFlagCache_RemoveContext_RemovesContext() + { + // Arrange + var cache = new FeatureFlagCache(); + var host = "test-host-3.databricks.com"; + var httpClient = CreateMockHttpClient(new FeatureFlagsResponse()); + cache.GetOrCreateContext(host, httpClient, DriverVersion); + + // Act + cache.RemoveContext(host); + + // Assert + Assert.False(cache.HasContext(host)); + Assert.Equal(0, cache.CachedHostCount); + } + + [Fact] + public void FeatureFlagCache_RemoveContext_UnknownHost_DoesNothing() + { + // Arrange + var cache = new FeatureFlagCache(); + + // Act - should not throw + cache.RemoveContext("unknown-host.databricks.com"); + + // Assert + Assert.Equal(0, cache.CachedHostCount); + } + + [Fact] + public void FeatureFlagCache_RemoveContext_NullHost_DoesNothing() + { + // Arrange + var cache = new FeatureFlagCache(); + + // Act - should not throw + cache.RemoveContext(null!); + } + + #endregion + + #region FeatureFlagCache with API Response Tests + + [Fact] + public void FeatureFlagCache_GetOrCreateContext_ParsesFlags() + { + // Arrange + var cache = new FeatureFlagCache(); + var response = new FeatureFlagsResponse + { + Flags = new List + { + new FeatureFlagEntry { Name = "flag1", Value = "value1" }, + new FeatureFlagEntry { Name = "flag2", Value = "true" } + }, + TtlSeconds = 300 + }; + var httpClient = CreateMockHttpClient(response); + + // Act + var context = cache.GetOrCreateContext("test-api.databricks.com", httpClient, DriverVersion); + + // Assert + Assert.Equal("value1", context.GetFlagValue("flag1")); + Assert.Equal("true", context.GetFlagValue("flag2")); + + // Cleanup + cache.Clear(); + } + + [Fact] + public void FeatureFlagCache_GetOrCreateContext_UpdatesTtl() + { + // Arrange + var cache = new FeatureFlagCache(); + var response = new FeatureFlagsResponse + { + Flags = new List(), + TtlSeconds = 300 // 5 minutes + }; + var httpClient = CreateMockHttpClient(response); + + // Act + var context = cache.GetOrCreateContext("test-ttl.databricks.com", httpClient, DriverVersion); + + // Assert + Assert.Equal(TimeSpan.FromSeconds(300), context.Ttl); + + // Cleanup + cache.Clear(); + } + + [Fact] + public void FeatureFlagCache_GetOrCreateContext_ApiError_DoesNotThrow() + { + // Arrange + var cache = new FeatureFlagCache(); + var httpClient = CreateMockHttpClient(HttpStatusCode.InternalServerError); + + // Act - should not throw + var context = cache.GetOrCreateContext("test-error.databricks.com", httpClient, DriverVersion); + + // Assert + Assert.NotNull(context); + Assert.Empty(context.GetAllFlags()); + + // Cleanup + cache.Clear(); + } + + #endregion + + #region FeatureFlagCache Helper Method Tests + + [Fact] + public void FeatureFlagCache_TryGetContext_ExistingContext_ReturnsTrue() + { + // Arrange + var cache = new FeatureFlagCache(); + var host = "try-get-host.databricks.com"; + var httpClient = CreateMockHttpClient(new FeatureFlagsResponse()); + var expectedContext = cache.GetOrCreateContext(host, httpClient, DriverVersion); + + // Act + var result = cache.TryGetContext(host, out var context); + + // Assert + Assert.True(result); + Assert.Same(expectedContext, context); + + // Cleanup + cache.Clear(); + } + + [Fact] + public void FeatureFlagCache_TryGetContext_UnknownHost_ReturnsFalse() + { + // Arrange + var cache = new FeatureFlagCache(); + + // Act + var result = cache.TryGetContext("unknown.databricks.com", out var context); + + // Assert + Assert.False(result); + Assert.Null(context); + } + + [Fact] + public void FeatureFlagCache_Clear_RemovesAllContexts() + { + // Arrange + var cache = new FeatureFlagCache(); + var httpClient = CreateMockHttpClient(new FeatureFlagsResponse()); + cache.GetOrCreateContext("host1.databricks.com", httpClient, DriverVersion); + cache.GetOrCreateContext("host2.databricks.com", httpClient, DriverVersion); + cache.GetOrCreateContext("host3.databricks.com", httpClient, DriverVersion); + Assert.Equal(3, cache.CachedHostCount); + + // Act + cache.Clear(); + + // Assert + Assert.Equal(0, cache.CachedHostCount); + } + + #endregion + + #region Async Initial Fetch Tests + + [Fact] + public async Task FeatureFlagCache_GetOrCreateContextAsync_AwaitsInitialFetch_FlagsAvailableImmediately() + { + // Arrange + var cache = new FeatureFlagCache(); + var response = new FeatureFlagsResponse + { + Flags = new List + { + new FeatureFlagEntry { Name = "async_flag1", Value = "async_value1" }, + new FeatureFlagEntry { Name = "async_flag2", Value = "async_value2" } + }, + TtlSeconds = 300 + }; + var httpClient = CreateMockHttpClient(response); + + // Act - Use async method explicitly + var context = await cache.GetOrCreateContextAsync("test-async.databricks.com", httpClient, DriverVersion); + + // Assert - Flags should be immediately available after await completes + // This verifies that GetOrCreateContextAsync waits for the initial fetch + Assert.Equal("async_value1", context.GetFlagValue("async_flag1")); + Assert.Equal("async_value2", context.GetFlagValue("async_flag2")); + Assert.Equal(2, context.GetAllFlags().Count); + + // Cleanup + cache.Clear(); + } + + [Fact] + public async Task FeatureFlagCache_GetOrCreateContextAsync_WithDelayedResponse_StillAwaitsInitialFetch() + { + // Arrange - Create a mock that simulates network delay + var cache = new FeatureFlagCache(); + var response = new FeatureFlagsResponse + { + Flags = new List + { + new FeatureFlagEntry { Name = "delayed_flag", Value = "delayed_value" } + }, + TtlSeconds = 300 + }; + var httpClient = CreateDelayedMockHttpClient(response, delayMs: 100); + + // Act - Measure time to verify we actually waited + var stopwatch = System.Diagnostics.Stopwatch.StartNew(); + var context = await cache.GetOrCreateContextAsync("test-delayed.databricks.com", httpClient, DriverVersion); + stopwatch.Stop(); + + // Assert - Should have waited for the delayed response + Assert.True(stopwatch.ElapsedMilliseconds >= 50, "Should have waited for the delayed fetch"); + + // Flags should be available immediately after await + Assert.Equal("delayed_value", context.GetFlagValue("delayed_flag")); + + // Cleanup + cache.Clear(); + } + + [Fact] + public async Task FeatureFlagContext_CreateAsync_AwaitsInitialFetch_FlagsPopulated() + { + // Arrange + var response = new FeatureFlagsResponse + { + Flags = new List + { + new FeatureFlagEntry { Name = "create_async_flag", Value = "create_async_value" } + }, + TtlSeconds = 600 + }; + var httpClient = CreateMockHttpClient(response); + + // Act - Call CreateAsync directly + var context = await FeatureFlagContext.CreateAsync( + "test-create-async.databricks.com", + httpClient, + DriverVersion); + + // Assert - Flags should be populated after CreateAsync completes + Assert.Equal("create_async_value", context.GetFlagValue("create_async_flag")); + Assert.Equal(TimeSpan.FromSeconds(600), context.Ttl); + + // Cleanup + context.Dispose(); + } + + #endregion + + #region Thread Safety Tests + + [Fact] + public async Task FeatureFlagCache_ConcurrentGetOrCreateContext_ThreadSafe() + { + // Arrange + var cache = new FeatureFlagCache(); + var host = "concurrent-host.databricks.com"; + var httpClient = CreateMockHttpClient(new FeatureFlagsResponse()); + var tasks = new Task[100]; + + // Act + for (int i = 0; i < 100; i++) + { + tasks[i] = Task.Run(() => cache.GetOrCreateContext(host, httpClient, DriverVersion)); + } + + var contexts = await Task.WhenAll(tasks); + + // Assert - All should be the same context + var firstContext = contexts[0]; + Assert.All(contexts, ctx => Assert.Same(firstContext, ctx)); + + // Cleanup + cache.Clear(); + } + + [Fact] + public async Task FeatureFlagContext_ConcurrentFlagAccess_ThreadSafe() + { + // Arrange + var flags = new Dictionary + { + ["flag1"] = "value1", + ["flag2"] = "value2" + }; + var context = CreateTestContext(flags); + var tasks = new Task[100]; + + // Act - Concurrent reads and writes + for (int i = 0; i < 100; i++) + { + var index = i; + tasks[i] = Task.Run(() => + { + // Read + var value = context.GetFlagValue("flag1"); + var all = context.GetAllFlags(); + var flag2Value = context.GetFlagValue("flag2"); + + // Write + context.SetFlag($"new_flag_{index}", $"value_{index}"); + }); + } + + await Task.WhenAll(tasks); + + // Assert - No exceptions thrown, all flags accessible + Assert.Equal("value1", context.GetFlagValue("flag1")); + var allFlags = context.GetAllFlags(); + Assert.True(allFlags.Count >= 2); // At least original flags + } + + #endregion + + #region Helper Methods + + /// + /// Creates a FeatureFlagContext for unit testing with pre-populated flags. + /// Does not make API calls or start background refresh. + /// + private static FeatureFlagContext CreateTestContext( + IReadOnlyDictionary? initialFlags = null, + TimeSpan? ttl = null) + { + var context = new FeatureFlagContext( + host: "test-host", + httpClient: null, + driverVersion: DriverVersion, + endpointFormat: null); + + if (ttl.HasValue) + { + context.Ttl = ttl.Value; + } + + if (initialFlags != null) + { + foreach (var kvp in initialFlags) + { + context.SetFlag(kvp.Key, kvp.Value); + } + } + + return context; + } + + private static HttpClient CreateMockHttpClient(FeatureFlagsResponse response) + { + var json = JsonSerializer.Serialize(response); + return CreateMockHttpClient(HttpStatusCode.OK, json); + } + + private static HttpClient CreateMockHttpClient(HttpStatusCode statusCode, string content = "") + { + var mockHandler = new Mock(); + mockHandler.Protected() + .Setup>( + "SendAsync", + ItExpr.IsAny(), + ItExpr.IsAny()) + .ReturnsAsync(new HttpResponseMessage + { + StatusCode = statusCode, + Content = new StringContent(content) + }); + + return new HttpClient(mockHandler.Object) + { + BaseAddress = new Uri("https://test.databricks.com") + }; + } + + private static HttpClient CreateMockHttpClient(HttpStatusCode statusCode) + { + return CreateMockHttpClient(statusCode, ""); + } + + private static HttpClient CreateDelayedMockHttpClient(FeatureFlagsResponse response, int delayMs) + { + var json = JsonSerializer.Serialize(response); + var mockHandler = new Mock(); + mockHandler.Protected() + .Setup>( + "SendAsync", + ItExpr.IsAny(), + ItExpr.IsAny()) + .Returns(async (HttpRequestMessage request, CancellationToken token) => + { + await Task.Delay(delayMs, token); + return new HttpResponseMessage + { + StatusCode = HttpStatusCode.OK, + Content = new StringContent(json) + }; + }); + + return new HttpClient(mockHandler.Object) + { + BaseAddress = new Uri("https://test.databricks.com") + }; + } + + #endregion + } +} diff --git a/csharp/test/Unit/Telemetry/DatabricksTelemetryExporterTests.cs b/csharp/test/Unit/Telemetry/DatabricksTelemetryExporterTests.cs new file mode 100644 index 00000000..fdb657db --- /dev/null +++ b/csharp/test/Unit/Telemetry/DatabricksTelemetryExporterTests.cs @@ -0,0 +1,603 @@ +/* +* Copyright (c) 2025 ADBC Drivers Contributors +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Generic; +using System.Net; +using System.Net.Http; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using AdbcDrivers.Databricks.Telemetry; +using AdbcDrivers.Databricks.Telemetry.Models; +using Xunit; + +namespace AdbcDrivers.Databricks.Tests.Unit.Telemetry +{ + /// + /// Tests for DatabricksTelemetryExporter class. + /// + public class DatabricksTelemetryExporterTests + { + private const string TestHost = "https://test-workspace.databricks.com"; + + #region Constructor Tests + + [Fact] + public void DatabricksTelemetryExporter_Constructor_NullHttpClient_ThrowsException() + { + // Arrange + var config = new TelemetryConfiguration(); + + // Act & Assert + Assert.Throws(() => + new DatabricksTelemetryExporter(null!, TestHost, true, config)); + } + + [Fact] + public void DatabricksTelemetryExporter_Constructor_NullHost_ThrowsException() + { + // Arrange + using var httpClient = new HttpClient(); + var config = new TelemetryConfiguration(); + + // Act & Assert + Assert.Throws(() => + new DatabricksTelemetryExporter(httpClient, null!, true, config)); + } + + [Fact] + public void DatabricksTelemetryExporter_Constructor_EmptyHost_ThrowsException() + { + // Arrange + using var httpClient = new HttpClient(); + var config = new TelemetryConfiguration(); + + // Act & Assert + Assert.Throws(() => + new DatabricksTelemetryExporter(httpClient, "", true, config)); + } + + [Fact] + public void DatabricksTelemetryExporter_Constructor_WhitespaceHost_ThrowsException() + { + // Arrange + using var httpClient = new HttpClient(); + var config = new TelemetryConfiguration(); + + // Act & Assert + Assert.Throws(() => + new DatabricksTelemetryExporter(httpClient, " ", true, config)); + } + + [Fact] + public void DatabricksTelemetryExporter_Constructor_NullConfig_ThrowsException() + { + // Arrange + using var httpClient = new HttpClient(); + + // Act & Assert + Assert.Throws(() => + new DatabricksTelemetryExporter(httpClient, TestHost, true, null!)); + } + + [Fact] + public void DatabricksTelemetryExporter_Constructor_ValidParameters_SetsProperties() + { + // Arrange + using var httpClient = new HttpClient(); + var config = new TelemetryConfiguration(); + + // Act + var exporter = new DatabricksTelemetryExporter(httpClient, TestHost, true, config); + + // Assert + Assert.Equal(TestHost, exporter.Host); + Assert.True(exporter.IsAuthenticated); + } + + #endregion + + #region Endpoint Tests + + [Fact] + public void DatabricksTelemetryExporter_ExportAsync_Authenticated_UsesCorrectEndpoint() + { + // Arrange + using var httpClient = new HttpClient(); + var config = new TelemetryConfiguration(); + var exporter = new DatabricksTelemetryExporter(httpClient, TestHost, true, config); + + // Act + var endpointUrl = exporter.GetEndpointUrl(); + + // Assert + Assert.Equal($"{TestHost}/telemetry-ext", endpointUrl); + } + + [Fact] + public void DatabricksTelemetryExporter_ExportAsync_Unauthenticated_UsesCorrectEndpoint() + { + // Arrange + using var httpClient = new HttpClient(); + var config = new TelemetryConfiguration(); + var exporter = new DatabricksTelemetryExporter(httpClient, TestHost, false, config); + + // Act + var endpointUrl = exporter.GetEndpointUrl(); + + // Assert + Assert.Equal($"{TestHost}/telemetry-unauth", endpointUrl); + } + + [Fact] + public void DatabricksTelemetryExporter_GetEndpointUrl_HostWithTrailingSlash_HandlesCorrectly() + { + // Arrange + using var httpClient = new HttpClient(); + var config = new TelemetryConfiguration(); + var exporter = new DatabricksTelemetryExporter(httpClient, $"{TestHost}/", true, config); + + // Act + var endpointUrl = exporter.GetEndpointUrl(); + + // Assert + Assert.Equal($"{TestHost}/telemetry-ext", endpointUrl); + } + + #endregion + + #region TelemetryRequest Creation Tests + + [Fact] + public void DatabricksTelemetryExporter_CreateTelemetryRequest_SingleLog_CreatesCorrectFormat() + { + // Arrange + using var httpClient = new HttpClient(); + var config = new TelemetryConfiguration(); + var exporter = new DatabricksTelemetryExporter(httpClient, TestHost, true, config); + + var logs = new List + { + new TelemetryFrontendLog + { + WorkspaceId = 12345, + FrontendLogEventId = "test-event-id" + } + }; + + // Act + var request = exporter.CreateTelemetryRequest(logs); + + // Assert + Assert.True(request.UploadTime > 0); + Assert.Single(request.ProtoLogs); + Assert.Contains("12345", request.ProtoLogs[0]); + Assert.Contains("test-event-id", request.ProtoLogs[0]); + } + + [Fact] + public void DatabricksTelemetryExporter_CreateTelemetryRequest_MultipleLogs_CreatesCorrectFormat() + { + // Arrange + using var httpClient = new HttpClient(); + var config = new TelemetryConfiguration(); + var exporter = new DatabricksTelemetryExporter(httpClient, TestHost, true, config); + + var logs = new List + { + new TelemetryFrontendLog { WorkspaceId = 1, FrontendLogEventId = "event-1" }, + new TelemetryFrontendLog { WorkspaceId = 2, FrontendLogEventId = "event-2" }, + new TelemetryFrontendLog { WorkspaceId = 3, FrontendLogEventId = "event-3" } + }; + + // Act + var request = exporter.CreateTelemetryRequest(logs); + + // Assert + Assert.Equal(3, request.ProtoLogs.Count); + } + + [Fact] + public void DatabricksTelemetryExporter_CreateTelemetryRequest_UploadTime_IsRecentTimestamp() + { + // Arrange + using var httpClient = new HttpClient(); + var config = new TelemetryConfiguration(); + var exporter = new DatabricksTelemetryExporter(httpClient, TestHost, true, config); + + var logs = new List + { + new TelemetryFrontendLog { WorkspaceId = 1 } + }; + + var beforeTime = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds(); + + // Act + var request = exporter.CreateTelemetryRequest(logs); + + var afterTime = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds(); + + // Assert + Assert.True(request.UploadTime >= beforeTime); + Assert.True(request.UploadTime <= afterTime); + } + + #endregion + + #region Serialization Tests + + [Fact] + public void DatabricksTelemetryExporter_SerializeRequest_ProducesValidJson() + { + // Arrange + using var httpClient = new HttpClient(); + var config = new TelemetryConfiguration(); + var exporter = new DatabricksTelemetryExporter(httpClient, TestHost, true, config); + + var request = new TelemetryRequest + { + UploadTime = 1234567890000, + ProtoLogs = new List { "{\"workspace_id\":12345}" } + }; + + // Act + var json = exporter.SerializeRequest(request); + + // Assert + Assert.NotEmpty(json); + var parsed = JsonDocument.Parse(json); + Assert.Equal(1234567890000, parsed.RootElement.GetProperty("uploadTime").GetInt64()); + Assert.Single(parsed.RootElement.GetProperty("protoLogs").EnumerateArray()); + } + + #endregion + + #region ExportAsync Tests with Mock Handler + + [Fact] + public async Task DatabricksTelemetryExporter_ExportAsync_Success_ReturnsTrue() + { + // Arrange + var handler = new MockHttpMessageHandler((request, ct) => + { + return Task.FromResult(new HttpResponseMessage(HttpStatusCode.OK)); + }); + + using var httpClient = new HttpClient(handler); + var config = new TelemetryConfiguration { MaxRetries = 0 }; + var exporter = new DatabricksTelemetryExporter(httpClient, TestHost, true, config); + + var logs = new List + { + new TelemetryFrontendLog { WorkspaceId = 12345 } + }; + + // Act + var result = await exporter.ExportAsync(logs); + + // Assert + Assert.True(result, "ExportAsync should return true on successful HTTP 200 response"); + } + + [Fact] + public async Task DatabricksTelemetryExporter_ExportAsync_EmptyList_ReturnsTrueWithoutRequest() + { + // Arrange + var requestCount = 0; + var handler = new MockHttpMessageHandler((request, ct) => + { + requestCount++; + return Task.FromResult(new HttpResponseMessage(HttpStatusCode.OK)); + }); + + using var httpClient = new HttpClient(handler); + var config = new TelemetryConfiguration(); + var exporter = new DatabricksTelemetryExporter(httpClient, TestHost, true, config); + + // Act + var result = await exporter.ExportAsync(new List()); + + // Assert - no HTTP request should be made and should return true + Assert.Equal(0, requestCount); + Assert.True(result, "ExportAsync should return true for empty list (nothing to export)"); + } + + [Fact] + public async Task DatabricksTelemetryExporter_ExportAsync_NullList_ReturnsTrueWithoutRequest() + { + // Arrange + var requestCount = 0; + var handler = new MockHttpMessageHandler((request, ct) => + { + requestCount++; + return Task.FromResult(new HttpResponseMessage(HttpStatusCode.OK)); + }); + + using var httpClient = new HttpClient(handler); + var config = new TelemetryConfiguration(); + var exporter = new DatabricksTelemetryExporter(httpClient, TestHost, true, config); + + // Act + var result = await exporter.ExportAsync(null!); + + // Assert - no HTTP request should be made and should return true + Assert.Equal(0, requestCount); + Assert.True(result, "ExportAsync should return true for null list (nothing to export)"); + } + + [Fact] + public async Task DatabricksTelemetryExporter_ExportAsync_TransientFailure_RetriesAndReturnsTrue() + { + // Arrange + var attemptCount = 0; + var handler = new MockHttpMessageHandler((request, ct) => + { + attemptCount++; + if (attemptCount < 3) + { + throw new HttpRequestException("Simulated transient failure"); + } + return Task.FromResult(new HttpResponseMessage(HttpStatusCode.OK)); + }); + + using var httpClient = new HttpClient(handler); + var config = new TelemetryConfiguration { MaxRetries = 3, RetryDelayMs = 10 }; + var exporter = new DatabricksTelemetryExporter(httpClient, TestHost, true, config); + + var logs = new List + { + new TelemetryFrontendLog { WorkspaceId = 12345 } + }; + + // Act + var result = await exporter.ExportAsync(logs); + + // Assert - should retry and eventually succeed + Assert.Equal(3, attemptCount); + Assert.True(result, "ExportAsync should return true after successful retry"); + } + + [Fact] + public async Task DatabricksTelemetryExporter_ExportAsync_MaxRetries_ReturnsFalse() + { + // Arrange + var attemptCount = 0; + var handler = new MockHttpMessageHandler((request, ct) => + { + attemptCount++; + throw new HttpRequestException("Simulated persistent failure"); + }); + + using var httpClient = new HttpClient(handler); + var config = new TelemetryConfiguration { MaxRetries = 3, RetryDelayMs = 10 }; + var exporter = new DatabricksTelemetryExporter(httpClient, TestHost, true, config); + + var logs = new List + { + new TelemetryFrontendLog { WorkspaceId = 12345 } + }; + + // Act + var result = await exporter.ExportAsync(logs); + + // Assert - should have tried initial attempt + max retries and return false + Assert.Equal(4, attemptCount); // 1 initial + 3 retries + Assert.False(result, "ExportAsync should return false after all retries exhausted"); + } + + [Fact] + public async Task DatabricksTelemetryExporter_ExportAsync_TerminalError_ReturnsFalseWithoutRetry() + { + // Arrange + var attemptCount = 0; + var handler = new MockHttpMessageHandler((request, ct) => + { + attemptCount++; + // Throw HttpRequestException with an inner UnauthorizedAccessException + // The ExceptionClassifier checks inner exceptions for terminal types + throw new HttpRequestException("Authentication failed", + new UnauthorizedAccessException("Unauthorized")); + }); + + using var httpClient = new HttpClient(handler); + var config = new TelemetryConfiguration { MaxRetries = 3, RetryDelayMs = 10 }; + var exporter = new DatabricksTelemetryExporter(httpClient, TestHost, true, config); + + var logs = new List + { + new TelemetryFrontendLog { WorkspaceId = 12345 } + }; + + // Act + var result = await exporter.ExportAsync(logs); + + // Assert - should not retry on terminal error and return false + Assert.Equal(1, attemptCount); + Assert.False(result, "ExportAsync should return false on terminal error"); + } + + [Fact] + public async Task DatabricksTelemetryExporter_ExportAsync_Cancelled_ThrowsCancelledException() + { + // Arrange + var handler = new MockHttpMessageHandler((request, ct) => + { + ct.ThrowIfCancellationRequested(); + return Task.FromResult(new HttpResponseMessage(HttpStatusCode.OK)); + }); + + using var httpClient = new HttpClient(handler); + var config = new TelemetryConfiguration(); + var exporter = new DatabricksTelemetryExporter(httpClient, TestHost, true, config); + + var logs = new List + { + new TelemetryFrontendLog { WorkspaceId = 12345 } + }; + + using var cts = new CancellationTokenSource(); + cts.Cancel(); + + // Act & Assert - TaskCanceledException inherits from OperationCanceledException + var ex = await Assert.ThrowsAnyAsync( + () => exporter.ExportAsync(logs, cts.Token)); + Assert.True(ex is OperationCanceledException); + } + + [Fact] + public async Task DatabricksTelemetryExporter_ExportAsync_Authenticated_SendsToCorrectEndpoint() + { + // Arrange + string? capturedUrl = null; + var handler = new MockHttpMessageHandler((request, ct) => + { + capturedUrl = request.RequestUri?.ToString(); + return Task.FromResult(new HttpResponseMessage(HttpStatusCode.OK)); + }); + + using var httpClient = new HttpClient(handler); + var config = new TelemetryConfiguration(); + var exporter = new DatabricksTelemetryExporter(httpClient, TestHost, true, config); + + var logs = new List + { + new TelemetryFrontendLog { WorkspaceId = 12345 } + }; + + // Act + await exporter.ExportAsync(logs); + + // Assert + Assert.Equal($"{TestHost}/telemetry-ext", capturedUrl); + } + + [Fact] + public async Task DatabricksTelemetryExporter_ExportAsync_Unauthenticated_SendsToCorrectEndpoint() + { + // Arrange + string? capturedUrl = null; + var handler = new MockHttpMessageHandler((request, ct) => + { + capturedUrl = request.RequestUri?.ToString(); + return Task.FromResult(new HttpResponseMessage(HttpStatusCode.OK)); + }); + + using var httpClient = new HttpClient(handler); + var config = new TelemetryConfiguration(); + var exporter = new DatabricksTelemetryExporter(httpClient, TestHost, false, config); + + var logs = new List + { + new TelemetryFrontendLog { WorkspaceId = 12345 } + }; + + // Act + await exporter.ExportAsync(logs); + + // Assert + Assert.Equal($"{TestHost}/telemetry-unauth", capturedUrl); + } + + [Fact] + public async Task DatabricksTelemetryExporter_ExportAsync_SendsValidJsonBody() + { + // Arrange + string? capturedContent = null; + var handler = new MockHttpMessageHandler(async (request, ct) => + { + capturedContent = await request.Content!.ReadAsStringAsync(); + return new HttpResponseMessage(HttpStatusCode.OK); + }); + + using var httpClient = new HttpClient(handler); + var config = new TelemetryConfiguration(); + var exporter = new DatabricksTelemetryExporter(httpClient, TestHost, true, config); + + var logs = new List + { + new TelemetryFrontendLog + { + WorkspaceId = 12345, + FrontendLogEventId = "test-event-123" + } + }; + + // Act + await exporter.ExportAsync(logs); + + // Assert + Assert.NotNull(capturedContent); + var parsedRequest = JsonDocument.Parse(capturedContent); + Assert.True(parsedRequest.RootElement.TryGetProperty("uploadTime", out _)); + Assert.True(parsedRequest.RootElement.TryGetProperty("protoLogs", out var protoLogs)); + Assert.Single(protoLogs.EnumerateArray()); + + var protoLogJson = protoLogs[0].GetString(); + Assert.Contains("12345", protoLogJson); + Assert.Contains("test-event-123", protoLogJson); + } + + [Fact] + public async Task DatabricksTelemetryExporter_ExportAsync_GenericException_ReturnsFalse() + { + // Arrange + var handler = new MockHttpMessageHandler((request, ct) => + { + throw new InvalidOperationException("Unexpected error"); + }); + + using var httpClient = new HttpClient(handler); + var config = new TelemetryConfiguration { MaxRetries = 0 }; + var exporter = new DatabricksTelemetryExporter(httpClient, TestHost, true, config); + + var logs = new List + { + new TelemetryFrontendLog { WorkspaceId = 12345 } + }; + + // Act + var result = await exporter.ExportAsync(logs); + + // Assert - should not throw but return false + Assert.False(result, "ExportAsync should return false on unexpected exception"); + } + + #endregion + + #region Mock HTTP Handler + + /// + /// Mock HttpMessageHandler for testing HTTP requests. + /// + private class MockHttpMessageHandler : HttpMessageHandler + { + private readonly Func> _handler; + + public MockHttpMessageHandler(Func> handler) + { + _handler = handler; + } + + protected override Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) + { + return _handler(request, cancellationToken); + } + } + + #endregion + } +} diff --git a/csharp/test/Unit/Telemetry/MetricsAggregatorTests.cs b/csharp/test/Unit/Telemetry/MetricsAggregatorTests.cs new file mode 100644 index 00000000..141fe84d --- /dev/null +++ b/csharp/test/Unit/Telemetry/MetricsAggregatorTests.cs @@ -0,0 +1,812 @@ +/* +* Copyright (c) 2025 ADBC Drivers Contributors +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; +using AdbcDrivers.Databricks.Telemetry; +using AdbcDrivers.Databricks.Telemetry.Models; +using Xunit; + +namespace AdbcDrivers.Databricks.Tests.Unit.Telemetry +{ + /// + /// Tests for MetricsAggregator class. + /// + public class MetricsAggregatorTests : IDisposable + { + private readonly ActivitySource _activitySource; + private readonly MockTelemetryExporter _mockExporter; + private readonly TelemetryConfiguration _config; + private MetricsAggregator? _aggregator; + + private const long TestWorkspaceId = 12345; + private const string TestUserAgent = "TestAgent/1.0"; + + public MetricsAggregatorTests() + { + _activitySource = new ActivitySource("TestSource"); + _mockExporter = new MockTelemetryExporter(); + _config = new TelemetryConfiguration + { + BatchSize = 10, + FlushIntervalMs = 60000 // Set high to control when flush happens + }; + } + + public void Dispose() + { + _aggregator?.Dispose(); + _activitySource.Dispose(); + } + + #region Constructor Tests + + [Fact] + public void MetricsAggregator_Constructor_NullExporter_ThrowsException() + { + // Act & Assert + Assert.Throws(() => + new MetricsAggregator(null!, _config, TestWorkspaceId)); + } + + [Fact] + public void MetricsAggregator_Constructor_NullConfig_ThrowsException() + { + // Act & Assert + Assert.Throws(() => + new MetricsAggregator(_mockExporter, null!, TestWorkspaceId)); + } + + [Fact] + public void MetricsAggregator_Constructor_ValidParameters_CreatesInstance() + { + // Act + _aggregator = new MetricsAggregator(_mockExporter, _config, TestWorkspaceId, TestUserAgent); + + // Assert + Assert.NotNull(_aggregator); + Assert.Equal(0, _aggregator.PendingEventCount); + Assert.Equal(0, _aggregator.ActiveStatementCount); + } + + #endregion + + #region ProcessActivity Tests - Connection Events + + [Fact] + public void MetricsAggregator_ProcessActivity_ConnectionOpen_EmitsImmediately() + { + // Arrange + _aggregator = new MetricsAggregator(_mockExporter, _config, TestWorkspaceId, TestUserAgent); + + using var listener = new ActivityListener + { + ShouldListenTo = _ => true, + Sample = (ref ActivityCreationOptions _) => ActivitySamplingResult.AllDataAndRecorded + }; + ActivitySource.AddActivityListener(listener); + + using var activity = _activitySource.StartActivity("Connection.Open"); + Assert.NotNull(activity); + activity.SetTag("session.id", "test-session-123"); + activity.Stop(); + + // Act + _aggregator.ProcessActivity(activity); + + // Assert + Assert.Equal(1, _aggregator.PendingEventCount); + } + + [Fact] + public void MetricsAggregator_ProcessActivity_ConnectionOpenAsync_EmitsImmediately() + { + // Arrange + _aggregator = new MetricsAggregator(_mockExporter, _config, TestWorkspaceId, TestUserAgent); + + using var listener = new ActivityListener + { + ShouldListenTo = _ => true, + Sample = (ref ActivityCreationOptions _) => ActivitySamplingResult.AllDataAndRecorded + }; + ActivitySource.AddActivityListener(listener); + + using var activity = _activitySource.StartActivity("OpenAsync"); + Assert.NotNull(activity); + activity.SetTag("session.id", "test-session-456"); + activity.SetTag("driver.version", "1.0.0"); + activity.Stop(); + + // Act + _aggregator.ProcessActivity(activity); + + // Assert + Assert.Equal(1, _aggregator.PendingEventCount); + } + + #endregion + + #region ProcessActivity Tests - Statement Events + + [Fact] + public void MetricsAggregator_ProcessActivity_Statement_AggregatesByStatementId() + { + // Arrange + _aggregator = new MetricsAggregator(_mockExporter, _config, TestWorkspaceId, TestUserAgent); + + using var listener = new ActivityListener + { + ShouldListenTo = _ => true, + Sample = (ref ActivityCreationOptions _) => ActivitySamplingResult.AllDataAndRecorded + }; + ActivitySource.AddActivityListener(listener); + + var statementId = "stmt-123"; + var sessionId = "session-456"; + + // First activity + using (var activity1 = _activitySource.StartActivity("Statement.Execute")) + { + Assert.NotNull(activity1); + activity1.SetTag("statement.id", statementId); + activity1.SetTag("session.id", sessionId); + activity1.SetTag("result.chunk_count", "5"); + activity1.Stop(); + + _aggregator.ProcessActivity(activity1); + } + + // Second activity with same statement_id + using (var activity2 = _activitySource.StartActivity("Statement.FetchResults")) + { + Assert.NotNull(activity2); + activity2.SetTag("statement.id", statementId); + activity2.SetTag("session.id", sessionId); + activity2.SetTag("result.chunk_count", "3"); + activity2.Stop(); + + _aggregator.ProcessActivity(activity2); + } + + // Assert - should not emit until CompleteStatement + Assert.Equal(0, _aggregator.PendingEventCount); + Assert.Equal(1, _aggregator.ActiveStatementCount); + } + + [Fact] + public void MetricsAggregator_ProcessActivity_Statement_WithoutStatementId_EmitsImmediately() + { + // Arrange + _aggregator = new MetricsAggregator(_mockExporter, _config, TestWorkspaceId, TestUserAgent); + + using var listener = new ActivityListener + { + ShouldListenTo = _ => true, + Sample = (ref ActivityCreationOptions _) => ActivitySamplingResult.AllDataAndRecorded + }; + ActivitySource.AddActivityListener(listener); + + using var activity = _activitySource.StartActivity("Statement.Execute"); + Assert.NotNull(activity); + // No statement.id tag + activity.SetTag("session.id", "session-123"); + activity.Stop(); + + // Act + _aggregator.ProcessActivity(activity); + + // Assert - should emit immediately since no statement_id + Assert.Equal(1, _aggregator.PendingEventCount); + } + + #endregion + + #region CompleteStatement Tests + + [Fact] + public void MetricsAggregator_CompleteStatement_EmitsAggregatedEvent() + { + // Arrange + _aggregator = new MetricsAggregator(_mockExporter, _config, TestWorkspaceId, TestUserAgent); + + using var listener = new ActivityListener + { + ShouldListenTo = _ => true, + Sample = (ref ActivityCreationOptions _) => ActivitySamplingResult.AllDataAndRecorded + }; + ActivitySource.AddActivityListener(listener); + + var statementId = "stmt-complete-123"; + var sessionId = "session-complete-456"; + + using (var activity = _activitySource.StartActivity("Statement.Execute")) + { + Assert.NotNull(activity); + activity.SetTag("statement.id", statementId); + activity.SetTag("session.id", sessionId); + activity.SetTag("result.format", "cloudfetch"); + activity.SetTag("result.chunk_count", "10"); + activity.Stop(); + + _aggregator.ProcessActivity(activity); + } + + Assert.Equal(0, _aggregator.PendingEventCount); + Assert.Equal(1, _aggregator.ActiveStatementCount); + + // Act + _aggregator.CompleteStatement(statementId); + + // Assert + Assert.Equal(1, _aggregator.PendingEventCount); + Assert.Equal(0, _aggregator.ActiveStatementCount); + } + + [Fact] + public void MetricsAggregator_CompleteStatement_NullStatementId_NoOp() + { + // Arrange + _aggregator = new MetricsAggregator(_mockExporter, _config, TestWorkspaceId, TestUserAgent); + + // Act - should not throw + _aggregator.CompleteStatement(null!); + _aggregator.CompleteStatement(string.Empty); + + // Assert + Assert.Equal(0, _aggregator.PendingEventCount); + } + + [Fact] + public void MetricsAggregator_CompleteStatement_UnknownStatementId_NoOp() + { + // Arrange + _aggregator = new MetricsAggregator(_mockExporter, _config, TestWorkspaceId, TestUserAgent); + + // Act - should not throw + _aggregator.CompleteStatement("unknown-statement-id"); + + // Assert + Assert.Equal(0, _aggregator.PendingEventCount); + } + + #endregion + + #region FlushAsync Tests + + [Fact] + public async Task MetricsAggregator_FlushAsync_BatchSizeReached_ExportsEvents() + { + // Arrange + var config = new TelemetryConfiguration { BatchSize = 2, FlushIntervalMs = 60000 }; + _aggregator = new MetricsAggregator(_mockExporter, config, TestWorkspaceId, TestUserAgent); + + using var listener = new ActivityListener + { + ShouldListenTo = _ => true, + Sample = (ref ActivityCreationOptions _) => ActivitySamplingResult.AllDataAndRecorded + }; + ActivitySource.AddActivityListener(listener); + + // Add events that should trigger flush + for (int i = 0; i < 3; i++) + { + using var activity = _activitySource.StartActivity("Connection.Open"); + Assert.NotNull(activity); + activity.SetTag("session.id", $"session-{i}"); + activity.Stop(); + _aggregator.ProcessActivity(activity); + } + + // Wait for any background flush to complete + await Task.Delay(100); + + // Assert - exporter should have been called + Assert.True(_mockExporter.ExportCallCount > 0); + } + + [Fact] + public async Task MetricsAggregator_FlushAsync_TimeInterval_ExportsEvents() + { + // Arrange + var config = new TelemetryConfiguration { BatchSize = 100, FlushIntervalMs = 100 }; // 100ms interval + _aggregator = new MetricsAggregator(_mockExporter, config, TestWorkspaceId, TestUserAgent); + + using var listener = new ActivityListener + { + ShouldListenTo = _ => true, + Sample = (ref ActivityCreationOptions _) => ActivitySamplingResult.AllDataAndRecorded + }; + ActivitySource.AddActivityListener(listener); + + using var activity = _activitySource.StartActivity("Connection.Open"); + Assert.NotNull(activity); + activity.SetTag("session.id", "session-timer"); + activity.Stop(); + _aggregator.ProcessActivity(activity); + + // Act - wait for timer to trigger flush + await Task.Delay(250); + + // Assert - exporter should have been called by timer + Assert.True(_mockExporter.ExportCallCount > 0); + } + + [Fact] + public async Task MetricsAggregator_FlushAsync_EmptyQueue_NoExport() + { + // Arrange + _aggregator = new MetricsAggregator(_mockExporter, _config, TestWorkspaceId, TestUserAgent); + + // Act + await _aggregator.FlushAsync(); + + // Assert + Assert.Equal(0, _mockExporter.ExportCallCount); + } + + [Fact] + public async Task MetricsAggregator_FlushAsync_MultipleEvents_ExportsAll() + { + // Arrange + _aggregator = new MetricsAggregator(_mockExporter, _config, TestWorkspaceId, TestUserAgent); + + using var listener = new ActivityListener + { + ShouldListenTo = _ => true, + Sample = (ref ActivityCreationOptions _) => ActivitySamplingResult.AllDataAndRecorded + }; + ActivitySource.AddActivityListener(listener); + + for (int i = 0; i < 5; i++) + { + using var activity = _activitySource.StartActivity("Connection.Open"); + Assert.NotNull(activity); + activity.SetTag("session.id", $"session-{i}"); + activity.Stop(); + _aggregator.ProcessActivity(activity); + } + + Assert.Equal(5, _aggregator.PendingEventCount); + + // Act + await _aggregator.FlushAsync(); + + // Assert + Assert.Equal(0, _aggregator.PendingEventCount); + Assert.Equal(1, _mockExporter.ExportCallCount); + Assert.Equal(5, _mockExporter.TotalExportedEvents); + } + + #endregion + + #region RecordException Tests + + [Fact] + public void MetricsAggregator_RecordException_Terminal_FlushesImmediately() + { + // Arrange + _aggregator = new MetricsAggregator(_mockExporter, _config, TestWorkspaceId, TestUserAgent); + + var terminalException = new HttpRequestException("401 (Unauthorized)"); + + // Act + _aggregator.RecordException("stmt-123", "session-456", terminalException); + + // Assert - terminal exception should be queued immediately + Assert.Equal(1, _aggregator.PendingEventCount); + } + + [Fact] + public void MetricsAggregator_RecordException_Retryable_BuffersUntilComplete() + { + // Arrange + _aggregator = new MetricsAggregator(_mockExporter, _config, TestWorkspaceId, TestUserAgent); + + var retryableException = new HttpRequestException("503 (Service Unavailable)"); + var statementId = "stmt-retryable-123"; + var sessionId = "session-retryable-456"; + + // Act + _aggregator.RecordException(statementId, sessionId, retryableException); + + // Assert - retryable exception should be buffered, not queued + Assert.Equal(0, _aggregator.PendingEventCount); + Assert.Equal(1, _aggregator.ActiveStatementCount); + } + + [Fact] + public void MetricsAggregator_RecordException_Retryable_EmittedOnFailedComplete() + { + // Arrange + _aggregator = new MetricsAggregator(_mockExporter, _config, TestWorkspaceId, TestUserAgent); + + var retryableException = new HttpRequestException("503 (Service Unavailable)"); + var statementId = "stmt-failed-123"; + var sessionId = "session-failed-456"; + + _aggregator.RecordException(statementId, sessionId, retryableException); + Assert.Equal(0, _aggregator.PendingEventCount); + + // Act - complete statement as failed + _aggregator.CompleteStatement(statementId, failed: true); + + // Assert - both statement event and error event should be emitted + Assert.Equal(2, _aggregator.PendingEventCount); + } + + [Fact] + public void MetricsAggregator_RecordException_Retryable_NotEmittedOnSuccessComplete() + { + // Arrange + _aggregator = new MetricsAggregator(_mockExporter, _config, TestWorkspaceId, TestUserAgent); + + var retryableException = new HttpRequestException("503 (Service Unavailable)"); + var statementId = "stmt-success-123"; + var sessionId = "session-success-456"; + + _aggregator.RecordException(statementId, sessionId, retryableException); + + // Act - complete statement as success + _aggregator.CompleteStatement(statementId, failed: false); + + // Assert - only statement event should be emitted, not the error + Assert.Equal(1, _aggregator.PendingEventCount); + } + + [Fact] + public void MetricsAggregator_RecordException_NullException_NoOp() + { + // Arrange + _aggregator = new MetricsAggregator(_mockExporter, _config, TestWorkspaceId, TestUserAgent); + + // Act - should not throw + _aggregator.RecordException("stmt-123", "session-456", null); + + // Assert + Assert.Equal(0, _aggregator.PendingEventCount); + } + + [Fact] + public void MetricsAggregator_RecordException_NullStatementId_NoOp() + { + // Arrange + _aggregator = new MetricsAggregator(_mockExporter, _config, TestWorkspaceId, TestUserAgent); + + // Act - should not throw + _aggregator.RecordException(null!, "session-456", new Exception("test")); + _aggregator.RecordException(string.Empty, "session-456", new Exception("test")); + + // Assert + Assert.Equal(0, _aggregator.PendingEventCount); + } + + #endregion + + #region Exception Swallowing Tests + + [Fact] + public void MetricsAggregator_ProcessActivity_ExceptionSwallowed_NoThrow() + { + // Arrange + var throwingExporter = new ThrowingTelemetryExporter(); + _aggregator = new MetricsAggregator(throwingExporter, _config, TestWorkspaceId, TestUserAgent); + + // Act & Assert - should not throw even with throwing exporter + _aggregator.ProcessActivity(null); + } + + [Fact] + public async Task MetricsAggregator_FlushAsync_ExceptionSwallowed_NoThrow() + { + // Arrange + var throwingExporter = new ThrowingTelemetryExporter(); + _aggregator = new MetricsAggregator(throwingExporter, _config, TestWorkspaceId, TestUserAgent); + + using var listener = new ActivityListener + { + ShouldListenTo = _ => true, + Sample = (ref ActivityCreationOptions _) => ActivitySamplingResult.AllDataAndRecorded + }; + ActivitySource.AddActivityListener(listener); + + using var activity = _activitySource.StartActivity("Connection.Open"); + Assert.NotNull(activity); + activity.SetTag("session.id", "session-throw"); + activity.Stop(); + _aggregator.ProcessActivity(activity); + + // Act & Assert - should not throw + await _aggregator.FlushAsync(); + } + + #endregion + + #region Tag Filtering Tests + + [Fact] + public void MetricsAggregator_ProcessActivity_FiltersTags_UsingRegistry() + { + // Arrange + _aggregator = new MetricsAggregator(_mockExporter, _config, TestWorkspaceId, TestUserAgent); + + using var listener = new ActivityListener + { + ShouldListenTo = _ => true, + Sample = (ref ActivityCreationOptions _) => ActivitySamplingResult.AllDataAndRecorded + }; + ActivitySource.AddActivityListener(listener); + + var statementId = "stmt-filter-123"; + + using (var activity = _activitySource.StartActivity("Statement.Execute")) + { + Assert.NotNull(activity); + activity.SetTag("statement.id", statementId); + activity.SetTag("session.id", "session-filter"); + activity.SetTag("result.format", "cloudfetch"); + // This sensitive tag should be filtered out + activity.SetTag("db.statement", "SELECT * FROM sensitive_table"); + // This tag should be included + activity.SetTag("result.chunk_count", "5"); + activity.Stop(); + + _aggregator.ProcessActivity(activity); + } + + // Complete to emit + _aggregator.CompleteStatement(statementId); + + // Assert - event should be created (we can verify via export) + Assert.Equal(1, _aggregator.PendingEventCount); + } + + #endregion + + #region WrapInFrontendLog Tests + + [Fact] + public void MetricsAggregator_WrapInFrontendLog_CreatesValidStructure() + { + // Arrange + _aggregator = new MetricsAggregator(_mockExporter, _config, TestWorkspaceId, TestUserAgent); + + var telemetryEvent = new TelemetryEvent + { + SessionId = "session-wrap-123", + SqlStatementId = "stmt-wrap-456", + OperationLatencyMs = 100 + }; + + // Act + var frontendLog = _aggregator.WrapInFrontendLog(telemetryEvent); + + // Assert + Assert.NotNull(frontendLog); + Assert.Equal(TestWorkspaceId, frontendLog.WorkspaceId); + Assert.NotEmpty(frontendLog.FrontendLogEventId); + Assert.NotNull(frontendLog.Context); + Assert.NotNull(frontendLog.Context.ClientContext); + Assert.Equal(TestUserAgent, frontendLog.Context.ClientContext.UserAgent); + Assert.True(frontendLog.Context.TimestampMillis > 0); + Assert.NotNull(frontendLog.Entry); + Assert.NotNull(frontendLog.Entry.SqlDriverLog); + Assert.Equal(telemetryEvent.SessionId, frontendLog.Entry.SqlDriverLog.SessionId); + Assert.Equal(telemetryEvent.SqlStatementId, frontendLog.Entry.SqlDriverLog.SqlStatementId); + } + + [Fact] + public void MetricsAggregator_WrapInFrontendLog_GeneratesUniqueEventIds() + { + // Arrange + _aggregator = new MetricsAggregator(_mockExporter, _config, TestWorkspaceId, TestUserAgent); + + var telemetryEvent = new TelemetryEvent { SessionId = "session-unique" }; + + // Act + var frontendLog1 = _aggregator.WrapInFrontendLog(telemetryEvent); + var frontendLog2 = _aggregator.WrapInFrontendLog(telemetryEvent); + + // Assert + Assert.NotEqual(frontendLog1.FrontendLogEventId, frontendLog2.FrontendLogEventId); + } + + #endregion + + #region Dispose Tests + + [Fact] + public void MetricsAggregator_Dispose_FlushesRemainingEvents() + { + // Arrange + _aggregator = new MetricsAggregator(_mockExporter, _config, TestWorkspaceId, TestUserAgent); + + using var listener = new ActivityListener + { + ShouldListenTo = _ => true, + Sample = (ref ActivityCreationOptions _) => ActivitySamplingResult.AllDataAndRecorded + }; + ActivitySource.AddActivityListener(listener); + + using var activity = _activitySource.StartActivity("Connection.Open"); + Assert.NotNull(activity); + activity.SetTag("session.id", "session-dispose"); + activity.Stop(); + _aggregator.ProcessActivity(activity); + + Assert.Equal(1, _aggregator.PendingEventCount); + + // Act + _aggregator.Dispose(); + + // Assert - events should have been flushed + Assert.True(_mockExporter.ExportCallCount > 0); + } + + [Fact] + public void MetricsAggregator_Dispose_CanBeCalledMultipleTimes() + { + // Arrange + _aggregator = new MetricsAggregator(_mockExporter, _config, TestWorkspaceId, TestUserAgent); + + // Act & Assert - should not throw + _aggregator.Dispose(); + _aggregator.Dispose(); + } + + #endregion + + #region Integration Tests + + [Fact] + public async Task MetricsAggregator_EndToEnd_StatementLifecycle() + { + // Arrange + _aggregator = new MetricsAggregator(_mockExporter, _config, TestWorkspaceId, TestUserAgent); + + using var listener = new ActivityListener + { + ShouldListenTo = _ => true, + Sample = (ref ActivityCreationOptions _) => ActivitySamplingResult.AllDataAndRecorded + }; + ActivitySource.AddActivityListener(listener); + + var statementId = "stmt-e2e-123"; + var sessionId = "session-e2e-456"; + + // Simulate statement execution + using (var executeActivity = _activitySource.StartActivity("Statement.Execute")) + { + Assert.NotNull(executeActivity); + executeActivity.SetTag("statement.id", statementId); + executeActivity.SetTag("session.id", sessionId); + executeActivity.SetTag("result.format", "cloudfetch"); + executeActivity.Stop(); + _aggregator.ProcessActivity(executeActivity); + } + + // Simulate chunk downloads + for (int i = 0; i < 3; i++) + { + using var downloadActivity = _activitySource.StartActivity("CloudFetch.Download"); + Assert.NotNull(downloadActivity); + downloadActivity.SetTag("statement.id", statementId); + downloadActivity.SetTag("session.id", sessionId); + downloadActivity.SetTag("result.chunk_count", "1"); + downloadActivity.SetTag("result.bytes_downloaded", "1000"); + downloadActivity.Stop(); + _aggregator.ProcessActivity(downloadActivity); + } + + // Complete statement + _aggregator.CompleteStatement(statementId); + + // Flush + await _aggregator.FlushAsync(); + + // Assert + Assert.Equal(1, _mockExporter.ExportCallCount); + Assert.Equal(1, _mockExporter.TotalExportedEvents); + Assert.Equal(0, _aggregator.ActiveStatementCount); + } + + [Fact] + public async Task MetricsAggregator_EndToEnd_MultipleStatements() + { + // Arrange + _aggregator = new MetricsAggregator(_mockExporter, _config, TestWorkspaceId, TestUserAgent); + + using var listener = new ActivityListener + { + ShouldListenTo = _ => true, + Sample = (ref ActivityCreationOptions _) => ActivitySamplingResult.AllDataAndRecorded + }; + ActivitySource.AddActivityListener(listener); + + var sessionId = "session-multi"; + + // Execute 3 statements + for (int i = 0; i < 3; i++) + { + var statementId = $"stmt-multi-{i}"; + + using var activity = _activitySource.StartActivity("Statement.Execute"); + Assert.NotNull(activity); + activity.SetTag("statement.id", statementId); + activity.SetTag("session.id", sessionId); + activity.Stop(); + _aggregator.ProcessActivity(activity); + + _aggregator.CompleteStatement(statementId); + } + + // Flush + await _aggregator.FlushAsync(); + + // Assert + Assert.Equal(3, _mockExporter.TotalExportedEvents); + } + + #endregion + + #region Mock Classes + + /// + /// Mock telemetry exporter for testing. + /// + private class MockTelemetryExporter : ITelemetryExporter + { + private int _exportCallCount; + private int _totalExportedEvents; + private readonly ConcurrentBag _exportedLogs = new ConcurrentBag(); + + public int ExportCallCount => _exportCallCount; + public int TotalExportedEvents => _totalExportedEvents; + public IReadOnlyCollection ExportedLogs => _exportedLogs.ToList(); + + public Task ExportAsync(IReadOnlyList logs, CancellationToken ct = default) + { + Interlocked.Increment(ref _exportCallCount); + Interlocked.Add(ref _totalExportedEvents, logs.Count); + + foreach (var log in logs) + { + _exportedLogs.Add(log); + } + + return Task.CompletedTask; + } + } + + /// + /// Telemetry exporter that always throws for testing exception handling. + /// + private class ThrowingTelemetryExporter : ITelemetryExporter + { + public Task ExportAsync(IReadOnlyList logs, CancellationToken ct = default) + { + throw new InvalidOperationException("Test exception from exporter"); + } + } + + #endregion + } +}