From b09b5b93a4249cc5886353c7d811a80b2fccce9f Mon Sep 17 00:00:00 2001 From: telli Date: Mon, 6 Apr 2026 01:09:22 -0700 Subject: [PATCH 1/2] add model profiles and gemma provider routing --- README.md | 110 ++++ docs/MODEL_PROFILES.md | 267 +++++++++ src/OpenClaw.Agent/AgentRuntime.cs | 13 + src/OpenClaw.Agent/ILlmExecutionService.cs | 4 + src/OpenClaw.Cli/OpenClawHttpClient.cs | 9 + src/OpenClaw.Cli/Program.cs | 127 +++++ src/OpenClaw.Client/OpenClawHttpClient.cs | 22 + .../Abstractions/IModelProfiles.cs | 43 ++ src/OpenClaw.Core/Models/GatewayConfig.cs | 5 + .../Models/IntegrationApiModels.cs | 1 + src/OpenClaw.Core/Models/ModelProfiles.cs | 147 +++++ .../Models/ModelSelectionException.cs | 9 + src/OpenClaw.Core/Models/OperatorApiModels.cs | 4 + src/OpenClaw.Core/Models/Session.cs | 30 + .../Validation/ConfigValidator.cs | 50 ++ src/OpenClaw.Core/Validation/DoctorCheck.cs | 16 + .../Composition/CoreServicesExtensions.cs | 5 + .../Composition/IntegrationApiFacade.cs | 5 + .../RuntimeInitializationExtensions.cs | 4 + .../Endpoints/AdminEndpoints.cs | 51 ++ .../Endpoints/OpenAiEndpoints.cs | 26 +- .../Extensions/GatewayWorkers.cs | 18 + .../GatewayLlmExecutionService.cs | 380 ++++++++----- .../Models/ConfiguredModelProfileRegistry.cs | 238 ++++++++ .../Models/DefaultModelSelectionPolicy.cs | 265 +++++++++ .../Models/ModelEvaluationRunner.cs | 525 ++++++++++++++++++ .../RuntimeOperationsState.cs | 16 + .../MafAgentRuntime.cs | 19 + .../ModelProfileSelectionTests.cs | 293 ++++++++++ 29 files changed, 2554 insertions(+), 148 deletions(-) create mode 100644 docs/MODEL_PROFILES.md create mode 100644 src/OpenClaw.Core/Abstractions/IModelProfiles.cs create mode 100644 src/OpenClaw.Core/Models/ModelProfiles.cs create mode 100644 src/OpenClaw.Core/Models/ModelSelectionException.cs create mode 100644 src/OpenClaw.Gateway/Models/ConfiguredModelProfileRegistry.cs create mode 100644 src/OpenClaw.Gateway/Models/DefaultModelSelectionPolicy.cs create mode 100644 src/OpenClaw.Gateway/Models/ModelEvaluationRunner.cs create mode 100644 src/OpenClaw.Tests/ModelProfileSelectionTests.cs diff --git a/README.md b/README.md index 2420272..b3469c1 100644 --- a/README.md +++ b/README.md @@ -39,6 +39,7 @@ If this repo is useful to you, please star it. - Configurable reasoning effort (`/think off|low|medium|high`) - Delegated sub-agents with configurable profiles, tool restrictions, and depth limits - Multi-agent routing — route channels/senders with per-route model, prompt, tool preset, and tool allowlist overrides +- Profile-aware routing — route channels/senders with per-route profile id, capability requirements, preferred tags, and fallback profile order - Persistent session search, user profiles, and session-scoped todo state available to the agent and operators ### Built-In Providers @@ -52,6 +53,17 @@ If this repo is useful to you, please star it. OpenClaw registers OpenAI, Claude, and Gemini natively at startup, so a fresh install only needs a provider id, model, and API key to get going. +### Model Profiles and Gemma + +OpenClaw now supports **provider-agnostic named model profiles**. This keeps model routing and capability checks above the provider layer, so **Gemma-family models, including Gemma 4, plug into the existing runtime through Ollama or OpenAI-compatible endpoints** instead of requiring a Gemma-specific execution path. + +- Configure **Gemma 4 local/dev** with `Provider: "ollama"` and a named profile such as `gemma4-local` +- Configure **Gemma 4 production/self-hosted** with `Provider: "openai-compatible"` and a named profile such as `gemma4-prod` +- Select profiles explicitly or by **capabilities** and **tags** such as `local`, `private`, `cheap`, `tool-reliable`, or `vision` +- The runtime will **fail clearly** or **fall back** if a selected profile cannot satisfy required capabilities like tool calling, structured outputs, streaming, or image input + +See [Model Profiles and Gemma](docs/MODEL_PROFILES.md) for the full configuration and evaluation guide. + ### Review-First Learning - The runtime can observe completed sessions and create **pending learning proposals** instead of auto-mutating behavior @@ -261,6 +273,12 @@ Then open one of: # CLI chat dotnet run --project src/OpenClaw.Cli -c Release -- chat +# Inspect registered model profiles +dotnet run --project src/OpenClaw.Cli -c Release -- models list + +# Run a built-in evaluation suite against a profile +dotnet run --project src/OpenClaw.Cli -c Release -- eval run --profile gemma4-prod + # CLI live session dotnet run --project src/OpenClaw.Cli -c Release -- live --provider gemini @@ -305,6 +323,98 @@ export MODEL_PROVIDER_KEY="sk-ant-..." export OpenClaw__Llm__Provider="gemini" export OpenClaw__Llm__Model="gemini-2.5-flash" export MODEL_PROVIDER_KEY="AIza..." + +# Ollama +export OpenClaw__Llm__Provider="ollama" +export OpenClaw__Llm__Model="gemma4" + +# OpenAI-compatible +export OpenClaw__Llm__Provider="openai-compatible" +export OpenClaw__Llm__Model="gemma-4" +export OpenClaw__Llm__Endpoint="https://your-inference-gateway.example.com/v1" +export MODEL_PROVIDER_KEY="your-api-key" +``` + +**Example model profile config with Gemma 4:** + +```json +{ + "OpenClaw": { + "Llm": { + "Provider": "openai", + "Model": "gpt-4.1" + }, + "Models": { + "DefaultProfile": "gemma4-prod", + "Profiles": [ + { + "Id": "gemma4-local", + "Provider": "ollama", + "Model": "gemma4", + "BaseUrl": "http://localhost:11434/v1", + "Tags": ["local", "private", "cheap"], + "Capabilities": { + "SupportsTools": false, + "SupportsVision": true, + "SupportsJsonSchema": false, + "SupportsStructuredOutputs": false, + "SupportsStreaming": true, + "SupportsParallelToolCalls": false, + "SupportsReasoningEffort": false, + "SupportsSystemMessages": true, + "SupportsImageInput": true, + "SupportsAudioInput": false, + "MaxContextTokens": 131072, + "MaxOutputTokens": 8192 + } + }, + { + "Id": "gemma4-prod", + "Provider": "openai-compatible", + "Model": "gemma-4", + "BaseUrl": "https://your-inference-gateway.example.com/v1", + "ApiKey": "env:MODEL_PROVIDER_KEY", + "Tags": ["private", "prod", "vision"], + "FallbackProfileIds": ["frontier-tools"], + "Capabilities": { + "SupportsTools": true, + "SupportsVision": true, + "SupportsJsonSchema": true, + "SupportsStructuredOutputs": true, + "SupportsStreaming": true, + "SupportsParallelToolCalls": true, + "SupportsReasoningEffort": false, + "SupportsSystemMessages": true, + "SupportsImageInput": true, + "SupportsAudioInput": false, + "MaxContextTokens": 262144, + "MaxOutputTokens": 16384 + } + }, + { + "Id": "frontier-tools", + "Provider": "openai", + "Model": "gpt-4.1", + "Tags": ["tool-reliable", "frontier"], + "Capabilities": { + "SupportsTools": true, + "SupportsVision": true, + "SupportsJsonSchema": true, + "SupportsStructuredOutputs": true, + "SupportsStreaming": true, + "SupportsParallelToolCalls": true, + "SupportsReasoningEffort": true, + "SupportsSystemMessages": true, + "SupportsImageInput": true, + "SupportsAudioInput": true, + "MaxContextTokens": 1000000, + "MaxOutputTokens": 32768 + } + } + ] + } + } +} ``` See the full [Quickstart Guide](docs/QUICKSTART.md) for deployment notes. diff --git a/docs/MODEL_PROFILES.md b/docs/MODEL_PROFILES.md new file mode 100644 index 0000000..1e9198a --- /dev/null +++ b/docs/MODEL_PROFILES.md @@ -0,0 +1,267 @@ +# Model Profiles and Gemma + +OpenClaw integrates **Gemma-family models, including Gemma 4**, through the existing provider seams instead of creating a Gemma-specific runtime fork. + +That design keeps: + +- one execution stack +- one tool-calling stack +- one session/compaction/middleware stack +- one MAF integration path + +Gemma is treated as a **model backend** that can be reached through: + +1. **Ollama** for local and development workflows +2. **OpenAI-compatible endpoints** for production or self-hosted inference gateways +3. future provider extensions if needed, without changing the runtime architecture + +## Why profiles exist + +Providers and models do not expose the same capabilities. A route that needs tool calling, structured outputs, and image input should not silently run against a model that only supports plain text chat. + +Model profiles let OpenClaw describe a model instance independently from the provider transport: + +- profile id +- provider id +- model id +- base URL +- API key or env ref +- capabilities +- context/output hints +- tags such as `local`, `private`, `cheap`, `tool-reliable`, `vision` + +The runtime uses those profiles to: + +- select a profile explicitly +- choose a profile based on route/session capability requirements +- prefer tags such as `local` or `private` +- fall back to another profile when allowed +- fail clearly when no profile can safely satisfy the request + +## Example configuration + +```json +{ + "OpenClaw": { + "Llm": { + "Provider": "openai", + "Model": "gpt-4.1" + }, + "Models": { + "DefaultProfile": "gemma4-prod", + "Profiles": [ + { + "Id": "gemma4-local", + "Provider": "ollama", + "Model": "gemma4", + "BaseUrl": "http://localhost:11434/v1", + "Tags": ["local", "private", "cheap"], + "Capabilities": { + "SupportsTools": false, + "SupportsVision": true, + "SupportsJsonSchema": false, + "SupportsStructuredOutputs": false, + "SupportsStreaming": true, + "SupportsParallelToolCalls": false, + "SupportsReasoningEffort": false, + "SupportsSystemMessages": true, + "SupportsImageInput": true, + "SupportsAudioInput": false, + "MaxContextTokens": 131072, + "MaxOutputTokens": 8192 + } + }, + { + "Id": "gemma4-prod", + "Provider": "openai-compatible", + "Model": "gemma-4", + "BaseUrl": "https://your-inference-gateway.example.com/v1", + "ApiKey": "env:MODEL_PROVIDER_KEY", + "Tags": ["private", "prod", "vision"], + "FallbackProfileIds": ["frontier-tools"], + "Capabilities": { + "SupportsTools": true, + "SupportsVision": true, + "SupportsJsonSchema": true, + "SupportsStructuredOutputs": true, + "SupportsStreaming": true, + "SupportsParallelToolCalls": true, + "SupportsReasoningEffort": false, + "SupportsSystemMessages": true, + "SupportsImageInput": true, + "SupportsAudioInput": false, + "MaxContextTokens": 262144, + "MaxOutputTokens": 16384 + } + }, + { + "Id": "frontier-tools", + "Provider": "openai", + "Model": "gpt-4.1", + "Tags": ["tool-reliable", "frontier"], + "Capabilities": { + "SupportsTools": true, + "SupportsVision": true, + "SupportsJsonSchema": true, + "SupportsStructuredOutputs": true, + "SupportsStreaming": true, + "SupportsParallelToolCalls": true, + "SupportsReasoningEffort": true, + "SupportsSystemMessages": true, + "SupportsImageInput": true, + "SupportsAudioInput": true, + "MaxContextTokens": 1000000, + "MaxOutputTokens": 32768 + } + } + ] + }, + "Routing": { + "Enabled": true, + "Routes": { + "telegram:private-coder": { + "ChannelId": "telegram", + "SenderId": "private-coder", + "ModelProfileId": "gemma4-local", + "PreferredModelTags": ["local", "private"], + "FallbackModelProfileIds": ["frontier-tools"], + "ModelRequirements": { + "SupportsTools": true, + "SupportsStreaming": true + } + } + } + } + } +} +``` + +## Gemma through Ollama + +Use this when you want local/private inference for development or workstation deployments. + +```json +{ + "Id": "gemma4-local", + "Provider": "ollama", + "Model": "gemma4", + "BaseUrl": "http://localhost:11434/v1", + "Tags": ["local", "private", "cheap"] +} +``` + +Notes: + +- OpenClaw talks to Ollama through the existing OpenAI-compatible adapter path. +- `BaseUrl` defaults to `http://localhost:11434/v1` if omitted by the legacy provider config, but setting it explicitly is clearer for named profiles. +- If the profile does not advertise `SupportsTools`, routes that require tools will fail clearly or fall back. + +## Gemma through an OpenAI-compatible gateway + +Use this when Gemma is hosted behind a production inference service that exposes an OpenAI-compatible API. + +```json +{ + "Id": "gemma4-prod", + "Provider": "openai-compatible", + "Model": "gemma-4", + "BaseUrl": "https://your-inference-gateway.example.com/v1", + "ApiKey": "env:MODEL_PROVIDER_KEY", + "Tags": ["private", "prod", "vision"] +} +``` + +Notes: + +- OpenClaw uses the existing OpenAI-compatible provider transport. +- No Gemma-specific runtime logic is required. +- Capability flags should reflect what your actual gateway exposes for that Gemma deployment. + +## Route assignment and fallback + +Routes can now express: + +- `ModelProfileId` +- `PreferredModelTags` +- `FallbackModelProfileIds` +- `ModelRequirements` + +Common patterns: + +- coding/tool-heavy route: require `SupportsTools=true`, prefer tag `tool-reliable` +- privacy-sensitive route: prefer tags `local` and `private` +- cheap summarization route: prefer tags `cheap` and `local` + +If the selected profile cannot satisfy the request, OpenClaw will either: + +- fall back to the first compatible profile in `FallbackModelProfileIds`, or +- fail with a clear message such as: + +`This route requires tool calling, but selected model profile 'gemma4-local' does not support it.` + +## Capability flags + +OpenClaw currently uses capability flags for: + +- tool calling +- vision and image input +- JSON schema and structured outputs +- streaming +- parallel tool calls +- reasoning effort +- system messages +- audio input +- context/output token hints + +These flags drive profile selection and request validation. They do not add provider-specific runtime branches. + +## CLI and operator surfaces + +List profiles: + +```bash +openclaw models list +``` + +Run profile doctor: + +```bash +openclaw models doctor +``` + +Run the built-in evaluation suite: + +```bash +openclaw eval run --profile gemma4-prod +``` + +Compare multiple profiles: + +```bash +openclaw eval compare --profiles gemma4-prod,frontier-tools +``` + +The gateway also exposes: + +- `GET /admin/models` +- `GET /admin/models/doctor` +- `POST /admin/models/evaluations` + +## Evaluation harness + +The first version ships with OpenClaw-native scenarios: + +- plain chat response +- structured JSON extraction +- tool selection correctness +- multi-turn continuity +- compaction recovery +- streaming behavior +- vision input behavior + +Reports are written to: + +- `memory/admin/model-evaluations/.json` +- `memory/admin/model-evaluations/.md` + +This is intentionally lightweight and filesystem-based for the first release. diff --git a/src/OpenClaw.Agent/AgentRuntime.cs b/src/OpenClaw.Agent/AgentRuntime.cs index 4efa2a8..16e6f47 100644 --- a/src/OpenClaw.Agent/AgentRuntime.cs +++ b/src/OpenClaw.Agent/AgentRuntime.cs @@ -299,6 +299,12 @@ public async Task RunAsync( LogTurnComplete(turnCtx); return ex.Message; } + catch (ModelSelectionException ex) + { + _logger?.LogWarning("[{CorrelationId}] Model selection failed: {Message}", turnCtx.CorrelationId, ex.Message); + LogTurnComplete(turnCtx); + return ex.Message; + } catch (Exception ex) { @@ -797,6 +803,13 @@ private async Task StreamLlmCollectAsync( { throw; } + catch (ModelSelectionException ex) + { + _logger?.LogWarning("[{CorrelationId}] Streaming model selection failed: {Message}", turnCtx.CorrelationId, ex.Message); + result.Error = ex.Message; + LogTurnComplete(turnCtx); + return result; + } catch (Exception ex) { _metrics?.IncrementLlmErrors(); diff --git a/src/OpenClaw.Agent/ILlmExecutionService.cs b/src/OpenClaw.Agent/ILlmExecutionService.cs index f98606e..540ac50 100644 --- a/src/OpenClaw.Agent/ILlmExecutionService.cs +++ b/src/OpenClaw.Agent/ILlmExecutionService.cs @@ -12,17 +12,21 @@ public sealed class LlmExecutionEstimate public sealed class LlmExecutionResult { + public string? ProfileId { get; init; } public required string ProviderId { get; init; } public required string ModelId { get; init; } public string? PolicyRuleId { get; init; } + public string? SelectionExplanation { get; init; } public required ChatResponse Response { get; init; } } public sealed class LlmStreamingExecutionResult { + public string? ProfileId { get; init; } public required string ProviderId { get; init; } public required string ModelId { get; init; } public string? PolicyRuleId { get; init; } + public string? SelectionExplanation { get; init; } public required IAsyncEnumerable Updates { get; init; } } diff --git a/src/OpenClaw.Cli/OpenClawHttpClient.cs b/src/OpenClaw.Cli/OpenClawHttpClient.cs index ec0776f..bfc796f 100644 --- a/src/OpenClaw.Cli/OpenClawHttpClient.cs +++ b/src/OpenClaw.Cli/OpenClawHttpClient.cs @@ -37,6 +37,15 @@ public Task GetHeartbeatStatusAsync(CancellationToken c public Task GetSecurityPostureAsync(CancellationToken cancellationToken) => _inner.GetSecurityPostureAsync(cancellationToken); + public Task GetModelProfilesAsync(CancellationToken cancellationToken) + => _inner.GetModelProfilesAsync(cancellationToken); + + public Task GetModelSelectionDoctorAsync(CancellationToken cancellationToken) + => _inner.GetModelSelectionDoctorAsync(cancellationToken); + + public Task RunModelEvaluationAsync(ModelEvaluationRequest request, CancellationToken cancellationToken) + => _inner.RunModelEvaluationAsync(request, cancellationToken); + public Task SimulateApprovalAsync(ApprovalSimulationRequest request, CancellationToken cancellationToken) => _inner.SimulateApprovalAsync(request, cancellationToken); diff --git a/src/OpenClaw.Cli/Program.cs b/src/OpenClaw.Cli/Program.cs index bad66b5..e57cd00 100644 --- a/src/OpenClaw.Cli/Program.cs +++ b/src/OpenClaw.Cli/Program.cs @@ -34,6 +34,8 @@ public static async Task Main(string[] args) "setup" => await SetupAsync(rest), "migrate" => await MigrateAsync(rest), "heartbeat" => await HeartbeatAsync(rest), + "models" => await ModelsAsync(rest), + "eval" => await EvalAsync(rest), "admin" => await AdminAsync(rest), "plugins" => await PluginCommands.RunAsync(rest), "clawhub" => await ClawHubCommand.RunAsync(rest), @@ -80,6 +82,8 @@ openclaw tui [options] openclaw setup [options] openclaw migrate [options] openclaw heartbeat [options] + openclaw models [options] + openclaw eval [options] openclaw admin [options] openclaw clawhub [wrapper options] [--] @@ -111,6 +115,10 @@ openclaw tui openclaw setup --workspace ./workspace openclaw migrate --apply openclaw heartbeat status + openclaw models list + openclaw models doctor + openclaw eval run --profile gemma4-prod + openclaw eval compare --profiles gemma4-prod,frontier-tools openclaw heartbeat wizard openclaw admin posture openclaw admin approvals simulate --tool shell --args "{\"command\":\"pwd\"}" @@ -163,6 +171,30 @@ openclaw admin approvals simulate --tool [--args ] [--autonomy ] [--token ] + openclaw models doctor [--url ] [--token ] + """); + } + + private static void PrintEvalHelp() + { + Console.WriteLine( + """ + openclaw eval + + Usage: + openclaw eval run [--profile ] [--scenario ]... [--url ] [--token ] + openclaw eval compare --profiles [--scenario ]... [--url ] [--token ] + """); + } + private static void PrintSetupHelp() { Console.WriteLine( @@ -634,6 +666,101 @@ private static async Task AdminAsync(string[] args) return 2; } + private static async Task ModelsAsync(string[] args) + { + if (args.Length == 0 || args[0] is "-h" or "--help" or "help") + { + PrintModelsHelp(); + return 0; + } + + var subcommand = args[0].Trim().ToLowerInvariant(); + var parsed = CliArgs.Parse(args.Skip(1).ToArray()); + var baseUrl = parsed.GetOption("--url") ?? Environment.GetEnvironmentVariable(EnvBaseUrl) ?? DefaultBaseUrl; + var token = ResolveAuthToken(parsed, Console.Error); + using var client = new OpenClawHttpClient(baseUrl, token); + + if (subcommand == "list") + { + var response = await client.GetModelProfilesAsync(CancellationToken.None); + Console.WriteLine($"default_profile={response.DefaultProfileId ?? "none"}"); + foreach (var profile in response.Profiles) + { + Console.WriteLine($"- {profile.Id} | {profile.ProviderId}/{profile.ModelId} | default={profile.IsDefault.ToString().ToLowerInvariant()} | tags={string.Join(",", profile.Tags)}"); + if (profile.ValidationIssues.Length > 0) + Console.WriteLine($" issues: {string.Join("; ", profile.ValidationIssues)}"); + } + + return 0; + } + + if (subcommand == "doctor") + { + var response = await client.GetModelSelectionDoctorAsync(CancellationToken.None); + Console.WriteLine($"default_profile={response.DefaultProfileId ?? "none"}"); + foreach (var error in response.Errors) + Console.WriteLine($"ERROR: {error}"); + foreach (var warning in response.Warnings) + Console.WriteLine($"WARN: {warning}"); + foreach (var profile in response.Profiles) + Console.WriteLine($"- {profile.Id} | available={profile.IsAvailable.ToString().ToLowerInvariant()} | {profile.ProviderId}/{profile.ModelId}"); + return response.Errors.Count > 0 ? 1 : 0; + } + + PrintModelsHelp(); + return 2; + } + + private static async Task EvalAsync(string[] args) + { + if (args.Length == 0 || args[0] is "-h" or "--help" or "help") + { + PrintEvalHelp(); + return 0; + } + + var subcommand = args[0].Trim().ToLowerInvariant(); + var parsed = CliArgs.Parse(args.Skip(1).ToArray()); + var baseUrl = parsed.GetOption("--url") ?? Environment.GetEnvironmentVariable(EnvBaseUrl) ?? DefaultBaseUrl; + var token = ResolveAuthToken(parsed, Console.Error); + using var client = new OpenClawHttpClient(baseUrl, token); + + if (subcommand is "run" or "compare") + { + var profiles = new List(); + if (parsed.GetOption("--profile") is { Length: > 0 } singleProfile) + profiles.Add(singleProfile); + if (parsed.GetOption("--profiles") is { Length: > 0 } multiProfiles) + profiles.AddRange(multiProfiles.Split(',', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries)); + + var scenarios = parsed.Options.TryGetValue("--scenario", out var scenarioValues) + ? scenarioValues.ToArray() + : []; + var report = await client.RunModelEvaluationAsync(new ModelEvaluationRequest + { + ProfileIds = profiles.Distinct(StringComparer.OrdinalIgnoreCase).ToArray(), + ScenarioIds = scenarios, + IncludeMarkdown = true + }, CancellationToken.None); + + Console.WriteLine($"run_id={report.RunId}"); + Console.WriteLine($"json={report.JsonPath}"); + if (!string.IsNullOrWhiteSpace(report.MarkdownPath)) + Console.WriteLine($"markdown={report.MarkdownPath}"); + foreach (var profile in report.Profiles) + { + Console.WriteLine($"[{profile.ProfileId}] {profile.ProviderId}/{profile.ModelId}"); + foreach (var scenario in profile.Scenarios) + Console.WriteLine($"- {scenario.ScenarioId}: {scenario.Status} ({scenario.LatencyMs} ms) {scenario.Summary ?? scenario.Error ?? ""}".TrimEnd()); + } + + return 0; + } + + PrintEvalHelp(); + return 2; + } + private static async Task HeartbeatStatusAsync(OpenClawHttpClient client) { var status = await client.GetHeartbeatStatusAsync(CancellationToken.None); diff --git a/src/OpenClaw.Client/OpenClawHttpClient.cs b/src/OpenClaw.Client/OpenClawHttpClient.cs index afed822..063099c 100644 --- a/src/OpenClaw.Client/OpenClawHttpClient.cs +++ b/src/OpenClaw.Client/OpenClawHttpClient.cs @@ -33,6 +33,9 @@ public sealed class OpenClawHttpClient : IDisposable private readonly Uri _adminHeartbeatPreviewUri; private readonly Uri _adminHeartbeatStatusUri; private readonly Uri _adminPostureUri; + private readonly Uri _adminModelsUri; + private readonly Uri _adminModelsDoctorUri; + private readonly Uri _adminModelEvaluationsUri; private readonly Uri _adminApprovalSimulationUri; private readonly Uri _adminIncidentExportUri; private readonly Uri _adminWhatsAppSetupUri; @@ -71,6 +74,9 @@ public OpenClawHttpClient(string baseUrl, string? authToken, HttpClient? httpCli _adminHeartbeatPreviewUri = new Uri(baseUri, "/admin/heartbeat/preview"); _adminHeartbeatStatusUri = new Uri(baseUri, "/admin/heartbeat/status"); _adminPostureUri = new Uri(baseUri, "/admin/posture"); + _adminModelsUri = new Uri(baseUri, "/admin/models"); + _adminModelsDoctorUri = new Uri(baseUri, "/admin/models/doctor"); + _adminModelEvaluationsUri = new Uri(baseUri, "/admin/models/evaluations"); _adminApprovalSimulationUri = new Uri(baseUri, "/admin/approvals/simulate"); _adminIncidentExportUri = new Uri(baseUri, "/admin/incident/export"); _adminWhatsAppSetupUri = new Uri(baseUri, "/admin/channels/whatsapp/setup"); @@ -438,6 +444,22 @@ public Task GetHeartbeatStatusAsync(CancellationToken c public Task GetSecurityPostureAsync(CancellationToken cancellationToken) => GetAsync(_adminPostureUri, CoreJsonContext.Default.SecurityPostureResponse, cancellationToken); + public Task GetModelProfilesAsync(CancellationToken cancellationToken) + => GetAsync(_adminModelsUri, CoreJsonContext.Default.ModelProfilesStatusResponse, cancellationToken); + + public Task GetModelSelectionDoctorAsync(CancellationToken cancellationToken) + => GetAsync(_adminModelsDoctorUri, CoreJsonContext.Default.ModelSelectionDoctorResponse, cancellationToken); + + public async Task RunModelEvaluationAsync(ModelEvaluationRequest request, CancellationToken cancellationToken) + { + using var httpRequest = new HttpRequestMessage(HttpMethod.Post, _adminModelEvaluationsUri) + { + Content = BuildJsonContent(request, CoreJsonContext.Default.ModelEvaluationRequest) + }; + + return await SendAsync(httpRequest, CoreJsonContext.Default.ModelEvaluationReport, cancellationToken); + } + public async Task SimulateApprovalAsync( ApprovalSimulationRequest request, CancellationToken cancellationToken) diff --git a/src/OpenClaw.Core/Abstractions/IModelProfiles.cs b/src/OpenClaw.Core/Abstractions/IModelProfiles.cs new file mode 100644 index 0000000..e3318db --- /dev/null +++ b/src/OpenClaw.Core/Abstractions/IModelProfiles.cs @@ -0,0 +1,43 @@ +using Microsoft.Extensions.AI; +using OpenClaw.Core.Models; + +namespace OpenClaw.Core.Abstractions; + +public interface IModelProfileRegistry +{ + string? DefaultProfileId { get; } + bool TryGet(string profileId, out ModelProfile? profile); + IReadOnlyList ListStatuses(); +} + +public sealed class ModelSelectionRequest +{ + public string? ExplicitProfileId { get; init; } + public required Session Session { get; init; } + public required IReadOnlyList Messages { get; init; } + public ChatOptions? Options { get; init; } + public bool Streaming { get; init; } +} + +public sealed class ModelSelectionCandidate +{ + public required ModelProfile Profile { get; init; } + public string[] FallbackModels { get; init; } = []; +} + +public sealed class ModelSelectionResult +{ + public string? RequestedProfileId { get; init; } + public string? SelectedProfileId { get; init; } + public required string ProviderId { get; init; } + public required string ModelId { get; init; } + public required ModelSelectionRequirements Requirements { get; init; } + public IReadOnlyList Candidates { get; init; } = []; + public string[] PreferredTags { get; init; } = []; + public string? Explanation { get; init; } +} + +public interface IModelSelectionPolicy +{ + ModelSelectionResult Resolve(ModelSelectionRequest request); +} diff --git a/src/OpenClaw.Core/Models/GatewayConfig.cs b/src/OpenClaw.Core/Models/GatewayConfig.cs index de23cc0..490828f 100644 --- a/src/OpenClaw.Core/Models/GatewayConfig.cs +++ b/src/OpenClaw.Core/Models/GatewayConfig.cs @@ -13,6 +13,7 @@ public sealed class GatewayConfig public string? AuthToken { get; set; } public RuntimeConfig Runtime { get; set; } = new(); public LlmProviderConfig Llm { get; set; } = new(); + public ModelsConfig Models { get; set; } = new(); public MemoryConfig Memory { get; set; } = new(); public SecurityConfig Security { get; set; } = new(); public WebSocketConfig WebSocket { get; set; } = new(); @@ -609,6 +610,10 @@ public sealed class AgentRouteConfig public string? SenderId { get; set; } public string? SystemPrompt { get; set; } public string? ModelOverride { get; set; } + public string? ModelProfileId { get; set; } + public string[] PreferredModelTags { get; set; } = []; + public string[] FallbackModelProfileIds { get; set; } = []; + public ModelSelectionRequirements ModelRequirements { get; set; } = new(); public string? PresetId { get; set; } public string[] AllowedTools { get; set; } = []; } diff --git a/src/OpenClaw.Core/Models/IntegrationApiModels.cs b/src/OpenClaw.Core/Models/IntegrationApiModels.cs index d4573e0..ed0c43d 100644 --- a/src/OpenClaw.Core/Models/IntegrationApiModels.cs +++ b/src/OpenClaw.Core/Models/IntegrationApiModels.cs @@ -98,6 +98,7 @@ public sealed class IntegrationApprovalHistoryResponse public sealed class IntegrationProvidersResponse { + public ModelProfilesStatusResponse? ModelProfiles { get; init; } public IReadOnlyList Routes { get; init; } = []; public IReadOnlyList Usage { get; init; } = []; public IReadOnlyList Policies { get; init; } = []; diff --git a/src/OpenClaw.Core/Models/ModelProfiles.cs b/src/OpenClaw.Core/Models/ModelProfiles.cs new file mode 100644 index 0000000..fe240b8 --- /dev/null +++ b/src/OpenClaw.Core/Models/ModelProfiles.cs @@ -0,0 +1,147 @@ +namespace OpenClaw.Core.Models; + +public sealed class ModelsConfig +{ + public string? DefaultProfile { get; set; } + public List Profiles { get; set; } = []; +} + +public sealed class ModelProfileConfig +{ + public string Id { get; set; } = ""; + public string Provider { get; set; } = ""; + public string Model { get; set; } = ""; + public string? BaseUrl { get; set; } + public string? ApiKey { get; set; } + public string[] Tags { get; set; } = []; + public string[] FallbackProfileIds { get; set; } = []; + public string[] FallbackModels { get; set; } = []; + public ModelCapabilities Capabilities { get; set; } = new(); +} + +public sealed class ModelCapabilities +{ + public bool SupportsTools { get; set; } + public bool SupportsVision { get; set; } + public bool SupportsJsonSchema { get; set; } + public bool SupportsStructuredOutputs { get; set; } + public bool SupportsStreaming { get; set; } = true; + public bool SupportsParallelToolCalls { get; set; } + public bool SupportsReasoningEffort { get; set; } + public bool SupportsSystemMessages { get; set; } = true; + public bool SupportsImageInput { get; set; } + public bool SupportsAudioInput { get; set; } + public int MaxContextTokens { get; set; } + public int MaxOutputTokens { get; set; } +} + +public sealed class ModelSelectionRequirements +{ + public bool? SupportsTools { get; set; } + public bool? SupportsVision { get; set; } + public bool? SupportsJsonSchema { get; set; } + public bool? SupportsStructuredOutputs { get; set; } + public bool? SupportsStreaming { get; set; } + public bool? SupportsParallelToolCalls { get; set; } + public bool? SupportsReasoningEffort { get; set; } + public bool? SupportsSystemMessages { get; set; } + public bool? SupportsImageInput { get; set; } + public bool? SupportsAudioInput { get; set; } + public int? MinContextTokens { get; set; } + public int? MinOutputTokens { get; set; } +} + +public sealed class ModelProfile +{ + public required string Id { get; init; } + public required string ProviderId { get; init; } + public required string ModelId { get; init; } + public string? BaseUrl { get; init; } + public string? ApiKey { get; init; } + public string[] Tags { get; init; } = []; + public string[] FallbackProfileIds { get; init; } = []; + public string[] FallbackModels { get; init; } = []; + public required ModelCapabilities Capabilities { get; init; } + public bool IsImplicit { get; init; } +} + +public sealed class ModelProfileStatus +{ + public required string Id { get; init; } + public required string ProviderId { get; init; } + public required string ModelId { get; init; } + public bool IsDefault { get; init; } + public bool IsImplicit { get; init; } + public bool IsAvailable { get; init; } + public string[] Tags { get; init; } = []; + public required ModelCapabilities Capabilities { get; init; } + public string[] ValidationIssues { get; init; } = []; + public string[] FallbackProfileIds { get; init; } = []; + public string[] FallbackModels { get; init; } = []; +} + +public sealed class ModelSelectionDescriptor +{ + public string? ProfileId { get; set; } + public string[] PreferredTags { get; set; } = []; + public string[] FallbackProfileIds { get; set; } = []; + public ModelSelectionRequirements Requirements { get; set; } = new(); +} + +public sealed class ModelProfilesStatusResponse +{ + public string? DefaultProfileId { get; init; } + public IReadOnlyList Profiles { get; init; } = []; +} + +public sealed class ModelSelectionDoctorResponse +{ + public string? DefaultProfileId { get; init; } + public IReadOnlyList Errors { get; init; } = []; + public IReadOnlyList Warnings { get; init; } = []; + public IReadOnlyList Profiles { get; init; } = []; +} + +public sealed class ModelEvaluationRequest +{ + public string? ProfileId { get; init; } + public string[] ProfileIds { get; init; } = []; + public string[] ScenarioIds { get; init; } = []; + public bool IncludeMarkdown { get; init; } = true; +} + +public sealed class ModelEvaluationScenarioResult +{ + public required string ScenarioId { get; init; } + public required string Name { get; init; } + public string Status { get; init; } = "unknown"; + public string? Summary { get; init; } + public long LatencyMs { get; init; } + public long InputTokens { get; init; } + public long OutputTokens { get; init; } + public bool MalformedJson { get; init; } + public int ToolCalls { get; init; } + public string? Error { get; init; } +} + +public sealed class ModelEvaluationProfileReport +{ + public required string ProfileId { get; init; } + public required string ProviderId { get; init; } + public required string ModelId { get; init; } + public DateTimeOffset StartedAtUtc { get; init; } + public DateTimeOffset CompletedAtUtc { get; init; } + public IReadOnlyList Scenarios { get; init; } = []; +} + +public sealed class ModelEvaluationReport +{ + public required string RunId { get; init; } + public DateTimeOffset StartedAtUtc { get; init; } + public DateTimeOffset CompletedAtUtc { get; init; } + public IReadOnlyList ScenarioIds { get; init; } = []; + public IReadOnlyList Profiles { get; init; } = []; + public string? JsonPath { get; init; } + public string? MarkdownPath { get; init; } + public string? Markdown { get; init; } +} diff --git a/src/OpenClaw.Core/Models/ModelSelectionException.cs b/src/OpenClaw.Core/Models/ModelSelectionException.cs new file mode 100644 index 0000000..0725491 --- /dev/null +++ b/src/OpenClaw.Core/Models/ModelSelectionException.cs @@ -0,0 +1,9 @@ +namespace OpenClaw.Core.Models; + +public sealed class ModelSelectionException : InvalidOperationException +{ + public ModelSelectionException(string message) + : base(message) + { + } +} diff --git a/src/OpenClaw.Core/Models/OperatorApiModels.cs b/src/OpenClaw.Core/Models/OperatorApiModels.cs index b3db3a6..592747c 100644 --- a/src/OpenClaw.Core/Models/OperatorApiModels.cs +++ b/src/OpenClaw.Core/Models/OperatorApiModels.cs @@ -45,11 +45,14 @@ public sealed class ProviderPolicyListResponse public sealed class ProviderRouteHealthSnapshot { + public string? ProfileId { get; init; } public required string ProviderId { get; init; } public required string ModelId { get; init; } public bool IsDefaultRoute { get; init; } public bool IsDynamic { get; init; } public string? OwnerId { get; init; } + public string[] Tags { get; init; } = []; + public string[] ValidationIssues { get; init; } = []; public string CircuitState { get; init; } = "Closed"; public long Requests { get; init; } public long Retries { get; init; } @@ -73,6 +76,7 @@ public sealed class ProviderTurnUsageEntry public sealed class ProviderAdminResponse { public IReadOnlyList Routes { get; init; } = []; + public ModelProfilesStatusResponse? ModelProfiles { get; init; } public IReadOnlyList Usage { get; init; } = []; public IReadOnlyList Policies { get; init; } = []; public IReadOnlyList RecentTurns { get; init; } = []; diff --git a/src/OpenClaw.Core/Models/Session.cs b/src/OpenClaw.Core/Models/Session.cs index fef67c6..1423798 100644 --- a/src/OpenClaw.Core/Models/Session.cs +++ b/src/OpenClaw.Core/Models/Session.cs @@ -28,6 +28,18 @@ public sealed class Session /// Optional model override for this specific session (set via /model command). public string? ModelOverride { get; set; } + /// Optional named model profile selected for this session or route. + public string? ModelProfileId { get; set; } + + /// Optional route/session profile preferences used by profile-aware model selection. + public string[] PreferredModelTags { get; set; } = []; + + /// Optional route/session fallback profile order used when the selected profile lacks required capabilities. + public string[] FallbackModelProfileIds { get; set; } = []; + + /// Optional route/session capability requirements used during profile-aware model selection. + public ModelSelectionRequirements ModelRequirements { get; set; } = new(); + /// Optional route-scoped system prompt appended by gateway routing before runtime execution. public string? SystemPromptOverride { get; set; } @@ -123,6 +135,24 @@ public sealed record ToolInvocation [JsonSerializable(typeof(RuntimeConfig))] [JsonSerializable(typeof(GatewayRuntimeState))] [JsonSerializable(typeof(LlmProviderConfig))] +[JsonSerializable(typeof(ModelsConfig))] +[JsonSerializable(typeof(ModelProfileConfig))] +[JsonSerializable(typeof(List))] +[JsonSerializable(typeof(ModelCapabilities))] +[JsonSerializable(typeof(ModelSelectionRequirements))] +[JsonSerializable(typeof(ModelProfile))] +[JsonSerializable(typeof(List))] +[JsonSerializable(typeof(ModelProfileStatus))] +[JsonSerializable(typeof(List))] +[JsonSerializable(typeof(ModelProfilesStatusResponse))] +[JsonSerializable(typeof(ModelSelectionDoctorResponse))] +[JsonSerializable(typeof(ModelSelectionDescriptor))] +[JsonSerializable(typeof(ModelEvaluationRequest))] +[JsonSerializable(typeof(ModelEvaluationScenarioResult))] +[JsonSerializable(typeof(List))] +[JsonSerializable(typeof(ModelEvaluationProfileReport))] +[JsonSerializable(typeof(List))] +[JsonSerializable(typeof(ModelEvaluationReport))] [JsonSerializable(typeof(TokenCostRateConfig))] [JsonSerializable(typeof(Dictionary))] [JsonSerializable(typeof(MemoryConfig))] diff --git a/src/OpenClaw.Core/Validation/ConfigValidator.cs b/src/OpenClaw.Core/Validation/ConfigValidator.cs index a3ad8ff..413a4dd 100644 --- a/src/OpenClaw.Core/Validation/ConfigValidator.cs +++ b/src/OpenClaw.Core/Validation/ConfigValidator.cs @@ -52,6 +52,7 @@ public static IReadOnlyList Validate(Models.GatewayConfig config) errors.Add($"Llm.CircuitBreakerThreshold must be >= 1 (got {config.Llm.CircuitBreakerThreshold})."); if (config.Llm.CircuitBreakerCooldownSeconds < 1) errors.Add($"Llm.CircuitBreakerCooldownSeconds must be >= 1 (got {config.Llm.CircuitBreakerCooldownSeconds})."); + ValidateModelProfiles(config, errors, pluginBackedProvidersPossible); // Memory if (string.IsNullOrWhiteSpace(config.Memory.StoragePath)) @@ -510,6 +511,55 @@ private static void ValidateRootSet(string field, string[] roots, ICollection errors, bool pluginBackedProvidersPossible) + { + if (config.Models.Profiles.Count == 0) + return; + + var profileIds = new HashSet(StringComparer.OrdinalIgnoreCase); + foreach (var profile in config.Models.Profiles) + { + if (string.IsNullOrWhiteSpace(profile.Id)) + { + errors.Add("Models.Profiles[].Id must be set."); + continue; + } + + if (!profileIds.Add(profile.Id)) + errors.Add($"Models.Profiles contains duplicate id '{profile.Id}'."); + + if (string.IsNullOrWhiteSpace(profile.Provider)) + errors.Add($"Models.Profiles.{profile.Id}.Provider must be set."); + else if (!pluginBackedProvidersPossible && !BuiltInLlmProviders.Contains(profile.Provider)) + errors.Add($"Models.Profiles.{profile.Id}.Provider '{profile.Provider}' is not a supported built-in provider."); + + if (string.IsNullOrWhiteSpace(profile.Model)) + errors.Add($"Models.Profiles.{profile.Id}.Model must be set."); + if (profile.Capabilities.MaxContextTokens < 0) + errors.Add($"Models.Profiles.{profile.Id}.Capabilities.MaxContextTokens must be >= 0."); + if (profile.Capabilities.MaxOutputTokens < 0) + errors.Add($"Models.Profiles.{profile.Id}.Capabilities.MaxOutputTokens must be >= 0."); + } + + if (!string.IsNullOrWhiteSpace(config.Models.DefaultProfile) && + !profileIds.Contains(config.Models.DefaultProfile)) + { + errors.Add($"Models.DefaultProfile '{config.Models.DefaultProfile}' does not exist in Models.Profiles."); + } + + foreach (var (routeId, route) in config.Routing.Routes) + { + if (!string.IsNullOrWhiteSpace(route.ModelProfileId) && !profileIds.Contains(route.ModelProfileId)) + errors.Add($"Routing.Routes.{routeId}.ModelProfileId '{route.ModelProfileId}' does not exist in Models.Profiles."); + + foreach (var fallbackId in route.FallbackModelProfileIds.Where(static item => !string.IsNullOrWhiteSpace(item))) + { + if (!profileIds.Contains(fallbackId)) + errors.Add($"Routing.Routes.{routeId}.FallbackModelProfileIds contains unknown profile '{fallbackId}'."); + } + } + } + private static string ResolveConfiguredPath(string? path) => ConfigPathResolver.Resolve(path); } diff --git a/src/OpenClaw.Core/Validation/DoctorCheck.cs b/src/OpenClaw.Core/Validation/DoctorCheck.cs index 81ac818..d577b64 100644 --- a/src/OpenClaw.Core/Validation/DoctorCheck.cs +++ b/src/OpenClaw.Core/Validation/DoctorCheck.cs @@ -24,6 +24,22 @@ public static async Task RunAsync(GatewayConfig config, GatewayRuntimeStat allPassed &= Check("LLM API Key configured", () => !string.IsNullOrWhiteSpace(config.Llm.ApiKey)); allPassed &= Check("LLM max tokens > 0", () => config.Llm.MaxTokens > 0); + allPassed &= Check( + "Model profile configuration is internally consistent", + () => + { + var profileIds = config.Models.Profiles + .Where(static profile => !string.IsNullOrWhiteSpace(profile.Id)) + .Select(static profile => profile.Id) + .Distinct(StringComparer.OrdinalIgnoreCase) + .Count(); + return config.Models.Profiles.Count == 0 || + (profileIds == config.Models.Profiles.Count && + (string.IsNullOrWhiteSpace(config.Models.DefaultProfile) || + config.Models.Profiles.Any(profile => string.Equals(profile.Id, config.Models.DefaultProfile, StringComparison.OrdinalIgnoreCase)))); + }, + warnOnly: false, + detail: "Check Models.DefaultProfile, duplicate profile ids, and route profile references."); var workspaceRoot = ResolveConfiguredPath(config.Tooling.WorkspaceRoot); if (config.Tooling.WorkspaceOnly) diff --git a/src/OpenClaw.Gateway/Composition/CoreServicesExtensions.cs b/src/OpenClaw.Gateway/Composition/CoreServicesExtensions.cs index c2c7533..f68dfc2 100644 --- a/src/OpenClaw.Gateway/Composition/CoreServicesExtensions.cs +++ b/src/OpenClaw.Gateway/Composition/CoreServicesExtensions.cs @@ -12,6 +12,7 @@ using OpenClaw.Core.Sessions; using OpenClaw.Gateway.Bootstrap; using OpenClaw.Gateway.Extensions; +using OpenClaw.Gateway.Models; namespace OpenClaw.Gateway.Composition; @@ -36,6 +37,10 @@ public static IServiceCollection AddOpenClawCoreServices(this IServiceCollection services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(sp => sp.GetRequiredService()); + services.AddSingleton(); + services.AddSingleton(); services.AddSingleton(sp => new ProviderPolicyService( config.Memory.StoragePath, diff --git a/src/OpenClaw.Gateway/Composition/IntegrationApiFacade.cs b/src/OpenClaw.Gateway/Composition/IntegrationApiFacade.cs index f759fad..ac2aeba 100644 --- a/src/OpenClaw.Gateway/Composition/IntegrationApiFacade.cs +++ b/src/OpenClaw.Gateway/Composition/IntegrationApiFacade.cs @@ -180,6 +180,11 @@ public IntegrationApprovalHistoryResponse GetApprovalHistory(ApprovalHistoryQuer public IntegrationProvidersResponse GetProviders(int recentTurnsLimit) => new() { + ModelProfiles = new ModelProfilesStatusResponse + { + DefaultProfileId = _runtime.Operations.ModelProfiles.DefaultProfileId, + Profiles = _runtime.Operations.ModelProfiles.ListStatuses() + }, Routes = _runtime.Operations.LlmExecution.SnapshotRoutes(), Usage = _runtime.ProviderUsage.Snapshot(), Policies = _runtime.Operations.ProviderPolicies.List(), diff --git a/src/OpenClaw.Gateway/Composition/RuntimeInitializationExtensions.cs b/src/OpenClaw.Gateway/Composition/RuntimeInitializationExtensions.cs index ba909c5..3a9937d 100644 --- a/src/OpenClaw.Gateway/Composition/RuntimeInitializationExtensions.cs +++ b/src/OpenClaw.Gateway/Composition/RuntimeInitializationExtensions.cs @@ -21,6 +21,7 @@ using OpenClaw.Gateway; using OpenClaw.Gateway.Bootstrap; using OpenClaw.Gateway.Extensions; +using OpenClaw.Gateway.Models; using OpenClaw.Gateway.Profiles; using OpenClaw.Gateway.Tools; @@ -217,6 +218,7 @@ private static RuntimeServices ResolveRuntimeServices(WebApplication app) ApprovalAuditStore = app.Services.GetRequiredService(), RuntimeMetrics = app.Services.GetRequiredService(), ProviderUsage = app.Services.GetRequiredService(), + ModelProfiles = app.Services.GetRequiredService(), ProviderRegistry = app.Services.GetRequiredService(), ProviderPolicies = app.Services.GetRequiredService(), LlmExecutionService = app.Services.GetRequiredService(), @@ -426,6 +428,7 @@ private static GatewayAppRuntime CreateGatewayRuntime( { var operations = new RuntimeOperationsState { + ModelProfiles = services.ModelProfiles, ProviderPolicies = services.ProviderPolicies, ProviderRegistry = services.ProviderRegistry, LlmExecution = services.LlmExecutionService, @@ -1033,6 +1036,7 @@ private sealed class RuntimeServices public required ApprovalAuditStore ApprovalAuditStore { get; init; } public required RuntimeMetrics RuntimeMetrics { get; init; } public required ProviderUsageTracker ProviderUsage { get; init; } + public required ConfiguredModelProfileRegistry ModelProfiles { get; init; } public required LlmProviderRegistry ProviderRegistry { get; init; } public required ProviderPolicyService ProviderPolicies { get; init; } public required GatewayLlmExecutionService LlmExecutionService { get; init; } diff --git a/src/OpenClaw.Gateway/Endpoints/AdminEndpoints.cs b/src/OpenClaw.Gateway/Endpoints/AdminEndpoints.cs index 9d93ba4..1e0556d 100644 --- a/src/OpenClaw.Gateway/Endpoints/AdminEndpoints.cs +++ b/src/OpenClaw.Gateway/Endpoints/AdminEndpoints.cs @@ -14,6 +14,7 @@ using OpenClaw.Gateway; using OpenClaw.Gateway.Bootstrap; using OpenClaw.Gateway.Composition; +using OpenClaw.Gateway.Models; using QRCoder; namespace OpenClaw.Gateway.Endpoints; @@ -37,6 +38,12 @@ public static void MapOpenClawAdminEndpoints( var facade = IntegrationApiFacade.Create(startup, runtime, app.Services); var sessionAdminStore = (ISessionAdminStore)app.Services.GetRequiredService(); var operations = runtime.Operations; + var modelEvaluationRunner = app.Services.GetService() + ?? new ModelEvaluationRunner( + operations.ModelProfiles as ConfiguredModelProfileRegistry + ?? new ConfiguredModelProfileRegistry(startup.Config, NullLogger.Instance), + startup.Config, + NullLogger.Instance); app.MapGet("/auth/session", (HttpContext ctx) => { @@ -786,6 +793,11 @@ public static void MapOpenClawAdminEndpoints( return Results.Json(new ProviderAdminResponse { + ModelProfiles = new ModelProfilesStatusResponse + { + DefaultProfileId = operations.ModelProfiles.DefaultProfileId, + Profiles = operations.ModelProfiles.ListStatuses() + }, Routes = operations.LlmExecution.SnapshotRoutes(), Usage = runtime.ProviderUsage.Snapshot(), Policies = operations.ProviderPolicies.List(), @@ -793,6 +805,45 @@ public static void MapOpenClawAdminEndpoints( }, CoreJsonContext.Default.ProviderAdminResponse); }); + app.MapGet("/admin/models", (HttpContext ctx) => + { + var authResult = AuthorizeOperator(ctx, startup, browserSessions, operations, requireCsrf: false, endpointScope: "admin.models"); + if (authResult.Failure is not null) + return authResult.Failure; + + return Results.Json( + new ModelProfilesStatusResponse + { + DefaultProfileId = operations.ModelProfiles.DefaultProfileId, + Profiles = operations.ModelProfiles.ListStatuses() + }, + CoreJsonContext.Default.ModelProfilesStatusResponse); + }); + + app.MapGet("/admin/models/doctor", (HttpContext ctx) => + { + var authResult = AuthorizeOperator(ctx, startup, browserSessions, operations, requireCsrf: false, endpointScope: "admin.models.doctor"); + if (authResult.Failure is not null) + return authResult.Failure; + + return Results.Json(modelEvaluationRunner.BuildDoctor(), CoreJsonContext.Default.ModelSelectionDoctorResponse); + }); + + app.MapPost("/admin/models/evaluations", async (HttpContext ctx) => + { + var authResult = AuthorizeOperator(ctx, startup, browserSessions, operations, requireCsrf: true, endpointScope: "admin.models.evaluate"); + if (authResult.Failure is not null) + return authResult.Failure; + + var requestPayload = await ReadJsonBodyAsync(ctx, CoreJsonContext.Default.ModelEvaluationRequest); + if (requestPayload.Failure is not null) + return requestPayload.Failure; + + var request = requestPayload.Value ?? new ModelEvaluationRequest(); + var report = await modelEvaluationRunner.RunAsync(request, ctx.RequestAborted); + return Results.Json(report, CoreJsonContext.Default.ModelEvaluationReport); + }); + app.MapGet("/admin/providers/policies", (HttpContext ctx) => { var authResult = AuthorizeOperator(ctx, startup, browserSessions, operations, requireCsrf: false, endpointScope: "admin.provider-policies"); diff --git a/src/OpenClaw.Gateway/Endpoints/OpenAiEndpoints.cs b/src/OpenClaw.Gateway/Endpoints/OpenAiEndpoints.cs index c6255d1..ef897b0 100644 --- a/src/OpenClaw.Gateway/Endpoints/OpenAiEndpoints.cs +++ b/src/OpenClaw.Gateway/Endpoints/OpenAiEndpoints.cs @@ -92,7 +92,18 @@ public static void MapOpenClawOpenAiEndpoints( return; } if (req.Model is not null) - session.ModelOverride = req.Model; + { + if (runtime.Operations.ModelProfiles.TryGet(req.Model, out _)) + { + session.ModelProfileId = req.Model; + session.ModelOverride = null; + } + else + { + session.ModelOverride = req.Model; + session.ModelProfileId = null; + } + } var presetHeader = ctx.Request.Headers.TryGetValue("X-OpenClaw-Preset", out var presetValues) ? presetValues.ToString() : null; @@ -373,7 +384,18 @@ await ctx.Response.WriteAsync( return; } if (req.Model is not null) - session.ModelOverride = req.Model; + { + if (runtime.Operations.ModelProfiles.TryGet(req.Model, out _)) + { + session.ModelProfileId = req.Model; + session.ModelOverride = null; + } + else + { + session.ModelOverride = req.Model; + session.ModelProfileId = null; + } + } var responsesPresetHeader = ctx.Request.Headers.TryGetValue("X-OpenClaw-Preset", out var responsesPresetValues) ? responsesPresetValues.ToString() : null; diff --git a/src/OpenClaw.Gateway/Extensions/GatewayWorkers.cs b/src/OpenClaw.Gateway/Extensions/GatewayWorkers.cs index 3d275f7..e095ff6 100644 --- a/src/OpenClaw.Gateway/Extensions/GatewayWorkers.cs +++ b/src/OpenClaw.Gateway/Extensions/GatewayWorkers.cs @@ -409,6 +409,20 @@ await pipeline.OutboundWriter.WriteAsync(new OutboundMessage session.ModelOverride = string.IsNullOrWhiteSpace(resolvedRoute.ModelOverride) ? session.ModelOverride : resolvedRoute.ModelOverride.Trim(); + session.ModelProfileId = string.IsNullOrWhiteSpace(resolvedRoute.ModelProfileId) + ? null + : resolvedRoute.ModelProfileId.Trim(); + session.PreferredModelTags = resolvedRoute.PreferredModelTags + .Where(static item => !string.IsNullOrWhiteSpace(item)) + .Select(static item => item.Trim()) + .Distinct(StringComparer.OrdinalIgnoreCase) + .ToArray(); + session.FallbackModelProfileIds = resolvedRoute.FallbackModelProfileIds + .Where(static item => !string.IsNullOrWhiteSpace(item)) + .Select(static item => item.Trim()) + .Distinct(StringComparer.OrdinalIgnoreCase) + .ToArray(); + session.ModelRequirements = resolvedRoute.ModelRequirements ?? new ModelSelectionRequirements(); session.SystemPromptOverride = string.IsNullOrWhiteSpace(resolvedRoute.SystemPrompt) ? null : resolvedRoute.SystemPrompt.Trim(); @@ -423,6 +437,10 @@ await pipeline.OutboundWriter.WriteAsync(new OutboundMessage } else { + session.ModelProfileId = null; + session.PreferredModelTags = []; + session.FallbackModelProfileIds = []; + session.ModelRequirements = new ModelSelectionRequirements(); session.SystemPromptOverride = null; session.RoutePresetId = null; session.RouteAllowedTools = []; diff --git a/src/OpenClaw.Gateway/GatewayLlmExecutionService.cs b/src/OpenClaw.Gateway/GatewayLlmExecutionService.cs index 3d903ee..92470c9 100644 --- a/src/OpenClaw.Gateway/GatewayLlmExecutionService.cs +++ b/src/OpenClaw.Gateway/GatewayLlmExecutionService.cs @@ -1,9 +1,12 @@ using System.Collections.Concurrent; using Microsoft.Extensions.AI; using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; using OpenClaw.Agent; +using OpenClaw.Core.Abstractions; using OpenClaw.Core.Models; using OpenClaw.Core.Observability; +using OpenClaw.Gateway.Models; namespace OpenClaw.Gateway; @@ -20,7 +23,8 @@ private sealed class RouteState } private readonly GatewayConfig _config; - private readonly LlmProviderRegistry _registry; + private readonly ConfiguredModelProfileRegistry _modelProfiles; + private readonly IModelSelectionPolicy _selectionPolicy; private readonly ProviderPolicyService _policyService; private readonly RuntimeEventStore _eventStore; private readonly RuntimeMetrics _runtimeMetrics; @@ -30,7 +34,8 @@ private sealed class RouteState public GatewayLlmExecutionService( GatewayConfig config, - LlmProviderRegistry registry, + ConfiguredModelProfileRegistry modelProfiles, + IModelSelectionPolicy selectionPolicy, ProviderPolicyService policyService, RuntimeEventStore eventStore, RuntimeMetrics runtimeMetrics, @@ -38,7 +43,8 @@ public GatewayLlmExecutionService( ILogger logger) { _config = config; - _registry = registry; + _modelProfiles = modelProfiles; + _selectionPolicy = selectionPolicy; _policyService = policyService; _eventStore = eventStore; _runtimeMetrics = runtimeMetrics; @@ -46,40 +52,59 @@ public GatewayLlmExecutionService( _logger = logger; } + public GatewayLlmExecutionService( + GatewayConfig config, + LlmProviderRegistry registry, + ProviderPolicyService policyService, + RuntimeEventStore eventStore, + RuntimeMetrics runtimeMetrics, + ProviderUsageTracker providerUsage, + ILogger logger) + : this( + config, + new ConfiguredModelProfileRegistry(config, NullLogger.Instance), + new DefaultModelSelectionPolicy(new ConfiguredModelProfileRegistry(config, NullLogger.Instance)), + policyService, + eventStore, + runtimeMetrics, + providerUsage, + logger) + { + } + public CircuitState DefaultCircuitState - => GetRouteState(_config.Llm.Provider, _config.Llm.Model).CircuitBreaker.State; + => GetRouteState( + _modelProfiles.DefaultProfileId ?? "default", + _config.Llm.Provider, + _config.Llm.Model).CircuitBreaker.State; public IReadOnlyList SnapshotRoutes() - => _registry.Snapshot() - .SelectMany(registration => + => _modelProfiles.ListStatuses() + .Select(profile => { - var models = registration.Models.Length > 0 ? registration.Models : [_config.Llm.Model]; - return models.Distinct(StringComparer.OrdinalIgnoreCase).Select(modelId => + var state = GetRouteState(profile.Id, profile.ProviderId, profile.ModelId); + return new ProviderRouteHealthSnapshot { - var state = GetRouteState(registration.ProviderId, modelId); - return new ProviderRouteHealthSnapshot - { - ProviderId = registration.ProviderId, - ModelId = modelId, - IsDefaultRoute = registration.IsDefault && string.Equals(modelId, _config.Llm.Model, StringComparison.OrdinalIgnoreCase), - IsDynamic = registration.IsDynamic, - OwnerId = registration.OwnerId, - CircuitState = state.CircuitBreaker.State.ToString(), - Requests = Interlocked.Read(ref state.Requests), - Retries = Interlocked.Read(ref state.Retries), - Errors = Interlocked.Read(ref state.Errors), - LastError = state.LastError, - LastErrorAtUtc = state.LastErrorAtUtc - }; - }); + ProfileId = profile.Id, + ProviderId = profile.ProviderId, + ModelId = profile.ModelId, + IsDefaultRoute = profile.IsDefault, + CircuitState = state.CircuitBreaker.State.ToString(), + Requests = Interlocked.Read(ref state.Requests), + Retries = Interlocked.Read(ref state.Retries), + Errors = Interlocked.Read(ref state.Errors), + LastError = state.LastError, + LastErrorAtUtc = state.LastErrorAtUtc, + Tags = profile.Tags, + ValidationIssues = profile.ValidationIssues + }; }) - .OrderBy(static item => item.ProviderId, StringComparer.OrdinalIgnoreCase) - .ThenBy(static item => item.ModelId, StringComparer.OrdinalIgnoreCase) + .OrderBy(static item => item.ProfileId, StringComparer.OrdinalIgnoreCase) .ToArray(); public void ResetProvider(string providerId) { - foreach (var key in _routes.Keys.Where(key => key.StartsWith(providerId + ":", StringComparison.OrdinalIgnoreCase)).ToArray()) + foreach (var key in _routes.Keys.Where(key => key.Contains($":{providerId}:", StringComparison.OrdinalIgnoreCase)).ToArray()) { if (_routes.TryRemove(key, out var state)) state.CircuitBreaker.Reset(); @@ -94,98 +119,114 @@ public async Task GetResponseAsync( LlmExecutionEstimate estimate, CancellationToken ct) { - var resolved = _policyService.Resolve(session, _config.Llm); - var effectiveOptions = CreateEffectiveOptions(options, resolved, estimate); - var modelsToTry = new[] { resolved.ModelId } - .Concat(resolved.FallbackModels.Where(static item => !string.IsNullOrWhiteSpace(item))) - .Distinct(StringComparer.OrdinalIgnoreCase) - .ToArray(); + var selection = ResolveSelection(session, messages, options, streaming: false); + var legacyPolicy = _policyService.Resolve(session, _config.Llm); - RecordEvent(session, turnContext, "llm", "route_selected", "info", $"Selected provider route {resolved.ProviderId}/{resolved.ModelId}", new() + RecordEvent(session, turnContext, "llm", "route_selected", "info", $"Selected provider route {selection.ProviderId}/{selection.ModelId}", new() { - ["providerId"] = resolved.ProviderId, - ["modelId"] = resolved.ModelId, - ["policyRuleId"] = resolved.RuleId ?? "" + ["providerId"] = selection.ProviderId, + ["modelId"] = selection.ModelId, + ["profileId"] = selection.SelectedProfileId ?? "", + ["policyRuleId"] = legacyPolicy.RuleId ?? "" }); + if (!string.IsNullOrWhiteSpace(selection.Explanation)) + _logger.LogInformation("{Explanation}", selection.Explanation); Exception? lastError = null; - for (var modelIndex = 0; modelIndex < modelsToTry.Length; modelIndex++) + foreach (var candidate in selection.Candidates) { - var modelId = modelsToTry[modelIndex]; - var routeState = GetRouteState(resolved.ProviderId, modelId); - var chatClient = GetClient(resolved.ProviderId); + if (!_modelProfiles.TryGetRegistration(candidate.Profile.Id, out var registration) || registration?.Client is null) + continue; + + var modelsToTry = new[] { ResolveRequestedModelId(session, candidate.Profile) } + .Concat(candidate.FallbackModels.Where(static item => !string.IsNullOrWhiteSpace(item))) + .Distinct(StringComparer.OrdinalIgnoreCase) + .ToArray(); - for (var attempt = 0; attempt <= _config.Llm.RetryCount; attempt++) + for (var modelIndex = 0; modelIndex < modelsToTry.Length; modelIndex++) { - Interlocked.Increment(ref routeState.Requests); - _providerUsage.RecordRequest(resolved.ProviderId, modelId); + var modelId = modelsToTry[modelIndex]; + var routeState = GetRouteState(candidate.Profile.Id, candidate.Profile.ProviderId, modelId); + var chatClient = registration.Client; + var effectiveOptions = CreateEffectiveOptions(options, candidate.Profile, registration.ProviderConfig, legacyPolicy, estimate); - if (attempt > 0 || modelIndex > 0) + for (var attempt = 0; attempt <= registration.ProviderConfig.RetryCount; attempt++) { - Interlocked.Increment(ref routeState.Retries); - turnContext.RecordRetry(); - _runtimeMetrics.IncrementLlmRetries(); - _providerUsage.RecordRetry(resolved.ProviderId, modelId); - var delayMs = Math.Min(4_000, (int)Math.Pow(2, attempt + modelIndex) * 500); - await Task.Delay(delayMs, ct); - } + Interlocked.Increment(ref routeState.Requests); + _providerUsage.RecordRequest(candidate.Profile.ProviderId, modelId); - try - { - RecordEvent(session, turnContext, "llm", "request_started", "info", $"LLM request started for {resolved.ProviderId}/{modelId}", new() + if (attempt > 0 || modelIndex > 0) { - ["providerId"] = resolved.ProviderId, - ["modelId"] = modelId - }); - - effectiveOptions.ModelId = modelId; - var response = await routeState.CircuitBreaker.ExecuteAsync(async innerCt => + Interlocked.Increment(ref routeState.Retries); + turnContext.RecordRetry(); + _runtimeMetrics.IncrementLlmRetries(); + _providerUsage.RecordRetry(candidate.Profile.ProviderId, modelId); + var delayMs = Math.Min(4_000, (int)Math.Pow(2, attempt + modelIndex) * 500); + await Task.Delay(delayMs, ct); + } + + try { - if (_config.Llm.TimeoutSeconds > 0) + RecordEvent(session, turnContext, "llm", "request_started", "info", $"LLM request started for {candidate.Profile.ProviderId}/{modelId}", new() + { + ["providerId"] = candidate.Profile.ProviderId, + ["modelId"] = modelId, + ["profileId"] = candidate.Profile.Id + }); + + effectiveOptions.ModelId = modelId; + var response = await routeState.CircuitBreaker.ExecuteAsync(async innerCt => { - using var timeoutCts = CancellationTokenSource.CreateLinkedTokenSource(innerCt); - timeoutCts.CancelAfter(TimeSpan.FromSeconds(_config.Llm.TimeoutSeconds)); - return await chatClient.GetResponseAsync(messages, effectiveOptions, timeoutCts.Token); - } + if (registration.ProviderConfig.TimeoutSeconds > 0) + { + using var timeoutCts = CancellationTokenSource.CreateLinkedTokenSource(innerCt); + timeoutCts.CancelAfter(TimeSpan.FromSeconds(registration.ProviderConfig.TimeoutSeconds)); + return await chatClient.GetResponseAsync(messages, effectiveOptions, timeoutCts.Token); + } - return await chatClient.GetResponseAsync(messages, effectiveOptions, innerCt); - }, ct); + return await chatClient.GetResponseAsync(messages, effectiveOptions, innerCt); + }, ct); - RecordEvent(session, turnContext, "llm", "request_completed", "info", $"LLM request completed for {resolved.ProviderId}/{modelId}", new() - { - ["providerId"] = resolved.ProviderId, - ["modelId"] = modelId - }); + RecordEvent(session, turnContext, "llm", "request_completed", "info", $"LLM request completed for {candidate.Profile.ProviderId}/{modelId}", new() + { + ["providerId"] = candidate.Profile.ProviderId, + ["modelId"] = modelId, + ["profileId"] = candidate.Profile.Id + }); - return new LlmExecutionResult + return new LlmExecutionResult + { + ProfileId = candidate.Profile.Id, + ProviderId = candidate.Profile.ProviderId, + ModelId = modelId, + PolicyRuleId = legacyPolicy.RuleId, + SelectionExplanation = selection.Explanation, + Response = response + }; + } + catch (OperationCanceledException) when (ct.IsCancellationRequested) { - ProviderId = resolved.ProviderId, - ModelId = modelId, - PolicyRuleId = resolved.RuleId, - Response = response - }; - } - catch (OperationCanceledException) when (ct.IsCancellationRequested) - { - throw; - } - catch (Exception ex) - { - lastError = ex; - Interlocked.Increment(ref routeState.Errors); - routeState.LastError = ex.Message; - routeState.LastErrorAtUtc = DateTimeOffset.UtcNow; - _runtimeMetrics.IncrementLlmErrors(); - _providerUsage.RecordError(resolved.ProviderId, modelId); - RecordEvent(session, turnContext, "llm", "request_failed", "error", ex.Message, new() + throw; + } + catch (Exception ex) { - ["providerId"] = resolved.ProviderId, - ["modelId"] = modelId, - ["exceptionType"] = ex.GetType().Name - }); - - if (!IsTransient(ex)) - break; + lastError = ex; + Interlocked.Increment(ref routeState.Errors); + routeState.LastError = ex.Message; + routeState.LastErrorAtUtc = DateTimeOffset.UtcNow; + _runtimeMetrics.IncrementLlmErrors(); + _providerUsage.RecordError(candidate.Profile.ProviderId, modelId); + RecordEvent(session, turnContext, "llm", "request_failed", "error", ex.Message, new() + { + ["providerId"] = candidate.Profile.ProviderId, + ["modelId"] = modelId, + ["profileId"] = candidate.Profile.Id, + ["exceptionType"] = ex.GetType().Name + }); + + if (!IsTransient(ex)) + break; + } } } } @@ -201,43 +242,56 @@ public Task StartStreamingAsync( LlmExecutionEstimate estimate, CancellationToken ct) { - var resolved = _policyService.Resolve(session, _config.Llm); - var effectiveOptions = CreateEffectiveOptions(options, resolved, estimate); - var routeState = GetRouteState(resolved.ProviderId, resolved.ModelId); - var chatClient = GetClient(resolved.ProviderId); + var selection = ResolveSelection(session, messages, options, streaming: true); + var legacyPolicy = _policyService.Resolve(session, _config.Llm); + var candidate = selection.Candidates.FirstOrDefault() + ?? throw new InvalidOperationException("No model profile candidate is available for streaming."); + if (!_modelProfiles.TryGetRegistration(candidate.Profile.Id, out var registration) || registration?.Client is null) + throw new ModelSelectionException($"Selected model profile '{candidate.Profile.Id}' is not available."); + + var effectiveOptions = CreateEffectiveOptions(options, candidate.Profile, registration.ProviderConfig, legacyPolicy, estimate); + var selectedModelId = ResolveRequestedModelId(session, candidate.Profile); + var routeState = GetRouteState(candidate.Profile.Id, candidate.Profile.ProviderId, selectedModelId); + var chatClient = registration.Client; Interlocked.Increment(ref routeState.Requests); - _providerUsage.RecordRequest(resolved.ProviderId, resolved.ModelId); - RecordEvent(session, turnContext, "llm", "route_selected", "info", $"Selected provider route {resolved.ProviderId}/{resolved.ModelId}", new() + _providerUsage.RecordRequest(candidate.Profile.ProviderId, selectedModelId); + RecordEvent(session, turnContext, "llm", "route_selected", "info", $"Selected provider route {candidate.Profile.ProviderId}/{selectedModelId}", new() { - ["providerId"] = resolved.ProviderId, - ["modelId"] = resolved.ModelId, - ["policyRuleId"] = resolved.RuleId ?? "" + ["providerId"] = candidate.Profile.ProviderId, + ["modelId"] = selectedModelId, + ["profileId"] = candidate.Profile.Id, + ["policyRuleId"] = legacyPolicy.RuleId ?? "" }); - RecordEvent(session, turnContext, "llm", "stream_started", "info", $"LLM stream started for {resolved.ProviderId}/{resolved.ModelId}", new() + RecordEvent(session, turnContext, "llm", "stream_started", "info", $"LLM stream started for {candidate.Profile.ProviderId}/{selectedModelId}", new() { - ["providerId"] = resolved.ProviderId, - ["modelId"] = resolved.ModelId, - ["policyRuleId"] = resolved.RuleId ?? "" + ["providerId"] = candidate.Profile.ProviderId, + ["modelId"] = selectedModelId, + ["profileId"] = candidate.Profile.Id, + ["policyRuleId"] = legacyPolicy.RuleId ?? "" }); - effectiveOptions.ModelId = resolved.ModelId; + effectiveOptions.ModelId = selectedModelId; IAsyncEnumerable updates = StreamWithCircuitAsync( session, turnContext, chatClient, routeState, - resolved.ProviderId, - resolved.ModelId, + candidate.Profile.ProviderId, + selectedModelId, messages, effectiveOptions, + registration.ProviderConfig.TimeoutSeconds, + candidate.Profile.Id, ct); return Task.FromResult(new LlmStreamingExecutionResult { - ProviderId = resolved.ProviderId, - ModelId = resolved.ModelId, - PolicyRuleId = resolved.RuleId, + ProfileId = candidate.Profile.Id, + ProviderId = candidate.Profile.ProviderId, + ModelId = selectedModelId, + PolicyRuleId = legacyPolicy.RuleId, + SelectionExplanation = selection.Explanation, Updates = updates }); } @@ -251,15 +305,17 @@ private async IAsyncEnumerable StreamWithCircuitAsync( string modelId, IReadOnlyList messages, ChatOptions options, + int timeoutSeconds, + string profileId, [System.Runtime.CompilerServices.EnumeratorCancellation] CancellationToken ct) { routeState.CircuitBreaker.ThrowIfOpen(); CancellationToken activeToken = ct; CancellationTokenSource? timeoutCts = null; - if (_config.Llm.TimeoutSeconds > 0) + if (timeoutSeconds > 0) { timeoutCts = CancellationTokenSource.CreateLinkedTokenSource(ct); - timeoutCts.CancelAfter(TimeSpan.FromSeconds(_config.Llm.TimeoutSeconds)); + timeoutCts.CancelAfter(TimeSpan.FromSeconds(timeoutSeconds)); activeToken = timeoutCts.Token; } @@ -295,6 +351,7 @@ private async IAsyncEnumerable StreamWithCircuitAsync( { ["providerId"] = providerId, ["modelId"] = modelId, + ["profileId"] = profileId, ["exceptionType"] = ex.GetType().Name }); throw; @@ -307,7 +364,8 @@ private async IAsyncEnumerable StreamWithCircuitAsync( RecordEvent(session, turnContext, "llm", "stream_completed", "info", $"LLM stream completed for {providerId}/{modelId}", new() { ["providerId"] = providerId, - ["modelId"] = modelId + ["modelId"] = modelId, + ["profileId"] = profileId }); } finally @@ -316,21 +374,9 @@ private async IAsyncEnumerable StreamWithCircuitAsync( } } - private IChatClient GetClient(string providerId) - { - if (!_registry.TryGet(providerId, out var registration) || registration is null) - { - throw new InvalidOperationException( - $"Provider '{providerId}' is not registered. " + - $"Available providers: {string.Join(", ", _registry.Snapshot().Select(static item => item.ProviderId))}"); - } - - return registration.Client; - } - - private RouteState GetRouteState(string providerId, string modelId) + private RouteState GetRouteState(string profileId, string providerId, string modelId) => _routes.GetOrAdd( - $"{providerId}:{modelId}", + $"{profileId}:{providerId}:{modelId}", _ => new RouteState { CircuitBreaker = new CircuitBreaker( @@ -339,26 +385,60 @@ private RouteState GetRouteState(string providerId, string modelId) _logger) }); - private ChatOptions CreateEffectiveOptions(ChatOptions source, ResolvedProviderRoute resolved, LlmExecutionEstimate estimate) + private ModelSelectionResult ResolveSelection( + Session session, + IReadOnlyList messages, + ChatOptions options, + bool streaming) + { + var explicitProfileId = !string.IsNullOrWhiteSpace(session.ModelProfileId) + ? session.ModelProfileId + : (!string.IsNullOrWhiteSpace(session.ModelOverride) && _modelProfiles.TryGet(session.ModelOverride!, out _) + ? session.ModelOverride + : null); + return _selectionPolicy.Resolve(new ModelSelectionRequest + { + ExplicitProfileId = explicitProfileId, + Session = session, + Messages = messages, + Options = options, + Streaming = streaming + }); + } + + private ChatOptions CreateEffectiveOptions( + ChatOptions source, + ModelProfile profile, + LlmProviderConfig providerConfig, + ResolvedProviderRoute legacyPolicy, + LlmExecutionEstimate estimate) { var maxOutputTokens = source.MaxOutputTokens; - if (resolved.MaxOutputTokens > 0) - maxOutputTokens = maxOutputTokens is > 0 ? Math.Min(maxOutputTokens.Value, resolved.MaxOutputTokens) : resolved.MaxOutputTokens; + if (profile.Capabilities.MaxOutputTokens > 0) + maxOutputTokens = maxOutputTokens is > 0 ? Math.Min(maxOutputTokens.Value, profile.Capabilities.MaxOutputTokens) : profile.Capabilities.MaxOutputTokens; + if (legacyPolicy.MaxOutputTokens > 0) + maxOutputTokens = maxOutputTokens is > 0 ? Math.Min(maxOutputTokens.Value, legacyPolicy.MaxOutputTokens) : legacyPolicy.MaxOutputTokens; + + if (profile.Capabilities.MaxContextTokens > 0 && estimate.EstimatedInputTokens > profile.Capabilities.MaxContextTokens) + { + throw new ModelSelectionException( + $"Selected model profile '{profile.Id}' cannot satisfy this request because estimated input tokens ({estimate.EstimatedInputTokens}) exceed MaxContextTokens ({profile.Capabilities.MaxContextTokens})."); + } - if (resolved.MaxInputTokens > 0 && estimate.EstimatedInputTokens > resolved.MaxInputTokens) + if (legacyPolicy.MaxInputTokens > 0 && estimate.EstimatedInputTokens > legacyPolicy.MaxInputTokens) { throw new InvalidOperationException( - $"Provider policy blocked this request because estimated input tokens ({estimate.EstimatedInputTokens}) exceed maxInputTokens ({resolved.MaxInputTokens})."); + $"Provider policy blocked this request because estimated input tokens ({estimate.EstimatedInputTokens}) exceed maxInputTokens ({legacyPolicy.MaxInputTokens})."); } - if (resolved.MaxTotalTokens > 0) + if (legacyPolicy.MaxTotalTokens > 0) { - var configuredOutput = maxOutputTokens ?? _config.Llm.MaxTokens; - var remaining = resolved.MaxTotalTokens - estimate.EstimatedInputTokens; + var configuredOutput = maxOutputTokens ?? providerConfig.MaxTokens; + var remaining = legacyPolicy.MaxTotalTokens - estimate.EstimatedInputTokens; if (remaining <= 0) { throw new InvalidOperationException( - $"Provider policy blocked this request because estimated total tokens would exceed maxTotalTokens ({resolved.MaxTotalTokens})."); + $"Provider policy blocked this request because estimated total tokens would exceed maxTotalTokens ({legacyPolicy.MaxTotalTokens})."); } maxOutputTokens = Math.Min(configuredOutput, (int)remaining); @@ -366,7 +446,7 @@ private ChatOptions CreateEffectiveOptions(ChatOptions source, ResolvedProviderR return new ChatOptions { - ModelId = resolved.ModelId, + ModelId = profile.ModelId, MaxOutputTokens = maxOutputTokens, Temperature = source.Temperature, Tools = source.Tools, @@ -374,6 +454,14 @@ private ChatOptions CreateEffectiveOptions(ChatOptions source, ResolvedProviderR }; } + private string ResolveRequestedModelId(Session session, ModelProfile profile) + { + if (!string.IsNullOrWhiteSpace(session.ModelOverride) && !_modelProfiles.TryGet(session.ModelOverride!, out _)) + return session.ModelOverride!.Trim(); + + return profile.ModelId; + } + private void RecordEvent( Session session, TurnContext turnContext, diff --git a/src/OpenClaw.Gateway/Models/ConfiguredModelProfileRegistry.cs b/src/OpenClaw.Gateway/Models/ConfiguredModelProfileRegistry.cs new file mode 100644 index 0000000..10715c1 --- /dev/null +++ b/src/OpenClaw.Gateway/Models/ConfiguredModelProfileRegistry.cs @@ -0,0 +1,238 @@ +using System.Collections.Concurrent; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.Logging; +using OpenClaw.Core.Abstractions; +using OpenClaw.Core.Models; +using OpenClaw.Gateway.Extensions; + +namespace OpenClaw.Gateway.Models; + +internal sealed class ConfiguredModelProfileRegistry : IModelProfileRegistry +{ + internal sealed class Registration + { + public required ModelProfile Profile { get; init; } + public required LlmProviderConfig ProviderConfig { get; init; } + public required string[] ValidationIssues { get; init; } + public IChatClient? Client { get; init; } + public bool IsDefault { get; init; } + } + + private readonly ConcurrentDictionary _registrations = new(StringComparer.OrdinalIgnoreCase); + private readonly ILogger _logger; + + public ConfiguredModelProfileRegistry(GatewayConfig config, ILogger logger) + { + _logger = logger; + DefaultProfileId = BuildRegistrations(config); + } + + public string? DefaultProfileId { get; } + + public bool TryGet(string profileId, out ModelProfile? profile) + { + if (_registrations.TryGetValue(profileId, out var registration)) + { + profile = registration.Profile; + return true; + } + + profile = null; + return false; + } + + internal bool TryGetRegistration(string profileId, out Registration? registration) + => _registrations.TryGetValue(profileId, out registration); + + public IReadOnlyList ListStatuses() + => _registrations.Values + .OrderByDescending(static item => item.IsDefault) + .ThenBy(static item => item.Profile.Id, StringComparer.OrdinalIgnoreCase) + .Select(static item => new ModelProfileStatus + { + Id = item.Profile.Id, + ProviderId = item.Profile.ProviderId, + ModelId = item.Profile.ModelId, + IsDefault = item.IsDefault, + IsImplicit = item.Profile.IsImplicit, + IsAvailable = item.Client is not null && item.ValidationIssues.Length == 0, + Tags = item.Profile.Tags, + Capabilities = item.Profile.Capabilities, + ValidationIssues = item.ValidationIssues, + FallbackProfileIds = item.Profile.FallbackProfileIds, + FallbackModels = item.Profile.FallbackModels + }) + .ToArray(); + + private string BuildRegistrations(GatewayConfig config) + { + var defaultProfileId = Normalize(config.Models.DefaultProfile); + var configs = config.Models.Profiles.Count > 0 + ? config.Models.Profiles + : [CreateImplicitConfig(config)]; + + var defaultId = defaultProfileId; + foreach (var profileConfig in configs) + { + var profile = ToProfile(config, profileConfig); + var issues = ValidateProfile(profile, config).ToArray(); + var providerConfig = BuildProviderConfig(config, profile); + IChatClient? client = null; + if (issues.Length == 0) + { + try + { + client = LlmClientFactory.CreateChatClient(providerConfig); + } + catch (Exception ex) + { + issues = [.. issues, ex.Message]; + _logger.LogWarning(ex, "Failed to initialize model profile {ProfileId}", profile.Id); + } + } + + var isDefault = + string.Equals(profile.Id, defaultProfileId, StringComparison.OrdinalIgnoreCase) || + (defaultProfileId is null && profile.IsImplicit); + _registrations[profile.Id] = new Registration + { + Profile = profile, + ProviderConfig = providerConfig, + ValidationIssues = issues, + Client = client, + IsDefault = isDefault + }; + + if (defaultId is null && profile.IsImplicit) + defaultId = profile.Id; + } + + if (defaultId is null || !_registrations.ContainsKey(defaultId)) + { + defaultId = _registrations.Keys.OrderBy(static item => item, StringComparer.OrdinalIgnoreCase).FirstOrDefault(); + if (defaultId is not null && _registrations.TryGetValue(defaultId, out var registration)) + { + _registrations[defaultId] = new Registration + { + Profile = registration.Profile, + ProviderConfig = registration.ProviderConfig, + ValidationIssues = registration.ValidationIssues, + Client = registration.Client, + IsDefault = true + }; + } + } + + return defaultId ?? "default"; + } + + private static ModelProfileConfig CreateImplicitConfig(GatewayConfig config) + => new() + { + Id = "default", + Provider = config.Llm.Provider, + Model = config.Llm.Model, + BaseUrl = config.Llm.Endpoint, + ApiKey = config.Llm.ApiKey, + FallbackModels = config.Llm.FallbackModels, + Capabilities = GuessCapabilities(config.Llm.Provider) + }; + + private static ModelCapabilities GuessCapabilities(string providerId) + { + var provider = (providerId ?? string.Empty).Trim().ToLowerInvariant(); + var supportsTools = provider is "openai" or "openai-compatible" or "azure-openai" or "groq" or "together" or "lmstudio" or "anthropic" or "claude" or "gemini" or "google"; + var supportsVision = provider is "openai" or "openai-compatible" or "azure-openai" or "gemini" or "google" or "ollama"; + return new ModelCapabilities + { + SupportsTools = supportsTools, + SupportsVision = supportsVision, + SupportsJsonSchema = provider is "openai" or "openai-compatible" or "azure-openai", + SupportsStructuredOutputs = provider is "openai" or "openai-compatible" or "azure-openai", + SupportsStreaming = true, + SupportsParallelToolCalls = provider is "openai" or "openai-compatible" or "azure-openai", + SupportsReasoningEffort = provider is "openai" or "openai-compatible" or "azure-openai", + SupportsSystemMessages = true, + SupportsImageInput = supportsVision, + SupportsAudioInput = provider is "openai" or "openai-compatible" or "azure-openai" + }; + } + + private static ModelProfile ToProfile(GatewayConfig config, ModelProfileConfig model) + => new() + { + Id = Normalize(model.Id) ?? "default", + ProviderId = Normalize(model.Provider) ?? config.Llm.Provider, + ModelId = Normalize(model.Model) ?? config.Llm.Model, + BaseUrl = Normalize(model.BaseUrl), + ApiKey = Normalize(model.ApiKey), + Tags = NormalizeDistinct(model.Tags), + FallbackProfileIds = NormalizeDistinct(model.FallbackProfileIds), + FallbackModels = NormalizeDistinct(model.FallbackModels), + Capabilities = model.Capabilities ?? GuessCapabilities(model.Provider), + IsImplicit = string.Equals(model.Id, "default", StringComparison.OrdinalIgnoreCase) + && config.Models.Profiles.Count == 0 + }; + + private static IEnumerable ValidateProfile(ModelProfile profile, GatewayConfig config) + { + if (string.IsNullOrWhiteSpace(profile.Id)) + yield return "Profile id is required."; + if (string.IsNullOrWhiteSpace(profile.ProviderId)) + yield return "Provider is required."; + if (string.IsNullOrWhiteSpace(profile.ModelId)) + yield return "Model is required."; + if ((profile.ProviderId.Equals("openai-compatible", StringComparison.OrdinalIgnoreCase) || + profile.ProviderId.Equals("groq", StringComparison.OrdinalIgnoreCase) || + profile.ProviderId.Equals("together", StringComparison.OrdinalIgnoreCase) || + profile.ProviderId.Equals("lmstudio", StringComparison.OrdinalIgnoreCase) || + profile.ProviderId.Equals("azure-openai", StringComparison.OrdinalIgnoreCase)) && + string.IsNullOrWhiteSpace(profile.BaseUrl) && + string.IsNullOrWhiteSpace(config.Llm.Endpoint)) + { + yield return "BaseUrl is required for OpenAI-compatible and Azure OpenAI profiles unless inherited from OpenClaw:Llm:Endpoint."; + } + + if ((profile.ProviderId.Equals("openai", StringComparison.OrdinalIgnoreCase) || + profile.ProviderId.Equals("openai-compatible", StringComparison.OrdinalIgnoreCase) || + profile.ProviderId.Equals("groq", StringComparison.OrdinalIgnoreCase) || + profile.ProviderId.Equals("together", StringComparison.OrdinalIgnoreCase) || + profile.ProviderId.Equals("azure-openai", StringComparison.OrdinalIgnoreCase) || + profile.ProviderId.Equals("anthropic", StringComparison.OrdinalIgnoreCase) || + profile.ProviderId.Equals("claude", StringComparison.OrdinalIgnoreCase) || + profile.ProviderId.Equals("gemini", StringComparison.OrdinalIgnoreCase) || + profile.ProviderId.Equals("google", StringComparison.OrdinalIgnoreCase)) && + string.IsNullOrWhiteSpace(profile.ApiKey) && + string.IsNullOrWhiteSpace(config.Llm.ApiKey)) + { + yield return "ApiKey is required for remote provider profiles unless inherited from OpenClaw:Llm:ApiKey."; + } + } + + internal static LlmProviderConfig BuildProviderConfig(GatewayConfig config, ModelProfile profile) + => new() + { + Provider = profile.ProviderId, + Model = profile.ModelId, + ApiKey = profile.ApiKey ?? config.Llm.ApiKey, + Endpoint = profile.BaseUrl ?? config.Llm.Endpoint, + FallbackModels = profile.FallbackModels, + MaxTokens = profile.Capabilities.MaxOutputTokens > 0 ? profile.Capabilities.MaxOutputTokens : config.Llm.MaxTokens, + Temperature = config.Llm.Temperature, + TimeoutSeconds = config.Llm.TimeoutSeconds, + RetryCount = config.Llm.RetryCount, + CircuitBreakerThreshold = config.Llm.CircuitBreakerThreshold, + CircuitBreakerCooldownSeconds = config.Llm.CircuitBreakerCooldownSeconds + }; + + private static string? Normalize(string? value) + => string.IsNullOrWhiteSpace(value) ? null : value.Trim(); + + private static string[] NormalizeDistinct(IEnumerable? values) + => values is null + ? [] + : values.Where(static item => !string.IsNullOrWhiteSpace(item)) + .Select(static item => item.Trim()) + .Distinct(StringComparer.OrdinalIgnoreCase) + .ToArray(); +} diff --git a/src/OpenClaw.Gateway/Models/DefaultModelSelectionPolicy.cs b/src/OpenClaw.Gateway/Models/DefaultModelSelectionPolicy.cs new file mode 100644 index 0000000..d10df59 --- /dev/null +++ b/src/OpenClaw.Gateway/Models/DefaultModelSelectionPolicy.cs @@ -0,0 +1,265 @@ +using Microsoft.Extensions.AI; +using OpenClaw.Core.Abstractions; +using OpenClaw.Core.Models; + +namespace OpenClaw.Gateway.Models; + +internal sealed class DefaultModelSelectionPolicy : IModelSelectionPolicy +{ + private readonly ConfiguredModelProfileRegistry _registry; + + public DefaultModelSelectionPolicy(ConfiguredModelProfileRegistry registry) + { + _registry = registry; + } + + public ModelSelectionResult Resolve(ModelSelectionRequest request) + { + var requirements = BuildRequirements(request); + var preferredTags = CollectPreferredTags(request.Session); + var fallbackProfileIds = CollectFallbackProfileIds(request.Session); + var explicitProfileId = Normalize(request.ExplicitProfileId) ?? Normalize(request.Session.ModelProfileId); + + var attempted = new List(); + if (!string.IsNullOrWhiteSpace(explicitProfileId)) + { + if (!_registry.TryGetRegistration(explicitProfileId, out var explicitRegistration) || explicitRegistration is null) + throw new ModelSelectionException($"Selected model profile '{explicitProfileId}' is not registered."); + + attempted.Add(new ModelSelectionCandidate + { + Profile = explicitRegistration.Profile, + FallbackModels = explicitRegistration.ProviderConfig.FallbackModels + }); + + if (Satisfies(explicitRegistration.Profile, requirements)) + return BuildResult(explicitProfileId, explicitRegistration.Profile, requirements, preferredTags, attempted, null); + + foreach (var fallbackId in explicitRegistration.Profile.FallbackProfileIds.Concat(fallbackProfileIds)) + { + if (!_registry.TryGetRegistration(fallbackId, out var fallbackRegistration) || fallbackRegistration is null) + continue; + + attempted.Add(new ModelSelectionCandidate + { + Profile = fallbackRegistration.Profile, + FallbackModels = fallbackRegistration.ProviderConfig.FallbackModels + }); + + if (Satisfies(fallbackRegistration.Profile, requirements)) + { + var explanation = + $"Falling back from '{explicitRegistration.Profile.Id}' to '{fallbackRegistration.Profile.Id}' because {DescribeMissing(explicitRegistration.Profile, requirements)}."; + return BuildResult(explicitProfileId, fallbackRegistration.Profile, requirements, preferredTags, attempted, explanation); + } + } + + throw new ModelSelectionException( + $"This route requires {DescribeRequirementSummary(requirements)}, but selected model profile '{explicitRegistration.Profile.Id}' does not support it."); + } + + var candidates = _registry.ListStatuses() + .OrderByDescending(item => Score(item, preferredTags, requirements)) + .ThenByDescending(static item => item.IsDefault) + .ThenBy(static item => item.Id, StringComparer.OrdinalIgnoreCase); + + foreach (var status in candidates) + { + if (!_registry.TryGetRegistration(status.Id, out var registration) || registration is null) + continue; + + if (!Satisfies(registration.Profile, requirements)) + continue; + + attempted.Add(new ModelSelectionCandidate + { + Profile = registration.Profile, + FallbackModels = registration.ProviderConfig.FallbackModels + }); + + return BuildResult(null, registration.Profile, requirements, preferredTags, attempted, null); + } + + throw new ModelSelectionException( + $"No configured model profile satisfies the current request requirements ({DescribeRequirementSummary(requirements)})."); + } + + private static ModelSelectionResult BuildResult( + string? requestedProfileId, + ModelProfile selectedProfile, + ModelSelectionRequirements requirements, + string[] preferredTags, + IReadOnlyList candidates, + string? explanation) + => new() + { + RequestedProfileId = requestedProfileId, + SelectedProfileId = selectedProfile.Id, + ProviderId = selectedProfile.ProviderId, + ModelId = selectedProfile.ModelId, + Requirements = requirements, + PreferredTags = preferredTags, + Candidates = candidates, + Explanation = explanation + }; + + private static int Score(ModelProfileStatus status, IReadOnlyList preferredTags, ModelSelectionRequirements requirements) + { + var score = status.IsDefault ? 100 : 0; + score += preferredTags.Count(tag => status.Tags.Contains(tag, StringComparer.OrdinalIgnoreCase)) * 10; + if (requirements.SupportsTools == true && status.Capabilities.SupportsTools) + score += 25; + if (requirements.SupportsVision == true && status.Capabilities.SupportsVision) + score += 20; + if (requirements.SupportsStructuredOutputs == true && status.Capabilities.SupportsStructuredOutputs) + score += 15; + if (requirements.SupportsStreaming == true && status.Capabilities.SupportsStreaming) + score += 10; + if (requirements.SupportsReasoningEffort == true && status.Capabilities.SupportsReasoningEffort) + score += 5; + if (status.IsAvailable) + score += 5; + return score; + } + + internal static ModelSelectionRequirements BuildRequirements(ModelSelectionRequest request) + { + var combined = Clone(request.Session.ModelRequirements); + + if (request.Streaming) + combined.SupportsStreaming = true; + if (request.Options?.Tools is { Count: > 0 }) + { + combined.SupportsTools = true; + if (request.Options.Tools.Count > 1) + combined.SupportsParallelToolCalls ??= true; + } + + if (request.Options?.ResponseFormat is not null) + { + combined.SupportsJsonSchema = true; + combined.SupportsStructuredOutputs = true; + } + + if (!string.IsNullOrWhiteSpace(request.Session.ReasoningEffort)) + combined.SupportsReasoningEffort = true; + + if (request.Messages.Any(static message => message.Role == ChatRole.System)) + combined.SupportsSystemMessages = true; + + if (request.Messages.SelectMany(static message => message.Contents).OfType().Any(static content => HasMediaPrefix(content.MediaType, "image/"))) + { + combined.SupportsVision = true; + combined.SupportsImageInput = true; + } + + if (request.Messages.SelectMany(static message => message.Contents).OfType().Any(static content => HasMediaPrefix(content.MediaType, "audio/"))) + combined.SupportsAudioInput = true; + + return combined; + } + + private static bool HasMediaPrefix(string? mediaType, string prefix) + => !string.IsNullOrWhiteSpace(mediaType) && + mediaType.StartsWith(prefix, StringComparison.OrdinalIgnoreCase); + + private static ModelSelectionRequirements Clone(ModelSelectionRequirements? source) + => source is null + ? new ModelSelectionRequirements() + : new ModelSelectionRequirements + { + SupportsTools = source.SupportsTools, + SupportsVision = source.SupportsVision, + SupportsJsonSchema = source.SupportsJsonSchema, + SupportsStructuredOutputs = source.SupportsStructuredOutputs, + SupportsStreaming = source.SupportsStreaming, + SupportsParallelToolCalls = source.SupportsParallelToolCalls, + SupportsReasoningEffort = source.SupportsReasoningEffort, + SupportsSystemMessages = source.SupportsSystemMessages, + SupportsImageInput = source.SupportsImageInput, + SupportsAudioInput = source.SupportsAudioInput, + MinContextTokens = source.MinContextTokens, + MinOutputTokens = source.MinOutputTokens + }; + + private static bool Satisfies(ModelProfile profile, ModelSelectionRequirements requirements) + { + var caps = profile.Capabilities; + return Meets(requirements.SupportsTools, caps.SupportsTools) + && Meets(requirements.SupportsVision, caps.SupportsVision) + && Meets(requirements.SupportsJsonSchema, caps.SupportsJsonSchema) + && Meets(requirements.SupportsStructuredOutputs, caps.SupportsStructuredOutputs) + && Meets(requirements.SupportsStreaming, caps.SupportsStreaming) + && Meets(requirements.SupportsParallelToolCalls, caps.SupportsParallelToolCalls) + && Meets(requirements.SupportsReasoningEffort, caps.SupportsReasoningEffort) + && Meets(requirements.SupportsSystemMessages, caps.SupportsSystemMessages) + && Meets(requirements.SupportsImageInput, caps.SupportsImageInput) + && Meets(requirements.SupportsAudioInput, caps.SupportsAudioInput) + && (!requirements.MinContextTokens.HasValue || caps.MaxContextTokens >= requirements.MinContextTokens.Value) + && (!requirements.MinOutputTokens.HasValue || caps.MaxOutputTokens >= requirements.MinOutputTokens.Value); + } + + private static bool Meets(bool? required, bool actual) + => required is not true || actual; + + private static string DescribeMissing(ModelProfile profile, ModelSelectionRequirements requirements) + { + var missing = new List(); + if (requirements.SupportsTools == true && !profile.Capabilities.SupportsTools) + missing.Add("tool calling was required"); + if (requirements.SupportsVision == true && !profile.Capabilities.SupportsVision) + missing.Add("vision was required"); + if (requirements.SupportsJsonSchema == true && !profile.Capabilities.SupportsJsonSchema) + missing.Add("JSON schema output was required"); + if (requirements.SupportsStructuredOutputs == true && !profile.Capabilities.SupportsStructuredOutputs) + missing.Add("structured output was required"); + if (requirements.SupportsStreaming == true && !profile.Capabilities.SupportsStreaming) + missing.Add("streaming was required"); + if (requirements.SupportsReasoningEffort == true && !profile.Capabilities.SupportsReasoningEffort) + missing.Add("reasoning effort was required"); + if (requirements.SupportsImageInput == true && !profile.Capabilities.SupportsImageInput) + missing.Add("image input was required"); + if (requirements.SupportsAudioInput == true && !profile.Capabilities.SupportsAudioInput) + missing.Add("audio input was required"); + return missing.Count == 0 ? "required capabilities were not satisfied" : string.Join(", ", missing); + } + + private static string DescribeRequirementSummary(ModelSelectionRequirements requirements) + { + var items = new List(); + if (requirements.SupportsTools == true) + items.Add("tool calling"); + if (requirements.SupportsVision == true) + items.Add("vision"); + if (requirements.SupportsJsonSchema == true) + items.Add("JSON schema"); + if (requirements.SupportsStructuredOutputs == true) + items.Add("structured outputs"); + if (requirements.SupportsStreaming == true) + items.Add("streaming"); + if (requirements.SupportsReasoningEffort == true) + items.Add("reasoning effort"); + if (requirements.SupportsImageInput == true) + items.Add("image input"); + if (requirements.SupportsAudioInput == true) + items.Add("audio input"); + return items.Count == 0 ? "the requested capabilities" : string.Join("+", items); + } + + private static string[] CollectPreferredTags(Session session) + => session.PreferredModelTags + .Where(static item => !string.IsNullOrWhiteSpace(item)) + .Select(static item => item.Trim()) + .Distinct(StringComparer.OrdinalIgnoreCase) + .ToArray(); + + private static string[] CollectFallbackProfileIds(Session session) + => session.FallbackModelProfileIds + .Where(static item => !string.IsNullOrWhiteSpace(item)) + .Select(static item => item.Trim()) + .Distinct(StringComparer.OrdinalIgnoreCase) + .ToArray(); + + private static string? Normalize(string? value) + => string.IsNullOrWhiteSpace(value) ? null : value.Trim(); +} diff --git a/src/OpenClaw.Gateway/Models/ModelEvaluationRunner.cs b/src/OpenClaw.Gateway/Models/ModelEvaluationRunner.cs new file mode 100644 index 0000000..84bf720 --- /dev/null +++ b/src/OpenClaw.Gateway/Models/ModelEvaluationRunner.cs @@ -0,0 +1,525 @@ +using System.Text; +using System.Text.Json; +using System.Diagnostics; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.Logging; +using OpenClaw.Core.Models; + +namespace OpenClaw.Gateway.Models; + +internal sealed class ModelEvaluationRunner +{ + private readonly ConfiguredModelProfileRegistry _registry; + private readonly GatewayConfig _config; + private readonly ILogger _logger; + private readonly IReadOnlyList _scenarios; + + public ModelEvaluationRunner( + ConfiguredModelProfileRegistry registry, + GatewayConfig config, + ILogger logger) + { + _registry = registry; + _config = config; + _logger = logger; + _scenarios = + [ + new PlainChatScenario(), + new JsonExtractionScenario(), + new ToolInvocationScenario(), + new MultiTurnContinuityScenario(), + new CompactionRecoveryScenario(), + new StreamingScenario(), + new VisionPromptScenario() + ]; + } + + public IReadOnlyList ListScenarioIds() + => _scenarios.Select(static scenario => scenario.Id).ToArray(); + + public ModelSelectionDoctorResponse BuildDoctor() + { + var statuses = _registry.ListStatuses(); + var warnings = new List(); + var errors = new List(); + if (statuses.Count == 0) + errors.Add("No model profiles are registered."); + if (string.IsNullOrWhiteSpace(_registry.DefaultProfileId)) + errors.Add("No default model profile is configured."); + + foreach (var status in statuses) + { + if (status.ValidationIssues.Length > 0) + warnings.Add($"Profile '{status.Id}' has validation issues: {string.Join("; ", status.ValidationIssues)}"); + } + + return new ModelSelectionDoctorResponse + { + DefaultProfileId = _registry.DefaultProfileId, + Errors = errors, + Warnings = warnings, + Profiles = statuses + }; + } + + public async Task RunAsync(ModelEvaluationRequest request, CancellationToken ct) + { + var startedAt = DateTimeOffset.UtcNow; + var runId = $"meval_{Guid.NewGuid():N}"[..18]; + var scenarioSet = ResolveScenarios(request.ScenarioIds); + var profileIds = ResolveProfiles(request); + var profileReports = new List(profileIds.Length); + + foreach (var profileId in profileIds) + { + ct.ThrowIfCancellationRequested(); + if (!_registry.TryGetRegistration(profileId, out var registration) || registration is null) + { + profileReports.Add(new ModelEvaluationProfileReport + { + ProfileId = profileId, + ProviderId = "unknown", + ModelId = "unknown", + StartedAtUtc = DateTimeOffset.UtcNow, + CompletedAtUtc = DateTimeOffset.UtcNow, + Scenarios = + [ + new ModelEvaluationScenarioResult + { + ScenarioId = "profile_lookup", + Name = "Profile lookup", + Status = "failed", + Error = $"Profile '{profileId}' is not registered." + } + ] + }); + continue; + } + + if (registration.Client is null) + { + profileReports.Add(new ModelEvaluationProfileReport + { + ProfileId = registration.Profile.Id, + ProviderId = registration.Profile.ProviderId, + ModelId = registration.Profile.ModelId, + StartedAtUtc = DateTimeOffset.UtcNow, + CompletedAtUtc = DateTimeOffset.UtcNow, + Scenarios = + [ + new ModelEvaluationScenarioResult + { + ScenarioId = "profile_availability", + Name = "Profile availability", + Status = "failed", + Error = string.Join("; ", registration.ValidationIssues) + } + ] + }); + continue; + } + + var profileStart = DateTimeOffset.UtcNow; + var scenarioResults = new List(scenarioSet.Count); + foreach (var scenario in scenarioSet) + { + ct.ThrowIfCancellationRequested(); + scenarioResults.Add(await scenario.RunAsync(registration.Client, registration.Profile, ct)); + } + + profileReports.Add(new ModelEvaluationProfileReport + { + ProfileId = registration.Profile.Id, + ProviderId = registration.Profile.ProviderId, + ModelId = registration.Profile.ModelId, + StartedAtUtc = profileStart, + CompletedAtUtc = DateTimeOffset.UtcNow, + Scenarios = scenarioResults + }); + } + + var completedAt = DateTimeOffset.UtcNow; + var evaluationDirectory = Path.Combine(Path.GetFullPath(_config.Memory.StoragePath), "admin", "model-evaluations"); + Directory.CreateDirectory(evaluationDirectory); + var jsonPath = Path.Combine(evaluationDirectory, $"{runId}.json"); + var markdownPath = Path.Combine(evaluationDirectory, $"{runId}.md"); + + var report = new ModelEvaluationReport + { + RunId = runId, + StartedAtUtc = startedAt, + CompletedAtUtc = completedAt, + ScenarioIds = scenarioSet.Select(static scenario => scenario.Id).ToArray(), + Profiles = profileReports, + JsonPath = jsonPath, + MarkdownPath = request.IncludeMarkdown ? markdownPath : null + }; + + var reportWithMarkdown = new ModelEvaluationReport + { + RunId = report.RunId, + StartedAtUtc = report.StartedAtUtc, + CompletedAtUtc = report.CompletedAtUtc, + ScenarioIds = report.ScenarioIds, + Profiles = report.Profiles, + JsonPath = report.JsonPath, + MarkdownPath = report.MarkdownPath, + Markdown = request.IncludeMarkdown ? BuildMarkdown(report) : null + }; + + await File.WriteAllTextAsync(jsonPath, JsonSerializer.Serialize(reportWithMarkdown, CoreJsonContext.Default.ModelEvaluationReport), ct); + if (request.IncludeMarkdown && reportWithMarkdown.Markdown is not null) + await File.WriteAllTextAsync(markdownPath, reportWithMarkdown.Markdown, ct); + + return reportWithMarkdown; + } + + private IReadOnlyList ResolveScenarios(IReadOnlyList scenarioIds) + { + if (scenarioIds.Count == 0) + return _scenarios; + + var requested = new HashSet(scenarioIds, StringComparer.OrdinalIgnoreCase); + return _scenarios.Where(scenario => requested.Contains(scenario.Id)).ToArray(); + } + + private string[] ResolveProfiles(ModelEvaluationRequest request) + { + var explicitProfiles = request.ProfileIds + .Concat(string.IsNullOrWhiteSpace(request.ProfileId) ? [] : [request.ProfileId]) + .Where(static item => !string.IsNullOrWhiteSpace(item)) + .Select(static item => item.Trim()) + .Distinct(StringComparer.OrdinalIgnoreCase) + .ToArray(); + + if (explicitProfiles.Length > 0) + return explicitProfiles; + + if (!string.IsNullOrWhiteSpace(_registry.DefaultProfileId)) + return [_registry.DefaultProfileId]; + + return _registry.ListStatuses().Select(static status => status.Id).ToArray(); + } + + private string BuildMarkdown(ModelEvaluationReport report) + { + var sb = new StringBuilder(); + sb.AppendLine("# Model Evaluation Report"); + sb.AppendLine(); + sb.AppendLine($"- Run ID: `{report.RunId}`"); + sb.AppendLine($"- Started: `{report.StartedAtUtc:O}`"); + sb.AppendLine($"- Completed: `{report.CompletedAtUtc:O}`"); + sb.AppendLine($"- Scenarios: {string.Join(", ", report.ScenarioIds)}"); + sb.AppendLine(); + + foreach (var profile in report.Profiles) + { + sb.AppendLine($"## {profile.ProfileId}"); + sb.AppendLine(); + sb.AppendLine($"- Provider/model: `{profile.ProviderId}/{profile.ModelId}`"); + sb.AppendLine($"- Started: `{profile.StartedAtUtc:O}`"); + sb.AppendLine($"- Completed: `{profile.CompletedAtUtc:O}`"); + sb.AppendLine(); + sb.AppendLine("| Scenario | Status | Latency (ms) | Summary |"); + sb.AppendLine("| --- | --- | ---: | --- |"); + foreach (var scenario in profile.Scenarios) + sb.AppendLine($"| {scenario.Name} | {scenario.Status} | {scenario.LatencyMs} | {EscapePipe(scenario.Summary ?? scenario.Error ?? "")} |"); + sb.AppendLine(); + } + + return sb.ToString(); + } + + private static string EscapePipe(string value) + => value.Replace("|", "\\|", StringComparison.Ordinal); + + private interface IModelEvaluationScenario + { + string Id { get; } + string Name { get; } + Task RunAsync(IChatClient client, ModelProfile profile, CancellationToken ct); + } + + private abstract class ModelEvaluationScenarioBase : IModelEvaluationScenario + { + public abstract string Id { get; } + public abstract string Name { get; } + + public async Task RunAsync(IChatClient client, ModelProfile profile, CancellationToken ct) + { + var started = Stopwatch.StartNew(); + try + { + return await ExecuteAsync(client, profile, started, ct); + } + catch (OperationCanceledException) when (ct.IsCancellationRequested) + { + throw; + } + catch (Exception ex) + { + return new ModelEvaluationScenarioResult + { + ScenarioId = Id, + Name = Name, + Status = "failed", + LatencyMs = started.ElapsedMilliseconds, + Error = ex.Message + }; + } + } + + protected static ModelEvaluationScenarioResult Unsupported(string id, string name, Stopwatch stopwatch, string summary) + => new() + { + ScenarioId = id, + Name = name, + Status = "unsupported", + LatencyMs = stopwatch.ElapsedMilliseconds, + Summary = summary + }; + + protected abstract Task ExecuteAsync(IChatClient client, ModelProfile profile, Stopwatch stopwatch, CancellationToken ct); + } + + private sealed class PlainChatScenario : ModelEvaluationScenarioBase + { + public override string Id => "plain-chat"; + public override string Name => "Plain chat response"; + + protected override async Task ExecuteAsync(IChatClient client, ModelProfile profile, Stopwatch stopwatch, CancellationToken ct) + { + var response = await client.GetResponseAsync( + [new ChatMessage(ChatRole.User, "Reply with exactly READY.")], + new ChatOptions { ModelId = profile.ModelId, MaxOutputTokens = 64, Temperature = 0 }, + ct); + return new ModelEvaluationScenarioResult + { + ScenarioId = Id, + Name = Name, + Status = response.Text?.Contains("READY", StringComparison.OrdinalIgnoreCase) == true ? "passed" : "warning", + LatencyMs = stopwatch.ElapsedMilliseconds, + InputTokens = response.Usage?.InputTokenCount ?? 0, + OutputTokens = response.Usage?.OutputTokenCount ?? 0, + Summary = response.Text + }; + } + } + + private sealed class JsonExtractionScenario : ModelEvaluationScenarioBase + { + public override string Id => "json-extraction"; + public override string Name => "Structured JSON extraction"; + + protected override async Task ExecuteAsync(IChatClient client, ModelProfile profile, Stopwatch stopwatch, CancellationToken ct) + { + if (!profile.Capabilities.SupportsJsonSchema && !profile.Capabilities.SupportsStructuredOutputs) + return Unsupported(Id, Name, stopwatch, $"Profile '{profile.Id}' does not advertise JSON schema or structured outputs."); + + using var schema = JsonDocument.Parse("""{"type":"object","properties":{"animal":{"type":"string"},"count":{"type":"integer"}},"required":["animal","count"],"additionalProperties":false}"""); + var response = await client.GetResponseAsync( + [new ChatMessage(ChatRole.User, "Extract JSON from this sentence: I saw 3 foxes in the garden.")], + new ChatOptions + { + ModelId = profile.ModelId, + MaxOutputTokens = 128, + Temperature = 0, + ResponseFormat = ChatResponseFormat.ForJsonSchema(schema.RootElement.Clone(), "extraction") + }, + ct); + + var malformed = false; + try + { + using var doc = JsonDocument.Parse(response.Text ?? "{}"); + malformed = !(doc.RootElement.TryGetProperty("animal", out _) && doc.RootElement.TryGetProperty("count", out _)); + } + catch + { + malformed = true; + } + + return new ModelEvaluationScenarioResult + { + ScenarioId = Id, + Name = Name, + Status = malformed ? "failed" : "passed", + LatencyMs = stopwatch.ElapsedMilliseconds, + InputTokens = response.Usage?.InputTokenCount ?? 0, + OutputTokens = response.Usage?.OutputTokenCount ?? 0, + MalformedJson = malformed, + Summary = response.Text + }; + } + } + + private sealed class ToolInvocationScenario : ModelEvaluationScenarioBase + { + public override string Id => "tool-invocation"; + public override string Name => "Tool selection correctness"; + + protected override async Task ExecuteAsync(IChatClient client, ModelProfile profile, Stopwatch stopwatch, CancellationToken ct) + { + if (!profile.Capabilities.SupportsTools) + return Unsupported(Id, Name, stopwatch, $"Profile '{profile.Id}' does not advertise tool support."); + + using var schema = JsonDocument.Parse("""{"type":"object","properties":{"value":{"type":"string"}},"required":["value"]}"""); + var tool = AIFunctionFactory.CreateDeclaration("record_observation", "Record the extracted observation.", schema.RootElement.Clone(), returnJsonSchema: null); + var response = await client.GetResponseAsync( + [new ChatMessage(ChatRole.User, "Use the tool to record the value 'gemma4'.")], + new ChatOptions + { + ModelId = profile.ModelId, + Temperature = 0, + MaxOutputTokens = 128, + Tools = [tool] + }, + ct); + + var call = response.Messages + .SelectMany(static message => message.Contents) + .OfType() + .FirstOrDefault(); + + return new ModelEvaluationScenarioResult + { + ScenarioId = Id, + Name = Name, + Status = call is null ? "failed" : "passed", + LatencyMs = stopwatch.ElapsedMilliseconds, + InputTokens = response.Usage?.InputTokenCount ?? 0, + OutputTokens = response.Usage?.OutputTokenCount ?? 0, + ToolCalls = call is null ? 0 : 1, + Summary = call is null ? response.Text : $"tool={call.Name}" + }; + } + } + + private sealed class MultiTurnContinuityScenario : ModelEvaluationScenarioBase + { + public override string Id => "multi-turn"; + public override string Name => "Multi-turn continuity"; + + protected override async Task ExecuteAsync(IChatClient client, ModelProfile profile, Stopwatch stopwatch, CancellationToken ct) + { + var first = await client.GetResponseAsync( + [new ChatMessage(ChatRole.User, "Remember the code word maple-42 and reply with STORED.")], + new ChatOptions { ModelId = profile.ModelId, Temperature = 0, MaxOutputTokens = 64 }, + ct); + var second = await client.GetResponseAsync( + [ + new ChatMessage(ChatRole.User, "Remember the code word maple-42 and reply with STORED."), + new ChatMessage(ChatRole.Assistant, first.Text ?? "STORED"), + new ChatMessage(ChatRole.User, "What code word did I ask you to remember?") + ], + new ChatOptions { ModelId = profile.ModelId, Temperature = 0, MaxOutputTokens = 64 }, + ct); + + return new ModelEvaluationScenarioResult + { + ScenarioId = Id, + Name = Name, + Status = second.Text?.Contains("maple-42", StringComparison.OrdinalIgnoreCase) == true ? "passed" : "warning", + LatencyMs = stopwatch.ElapsedMilliseconds, + InputTokens = (first.Usage?.InputTokenCount ?? 0) + (second.Usage?.InputTokenCount ?? 0), + OutputTokens = (first.Usage?.OutputTokenCount ?? 0) + (second.Usage?.OutputTokenCount ?? 0), + Summary = second.Text + }; + } + } + + private sealed class CompactionRecoveryScenario : ModelEvaluationScenarioBase + { + public override string Id => "compaction-recovery"; + public override string Name => "Compaction recovery"; + + protected override async Task ExecuteAsync(IChatClient client, ModelProfile profile, Stopwatch stopwatch, CancellationToken ct) + { + var response = await client.GetResponseAsync( + [ + new ChatMessage(ChatRole.System, "Conversation summary: The user named the migration branch gemma4-rollout."), + new ChatMessage(ChatRole.User, "What branch name was mentioned in the summary?") + ], + new ChatOptions { ModelId = profile.ModelId, Temperature = 0, MaxOutputTokens = 64 }, + ct); + + return new ModelEvaluationScenarioResult + { + ScenarioId = Id, + Name = Name, + Status = response.Text?.Contains("gemma4-rollout", StringComparison.OrdinalIgnoreCase) == true ? "passed" : "warning", + LatencyMs = stopwatch.ElapsedMilliseconds, + InputTokens = response.Usage?.InputTokenCount ?? 0, + OutputTokens = response.Usage?.OutputTokenCount ?? 0, + Summary = response.Text + }; + } + } + + private sealed class StreamingScenario : ModelEvaluationScenarioBase + { + public override string Id => "streaming"; + public override string Name => "Streaming behavior"; + + protected override async Task ExecuteAsync(IChatClient client, ModelProfile profile, Stopwatch stopwatch, CancellationToken ct) + { + if (!profile.Capabilities.SupportsStreaming) + return Unsupported(Id, Name, stopwatch, $"Profile '{profile.Id}' does not advertise streaming support."); + + var deltas = new List(); + await foreach (var update in client.GetStreamingResponseAsync( + [new ChatMessage(ChatRole.User, "Reply in one short sentence about Gemma.")], + new ChatOptions { ModelId = profile.ModelId, Temperature = 0, MaxOutputTokens = 96 }, + ct)) + { + if (!string.IsNullOrWhiteSpace(update.Text)) + deltas.Add(update.Text); + } + + return new ModelEvaluationScenarioResult + { + ScenarioId = Id, + Name = Name, + Status = deltas.Count > 0 ? "passed" : "warning", + LatencyMs = stopwatch.ElapsedMilliseconds, + Summary = string.Concat(deltas) + }; + } + } + + private sealed class VisionPromptScenario : ModelEvaluationScenarioBase + { + private const string TinyRedPngDataUri = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADUlEQVR42mP8z8BQDwAFgwJ/l38DSQAAAABJRU5ErkJggg=="; + + public override string Id => "vision"; + public override string Name => "Vision input behavior"; + + protected override async Task ExecuteAsync(IChatClient client, ModelProfile profile, Stopwatch stopwatch, CancellationToken ct) + { + if (!profile.Capabilities.SupportsVision || !profile.Capabilities.SupportsImageInput) + return Unsupported(Id, Name, stopwatch, $"Profile '{profile.Id}' does not advertise vision/image input support."); + + var response = await client.GetResponseAsync( + [ + new ChatMessage(ChatRole.User, new AIContent[] + { + new TextContent("The image is a single-color square. Answer with one color word."), + new UriContent(new Uri(TinyRedPngDataUri), "image/png") + }) + ], + new ChatOptions { ModelId = profile.ModelId, Temperature = 0, MaxOutputTokens = 32 }, + ct); + + return new ModelEvaluationScenarioResult + { + ScenarioId = Id, + Name = Name, + Status = !string.IsNullOrWhiteSpace(response.Text) ? "passed" : "warning", + LatencyMs = stopwatch.ElapsedMilliseconds, + InputTokens = response.Usage?.InputTokenCount ?? 0, + OutputTokens = response.Usage?.OutputTokenCount ?? 0, + Summary = response.Text + }; + } + } +} diff --git a/src/OpenClaw.Gateway/RuntimeOperationsState.cs b/src/OpenClaw.Gateway/RuntimeOperationsState.cs index 7a4faea..84bc72d 100644 --- a/src/OpenClaw.Gateway/RuntimeOperationsState.cs +++ b/src/OpenClaw.Gateway/RuntimeOperationsState.cs @@ -2,6 +2,7 @@ namespace OpenClaw.Gateway; internal sealed class RuntimeOperationsState { + public OpenClaw.Core.Abstractions.IModelProfileRegistry ModelProfiles { get; init; } = EmptyModelProfileRegistry.Instance; public required ProviderPolicyService ProviderPolicies { get; init; } public required LlmProviderRegistry ProviderRegistry { get; init; } public required GatewayLlmExecutionService LlmExecution { get; init; } @@ -12,4 +13,19 @@ internal sealed class RuntimeOperationsState public required WebhookDeliveryStore WebhookDeliveries { get; init; } public required ActorRateLimitService ActorRateLimits { get; init; } public required SessionMetadataStore SessionMetadata { get; init; } + + private sealed class EmptyModelProfileRegistry : OpenClaw.Core.Abstractions.IModelProfileRegistry + { + public static EmptyModelProfileRegistry Instance { get; } = new(); + + public string? DefaultProfileId => null; + + public IReadOnlyList ListStatuses() => []; + + public bool TryGet(string profileId, out OpenClaw.Core.Models.ModelProfile? profile) + { + profile = null; + return false; + } + } } diff --git a/src/OpenClaw.MicrosoftAgentFrameworkAdapter/MafAgentRuntime.cs b/src/OpenClaw.MicrosoftAgentFrameworkAdapter/MafAgentRuntime.cs index 7d1457f..c6f4cf2 100644 --- a/src/OpenClaw.MicrosoftAgentFrameworkAdapter/MafAgentRuntime.cs +++ b/src/OpenClaw.MicrosoftAgentFrameworkAdapter/MafAgentRuntime.cs @@ -240,6 +240,12 @@ public async Task RunAsync( { throw; } + catch (ModelSelectionException ex) + { + _logger?.LogWarning("[{CorrelationId}] MAF model selection failed: {Message}", turnCtx.CorrelationId, ex.Message); + LogTurnComplete(turnCtx); + return ex.Message; + } catch (Exception ex) { _metrics.IncrementLlmErrors(); @@ -410,6 +416,19 @@ ValueTask WriteStreamEventAsync(AgentStreamEvent evt, CancellationToken token) writer.TryComplete(); throw; } + catch (ModelSelectionException ex) + { + _logger?.LogWarning("[{CorrelationId}] MAF streaming model selection failed: {Message}", turnCtx.CorrelationId, ex.Message); + try + { + await writer.WriteAsync(AgentStreamEvent.ErrorOccurred(ex.Message, "model_selection_failed"), ct); + await writer.WriteAsync(AgentStreamEvent.Complete(), ct); + } + catch (OperationCanceledException) when (ct.IsCancellationRequested) + { + throw; + } + } catch (Exception ex) { _metrics.IncrementLlmErrors(); diff --git a/src/OpenClaw.Tests/ModelProfileSelectionTests.cs b/src/OpenClaw.Tests/ModelProfileSelectionTests.cs new file mode 100644 index 0000000..5bc1404 --- /dev/null +++ b/src/OpenClaw.Tests/ModelProfileSelectionTests.cs @@ -0,0 +1,293 @@ +using System.Text.Json; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.Logging.Abstractions; +using OpenClaw.Core.Models; +using OpenClaw.Core.Validation; +using OpenClaw.Gateway.Bootstrap; +using OpenClaw.Gateway.Extensions; +using OpenClaw.Gateway.Models; +using Xunit; + +namespace OpenClaw.Tests; + +public sealed class ModelProfileSelectionTests +{ + [Fact] + public void Registry_WhenProfilesMissing_CreatesImplicitDefaultProfile() + { + LlmClientFactory.ResetDynamicProviders(); + LlmClientFactory.RegisterProvider("fake-profile-tests", new EvaluationChatClient()); + + var config = new GatewayConfig + { + Llm = new LlmProviderConfig + { + Provider = "fake-profile-tests", + Model = "legacy-model" + } + }; + + var registry = new ConfiguredModelProfileRegistry(config, NullLogger.Instance); + var statuses = registry.ListStatuses(); + + var profile = Assert.Single(statuses); + Assert.Equal("default", registry.DefaultProfileId); + Assert.Equal("default", profile.Id); + Assert.True(profile.IsDefault); + Assert.True(profile.IsImplicit); + Assert.Equal("legacy-model", profile.ModelId); + } + + [Fact] + public void SelectionPolicy_ExplicitProfileFallsBackWhenCapabilitiesMissing() + { + LlmClientFactory.ResetDynamicProviders(); + LlmClientFactory.RegisterProvider("fake-profile-tests", new EvaluationChatClient()); + + var config = BuildProfileConfig(); + var registry = new ConfiguredModelProfileRegistry(config, NullLogger.Instance); + var policy = new DefaultModelSelectionPolicy(registry); + var session = new Session + { + Id = "s1", + ChannelId = "test", + SenderId = "user", + ModelProfileId = "gemma4-local", + FallbackModelProfileIds = ["frontier-tools"], + ModelRequirements = new ModelSelectionRequirements + { + SupportsTools = true + } + }; + + var selection = policy.Resolve(new OpenClaw.Core.Abstractions.ModelSelectionRequest + { + Session = session, + Messages = [new ChatMessage(ChatRole.User, "Use a tool")], + Options = new ChatOptions + { + Tools = + [ + AIFunctionFactory.CreateDeclaration( + "record_observation", + "Record an observation", + JsonDocument.Parse("""{"type":"object","properties":{"value":{"type":"string"}},"required":["value"]}""").RootElement.Clone(), + returnJsonSchema: null) + ] + }, + Streaming = false + }); + + Assert.Equal("frontier-tools", selection.SelectedProfileId); + Assert.Contains("Falling back from 'gemma4-local'", selection.Explanation, StringComparison.Ordinal); + } + + [Fact] + public void SelectionPolicy_PrefersTaggedProfileWhenRequirementsAreEqual() + { + LlmClientFactory.ResetDynamicProviders(); + LlmClientFactory.RegisterProvider("fake-profile-tests", new EvaluationChatClient()); + + var config = BuildProfileConfig(); + var registry = new ConfiguredModelProfileRegistry(config, NullLogger.Instance); + var policy = new DefaultModelSelectionPolicy(registry); + var session = new Session + { + Id = "s2", + ChannelId = "test", + SenderId = "user", + PreferredModelTags = ["private", "local"] + }; + + var selection = policy.Resolve(new OpenClaw.Core.Abstractions.ModelSelectionRequest + { + Session = session, + Messages = [new ChatMessage(ChatRole.User, "Hello")], + Options = new ChatOptions(), + Streaming = false + }); + + Assert.Equal("gemma4-local", selection.SelectedProfileId); + } + + [Fact] + public void ConfigValidator_RejectsUnknownDefaultModelProfile() + { + var config = new GatewayConfig + { + Models = new ModelsConfig + { + DefaultProfile = "missing-profile", + Profiles = + [ + new ModelProfileConfig + { + Id = "gemma4-local", + Provider = "ollama", + Model = "gemma4" + } + ] + } + }; + + var errors = ConfigValidator.Validate(config); + Assert.Contains(errors, error => error.Contains("Models.DefaultProfile", StringComparison.Ordinal)); + } + + [Fact] + public void LoadGatewayConfig_BindsModelProfiles() + { + var values = new Dictionary + { + ["OpenClaw:Llm:Provider"] = "openai", + ["OpenClaw:Llm:Model"] = "gpt-4.1", + ["OpenClaw:Models:DefaultProfile"] = "gemma4-prod", + ["OpenClaw:Models:Profiles:0:Id"] = "gemma4-local", + ["OpenClaw:Models:Profiles:0:Provider"] = "ollama", + ["OpenClaw:Models:Profiles:0:Model"] = "gemma4", + ["OpenClaw:Models:Profiles:0:Tags:0"] = "local", + ["OpenClaw:Models:Profiles:0:Capabilities:SupportsStreaming"] = "true", + ["OpenClaw:Models:Profiles:1:Id"] = "gemma4-prod", + ["OpenClaw:Models:Profiles:1:Provider"] = "openai-compatible", + ["OpenClaw:Models:Profiles:1:Model"] = "gemma-4", + ["OpenClaw:Models:Profiles:1:BaseUrl"] = "https://example.invalid/v1", + ["OpenClaw:Models:Profiles:1:Capabilities:SupportsTools"] = "true" + }; + + var configuration = new ConfigurationBuilder() + .AddInMemoryCollection(values) + .Build(); + + var config = GatewayBootstrapExtensions.LoadGatewayConfig(configuration); + + Assert.Equal("gemma4-prod", config.Models.DefaultProfile); + Assert.Equal(2, config.Models.Profiles.Count); + Assert.Equal("ollama", config.Models.Profiles[0].Provider); + Assert.Equal("https://example.invalid/v1", config.Models.Profiles[1].BaseUrl); + Assert.True(config.Models.Profiles[1].Capabilities.SupportsTools); + } + + [Fact] + public async Task EvaluationRunner_PersistsJsonAndMarkdownReport() + { + LlmClientFactory.ResetDynamicProviders(); + LlmClientFactory.RegisterProvider("fake-profile-tests", new EvaluationChatClient()); + + var storagePath = Path.Combine(Path.GetTempPath(), "openclaw-model-evals", Guid.NewGuid().ToString("N")); + var config = BuildProfileConfig(); + config.Memory.StoragePath = storagePath; + var registry = new ConfiguredModelProfileRegistry(config, NullLogger.Instance); + var runner = new ModelEvaluationRunner(registry, config, NullLogger.Instance); + + var report = await runner.RunAsync(new ModelEvaluationRequest + { + ProfileId = "frontier-tools", + ScenarioIds = ["plain-chat", "json-extraction", "tool-invocation"], + IncludeMarkdown = true + }, CancellationToken.None); + + Assert.Equal("frontier-tools", Assert.Single(report.Profiles).ProfileId); + Assert.True(File.Exists(report.JsonPath)); + Assert.True(File.Exists(report.MarkdownPath)); + Assert.Contains("Model Evaluation Report", report.Markdown, StringComparison.Ordinal); + Assert.All(report.Profiles.SelectMany(static profile => profile.Scenarios), result => Assert.NotEqual("failed", result.Status)); + } + + private static GatewayConfig BuildProfileConfig() + => new() + { + Llm = new LlmProviderConfig + { + Provider = "fake-profile-tests", + Model = "legacy-model" + }, + Models = new ModelsConfig + { + DefaultProfile = "gemma4-local", + Profiles = + [ + new ModelProfileConfig + { + Id = "gemma4-local", + Provider = "fake-profile-tests", + Model = "gemma4", + Tags = ["local", "private", "cheap"], + FallbackProfileIds = ["frontier-tools"], + Capabilities = new ModelCapabilities + { + SupportsStreaming = true, + SupportsSystemMessages = true, + SupportsVision = true, + SupportsImageInput = true, + MaxContextTokens = 131072, + MaxOutputTokens = 8192 + } + }, + new ModelProfileConfig + { + Id = "frontier-tools", + Provider = "fake-profile-tests", + Model = "frontier", + Tags = ["tool-reliable", "frontier"], + Capabilities = new ModelCapabilities + { + SupportsTools = true, + SupportsJsonSchema = true, + SupportsStructuredOutputs = true, + SupportsStreaming = true, + SupportsParallelToolCalls = true, + SupportsReasoningEffort = true, + SupportsSystemMessages = true, + SupportsVision = true, + SupportsImageInput = true, + MaxContextTokens = 1_000_000, + MaxOutputTokens = 32768 + } + } + ] + } + }; + + private sealed class EvaluationChatClient : IChatClient + { + public Task GetResponseAsync(IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) + { + var lastUser = messages.LastOrDefault(static message => message.Role == ChatRole.User); + var prompt = string.Join("\n", lastUser?.Contents.OfType().Select(static content => content.Text) ?? []); + + if (options?.Tools is { Count: > 0 }) + { + var call = new FunctionCallContent("call_1", "record_observation", new Dictionary { ["value"] = "gemma4" }); + return Task.FromResult(new ChatResponse(new ChatMessage(ChatRole.Assistant, [call]))); + } + + if (options?.ResponseFormat is not null) + { + return Task.FromResult(new ChatResponse(new ChatMessage(ChatRole.Assistant, """{"animal":"fox","count":3}"""))); + } + + if (prompt.Contains("code word", StringComparison.OrdinalIgnoreCase)) + return Task.FromResult(new ChatResponse(new ChatMessage(ChatRole.Assistant, "The code word was maple-42."))); + if (prompt.Contains("branch name", StringComparison.OrdinalIgnoreCase)) + return Task.FromResult(new ChatResponse(new ChatMessage(ChatRole.Assistant, "The branch name was gemma4-rollout."))); + if (prompt.Contains("color", StringComparison.OrdinalIgnoreCase)) + return Task.FromResult(new ChatResponse(new ChatMessage(ChatRole.Assistant, "red"))); + + return Task.FromResult(new ChatResponse(new ChatMessage(ChatRole.Assistant, "READY"))); + } + + public async IAsyncEnumerable GetStreamingResponseAsync(IEnumerable messages, ChatOptions? options = null, [System.Runtime.CompilerServices.EnumeratorCancellation] CancellationToken cancellationToken = default) + { + await Task.Yield(); + yield return new ChatResponseUpdate(ChatRole.Assistant, [new TextContent("Gemma")]); + yield return new ChatResponseUpdate(ChatRole.Assistant, [new TextContent(" streams.")]); + } + + public object? GetService(Type serviceType, object? serviceKey = null) => null; + + public void Dispose() + { + } + } +} From 783374c31bb16c832bcde5ccf05ff257c289b2c0 Mon Sep 17 00:00:00 2001 From: telli Date: Mon, 6 Apr 2026 23:58:55 -0700 Subject: [PATCH 2/2] fix model profile review feedback --- src/OpenClaw.Core/Models/ModelProfiles.cs | 2 +- .../Validation/ConfigValidator.cs | 21 +- src/OpenClaw.Core/Validation/DoctorCheck.cs | 48 ++- .../GatewayLlmExecutionService.cs | 310 +++++++++++++----- .../Models/ConfiguredModelProfileRegistry.cs | 60 +++- .../Models/DefaultModelSelectionPolicy.cs | 93 +++--- .../ModelProfileSelectionTests.cs | 283 +++++++++++++++- 7 files changed, 663 insertions(+), 154 deletions(-) diff --git a/src/OpenClaw.Core/Models/ModelProfiles.cs b/src/OpenClaw.Core/Models/ModelProfiles.cs index fe240b8..742ed69 100644 --- a/src/OpenClaw.Core/Models/ModelProfiles.cs +++ b/src/OpenClaw.Core/Models/ModelProfiles.cs @@ -16,7 +16,7 @@ public sealed class ModelProfileConfig public string[] Tags { get; set; } = []; public string[] FallbackProfileIds { get; set; } = []; public string[] FallbackModels { get; set; } = []; - public ModelCapabilities Capabilities { get; set; } = new(); + public ModelCapabilities? Capabilities { get; set; } } public sealed class ModelCapabilities diff --git a/src/OpenClaw.Core/Validation/ConfigValidator.cs b/src/OpenClaw.Core/Validation/ConfigValidator.cs index 413a4dd..e9a2f1e 100644 --- a/src/OpenClaw.Core/Validation/ConfigValidator.cs +++ b/src/OpenClaw.Core/Validation/ConfigValidator.cs @@ -513,10 +513,9 @@ private static void ValidateRootSet(string field, string[] roots, ICollection errors, bool pluginBackedProvidersPossible) { - if (config.Models.Profiles.Count == 0) - return; - + var hasExplicitProfiles = config.Models.Profiles.Count > 0; var profileIds = new HashSet(StringComparer.OrdinalIgnoreCase); + foreach (var profile in config.Models.Profiles) { if (string.IsNullOrWhiteSpace(profile.Id)) @@ -535,18 +534,30 @@ private static void ValidateModelProfiles(GatewayConfig config, List err if (string.IsNullOrWhiteSpace(profile.Model)) errors.Add($"Models.Profiles.{profile.Id}.Model must be set."); - if (profile.Capabilities.MaxContextTokens < 0) + if (profile.Capabilities?.MaxContextTokens < 0) errors.Add($"Models.Profiles.{profile.Id}.Capabilities.MaxContextTokens must be >= 0."); - if (profile.Capabilities.MaxOutputTokens < 0) + if (profile.Capabilities?.MaxOutputTokens < 0) errors.Add($"Models.Profiles.{profile.Id}.Capabilities.MaxOutputTokens must be >= 0."); } + if (!hasExplicitProfiles) + profileIds.Add("default"); + if (!string.IsNullOrWhiteSpace(config.Models.DefaultProfile) && !profileIds.Contains(config.Models.DefaultProfile)) { errors.Add($"Models.DefaultProfile '{config.Models.DefaultProfile}' does not exist in Models.Profiles."); } + foreach (var profile in config.Models.Profiles) + { + foreach (var fallbackId in profile.FallbackProfileIds.Where(static item => !string.IsNullOrWhiteSpace(item))) + { + if (!profileIds.Contains(fallbackId)) + errors.Add($"Models.Profiles.{profile.Id}.FallbackProfileIds contains unknown profile '{fallbackId}'."); + } + } + foreach (var (routeId, route) in config.Routing.Routes) { if (!string.IsNullOrWhiteSpace(route.ModelProfileId) && !profileIds.Contains(route.ModelProfileId)) diff --git a/src/OpenClaw.Core/Validation/DoctorCheck.cs b/src/OpenClaw.Core/Validation/DoctorCheck.cs index d577b64..9096b4c 100644 --- a/src/OpenClaw.Core/Validation/DoctorCheck.cs +++ b/src/OpenClaw.Core/Validation/DoctorCheck.cs @@ -26,20 +26,9 @@ public static async Task RunAsync(GatewayConfig config, GatewayRuntimeStat allPassed &= Check("LLM max tokens > 0", () => config.Llm.MaxTokens > 0); allPassed &= Check( "Model profile configuration is internally consistent", - () => - { - var profileIds = config.Models.Profiles - .Where(static profile => !string.IsNullOrWhiteSpace(profile.Id)) - .Select(static profile => profile.Id) - .Distinct(StringComparer.OrdinalIgnoreCase) - .Count(); - return config.Models.Profiles.Count == 0 || - (profileIds == config.Models.Profiles.Count && - (string.IsNullOrWhiteSpace(config.Models.DefaultProfile) || - config.Models.Profiles.Any(profile => string.Equals(profile.Id, config.Models.DefaultProfile, StringComparison.OrdinalIgnoreCase)))); - }, + () => HasValidModelProfileConfiguration(config), warnOnly: false, - detail: "Check Models.DefaultProfile, duplicate profile ids, and route profile references."); + detail: "Check Models.DefaultProfile, duplicate profile ids, route profile references, and profile fallback references."); var workspaceRoot = ResolveConfiguredPath(config.Tooling.WorkspaceRoot); if (config.Tooling.WorkspaceOnly) @@ -357,6 +346,39 @@ private static bool HasValidRootSet(string[] roots) return true; } + private static bool HasValidModelProfileConfiguration(GatewayConfig config) + { + var profileIds = new HashSet(StringComparer.OrdinalIgnoreCase); + foreach (var profile in config.Models.Profiles) + { + if (string.IsNullOrWhiteSpace(profile.Id) || !profileIds.Add(profile.Id)) + return false; + } + + if (profileIds.Count == 0) + profileIds.Add("default"); + + if (!string.IsNullOrWhiteSpace(config.Models.DefaultProfile) && !profileIds.Contains(config.Models.DefaultProfile)) + return false; + + foreach (var profile in config.Models.Profiles) + { + if (profile.FallbackProfileIds.Any(fallbackId => !string.IsNullOrWhiteSpace(fallbackId) && !profileIds.Contains(fallbackId))) + return false; + } + + foreach (var route in config.Routing.Routes.Values) + { + if (!string.IsNullOrWhiteSpace(route.ModelProfileId) && !profileIds.Contains(route.ModelProfileId)) + return false; + + if (route.FallbackModelProfileIds.Any(fallbackId => !string.IsNullOrWhiteSpace(fallbackId) && !profileIds.Contains(fallbackId))) + return false; + } + + return true; + } + private static string ResolveConfiguredPath(string? path) => ConfigPathResolver.Resolve(path); } diff --git a/src/OpenClaw.Gateway/GatewayLlmExecutionService.cs b/src/OpenClaw.Gateway/GatewayLlmExecutionService.cs index 92470c9..27ce5f9 100644 --- a/src/OpenClaw.Gateway/GatewayLlmExecutionService.cs +++ b/src/OpenClaw.Gateway/GatewayLlmExecutionService.cs @@ -1,4 +1,5 @@ using System.Collections.Concurrent; +using System.Text; using Microsoft.Extensions.AI; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; @@ -12,6 +13,12 @@ namespace OpenClaw.Gateway; internal sealed class GatewayLlmExecutionService : ILlmExecutionService { + private sealed class CompatibilityServices + { + public required ConfiguredModelProfileRegistry Registry { get; init; } + public required IModelSelectionPolicy SelectionPolicy { get; init; } + } + private sealed class RouteState { public required CircuitBreaker CircuitBreaker { get; init; } @@ -62,8 +69,27 @@ public GatewayLlmExecutionService( ILogger logger) : this( config, - new ConfiguredModelProfileRegistry(config, NullLogger.Instance), - new DefaultModelSelectionPolicy(new ConfiguredModelProfileRegistry(config, NullLogger.Instance)), + CreateCompatibilityServices(config, registry), + policyService, + eventStore, + runtimeMetrics, + providerUsage, + logger) + { + } + + private GatewayLlmExecutionService( + GatewayConfig config, + CompatibilityServices compatibility, + ProviderPolicyService policyService, + RuntimeEventStore eventStore, + RuntimeMetrics runtimeMetrics, + ProviderUsageTracker providerUsage, + ILogger logger) + : this( + config, + compatibility.Registry, + compatibility.SelectionPolicy, policyService, eventStore, runtimeMetrics, @@ -73,30 +99,39 @@ public GatewayLlmExecutionService( } public CircuitState DefaultCircuitState - => GetRouteState( - _modelProfiles.DefaultProfileId ?? "default", - _config.Llm.Provider, - _config.Llm.Model).CircuitBreaker.State; + { + get + { + if (_modelProfiles.DefaultProfileId is not null && + _modelProfiles.TryGetRegistration(_modelProfiles.DefaultProfileId, out var registration) && + registration is not null) + { + return GetRouteStateSnapshot(registration.Profile.Id, registration.Profile.ProviderId, registration.Profile.ModelId).CircuitBreaker.State; + } + + return GetRouteStateSnapshot("default", _config.Llm.Provider, _config.Llm.Model).CircuitBreaker.State; + } + } public IReadOnlyList SnapshotRoutes() - => _modelProfiles.ListStatuses() - .Select(profile => + => BuildRouteDescriptors() + .Select(route => { - var state = GetRouteState(profile.Id, profile.ProviderId, profile.ModelId); + var state = GetRouteStateSnapshot(route.ProfileId ?? "default", route.ProviderId, route.ModelId); return new ProviderRouteHealthSnapshot { - ProfileId = profile.Id, - ProviderId = profile.ProviderId, - ModelId = profile.ModelId, - IsDefaultRoute = profile.IsDefault, + ProfileId = route.ProfileId, + ProviderId = route.ProviderId, + ModelId = route.ModelId, + IsDefaultRoute = route.IsDefaultRoute, CircuitState = state.CircuitBreaker.State.ToString(), Requests = Interlocked.Read(ref state.Requests), Retries = Interlocked.Read(ref state.Retries), Errors = Interlocked.Read(ref state.Errors), LastError = state.LastError, LastErrorAtUtc = state.LastErrorAtUtc, - Tags = profile.Tags, - ValidationIssues = profile.ValidationIssues + Tags = route.Tags, + ValidationIssues = route.ValidationIssues }; }) .OrderBy(static item => item.ProfileId, StringComparer.OrdinalIgnoreCase) @@ -104,8 +139,12 @@ public IReadOnlyList SnapshotRoutes() public void ResetProvider(string providerId) { - foreach (var key in _routes.Keys.Where(key => key.Contains($":{providerId}:", StringComparison.OrdinalIgnoreCase)).ToArray()) + foreach (var key in _routes.Keys.ToArray()) { + if (!TryParseRouteKey(key, out _, out var routeProviderId, out _) || + !string.Equals(routeProviderId, providerId, StringComparison.OrdinalIgnoreCase)) + continue; + if (_routes.TryRemove(key, out var state)) state.CircuitBreaker.Reset(); } @@ -121,18 +160,11 @@ public async Task GetResponseAsync( { var selection = ResolveSelection(session, messages, options, streaming: false); var legacyPolicy = _policyService.Resolve(session, _config.Llm); - - RecordEvent(session, turnContext, "llm", "route_selected", "info", $"Selected provider route {selection.ProviderId}/{selection.ModelId}", new() - { - ["providerId"] = selection.ProviderId, - ["modelId"] = selection.ModelId, - ["profileId"] = selection.SelectedProfileId ?? "", - ["policyRuleId"] = legacyPolicy.RuleId ?? "" - }); if (!string.IsNullOrWhiteSpace(selection.Explanation)) _logger.LogInformation("{Explanation}", selection.Explanation); Exception? lastError = null; + var routeSelectedRecorded = false; foreach (var candidate in selection.Candidates) { if (!_modelProfiles.TryGetRegistration(candidate.Profile.Id, out var registration) || registration?.Client is null) @@ -146,9 +178,14 @@ public async Task GetResponseAsync( for (var modelIndex = 0; modelIndex < modelsToTry.Length; modelIndex++) { var modelId = modelsToTry[modelIndex]; - var routeState = GetRouteState(candidate.Profile.Id, candidate.Profile.ProviderId, modelId); var chatClient = registration.Client; - var effectiveOptions = CreateEffectiveOptions(options, candidate.Profile, registration.ProviderConfig, legacyPolicy, estimate); + if (!TryCreateEffectiveOptions(options, candidate.Profile, registration.ProviderConfig, legacyPolicy, estimate, out var effectiveOptions, out var profileLimitError)) + { + lastError = new ModelSelectionException(profileLimitError); + continue; + } + + var routeState = GetOrAddRouteState(candidate.Profile.Id, candidate.Profile.ProviderId, modelId); for (var attempt = 0; attempt <= registration.ProviderConfig.RetryCount; attempt++) { @@ -167,6 +204,18 @@ public async Task GetResponseAsync( try { + if (!routeSelectedRecorded) + { + routeSelectedRecorded = true; + RecordEvent(session, turnContext, "llm", "route_selected", "info", $"Selected provider route {candidate.Profile.ProviderId}/{modelId}", new() + { + ["providerId"] = candidate.Profile.ProviderId, + ["modelId"] = modelId, + ["profileId"] = candidate.Profile.Id, + ["policyRuleId"] = legacyPolicy.RuleId ?? "" + }); + } + RecordEvent(session, turnContext, "llm", "request_started", "info", $"LLM request started for {candidate.Profile.ProviderId}/{modelId}", new() { ["providerId"] = candidate.Profile.ProviderId, @@ -244,56 +293,65 @@ public Task StartStreamingAsync( { var selection = ResolveSelection(session, messages, options, streaming: true); var legacyPolicy = _policyService.Resolve(session, _config.Llm); - var candidate = selection.Candidates.FirstOrDefault() - ?? throw new InvalidOperationException("No model profile candidate is available for streaming."); - if (!_modelProfiles.TryGetRegistration(candidate.Profile.Id, out var registration) || registration?.Client is null) - throw new ModelSelectionException($"Selected model profile '{candidate.Profile.Id}' is not available."); - - var effectiveOptions = CreateEffectiveOptions(options, candidate.Profile, registration.ProviderConfig, legacyPolicy, estimate); - var selectedModelId = ResolveRequestedModelId(session, candidate.Profile); - var routeState = GetRouteState(candidate.Profile.Id, candidate.Profile.ProviderId, selectedModelId); - var chatClient = registration.Client; - - Interlocked.Increment(ref routeState.Requests); - _providerUsage.RecordRequest(candidate.Profile.ProviderId, selectedModelId); - RecordEvent(session, turnContext, "llm", "route_selected", "info", $"Selected provider route {candidate.Profile.ProviderId}/{selectedModelId}", new() - { - ["providerId"] = candidate.Profile.ProviderId, - ["modelId"] = selectedModelId, - ["profileId"] = candidate.Profile.Id, - ["policyRuleId"] = legacyPolicy.RuleId ?? "" - }); - RecordEvent(session, turnContext, "llm", "stream_started", "info", $"LLM stream started for {candidate.Profile.ProviderId}/{selectedModelId}", new() + Exception? lastError = null; + foreach (var candidate in selection.Candidates) { - ["providerId"] = candidate.Profile.ProviderId, - ["modelId"] = selectedModelId, - ["profileId"] = candidate.Profile.Id, - ["policyRuleId"] = legacyPolicy.RuleId ?? "" - }); + if (!_modelProfiles.TryGetRegistration(candidate.Profile.Id, out var registration) || registration?.Client is null) + continue; - effectiveOptions.ModelId = selectedModelId; - IAsyncEnumerable updates = StreamWithCircuitAsync( - session, - turnContext, - chatClient, - routeState, - candidate.Profile.ProviderId, - selectedModelId, - messages, - effectiveOptions, - registration.ProviderConfig.TimeoutSeconds, - candidate.Profile.Id, - ct); - - return Task.FromResult(new LlmStreamingExecutionResult - { - ProfileId = candidate.Profile.Id, - ProviderId = candidate.Profile.ProviderId, - ModelId = selectedModelId, - PolicyRuleId = legacyPolicy.RuleId, - SelectionExplanation = selection.Explanation, - Updates = updates - }); + if (!TryCreateEffectiveOptions(options, candidate.Profile, registration.ProviderConfig, legacyPolicy, estimate, out var effectiveOptions, out var profileLimitError)) + { + lastError = new ModelSelectionException(profileLimitError); + continue; + } + + var selectedModelId = ResolveRequestedModelId(session, candidate.Profile); + var routeState = GetOrAddRouteState(candidate.Profile.Id, candidate.Profile.ProviderId, selectedModelId); + var chatClient = registration.Client; + + Interlocked.Increment(ref routeState.Requests); + _providerUsage.RecordRequest(candidate.Profile.ProviderId, selectedModelId); + RecordEvent(session, turnContext, "llm", "route_selected", "info", $"Selected provider route {candidate.Profile.ProviderId}/{selectedModelId}", new() + { + ["providerId"] = candidate.Profile.ProviderId, + ["modelId"] = selectedModelId, + ["profileId"] = candidate.Profile.Id, + ["policyRuleId"] = legacyPolicy.RuleId ?? "" + }); + RecordEvent(session, turnContext, "llm", "stream_started", "info", $"LLM stream started for {candidate.Profile.ProviderId}/{selectedModelId}", new() + { + ["providerId"] = candidate.Profile.ProviderId, + ["modelId"] = selectedModelId, + ["profileId"] = candidate.Profile.Id, + ["policyRuleId"] = legacyPolicy.RuleId ?? "" + }); + + effectiveOptions.ModelId = selectedModelId; + IAsyncEnumerable updates = StreamWithCircuitAsync( + session, + turnContext, + chatClient, + routeState, + candidate.Profile.ProviderId, + selectedModelId, + messages, + effectiveOptions, + registration.ProviderConfig.TimeoutSeconds, + candidate.Profile.Id, + ct); + + return Task.FromResult(new LlmStreamingExecutionResult + { + ProfileId = candidate.Profile.Id, + ProviderId = candidate.Profile.ProviderId, + ModelId = selectedModelId, + PolicyRuleId = legacyPolicy.RuleId, + SelectionExplanation = selection.Explanation, + Updates = updates + }); + } + + throw lastError ?? new InvalidOperationException("No model profile candidate is available for streaming."); } private async IAsyncEnumerable StreamWithCircuitAsync( @@ -374,9 +432,9 @@ private async IAsyncEnumerable StreamWithCircuitAsync( } } - private RouteState GetRouteState(string profileId, string providerId, string modelId) + private RouteState GetOrAddRouteState(string profileId, string providerId, string modelId) => _routes.GetOrAdd( - $"{profileId}:{providerId}:{modelId}", + BuildRouteKey(profileId, providerId, modelId), _ => new RouteState { CircuitBreaker = new CircuitBreaker( @@ -406,12 +464,14 @@ private ModelSelectionResult ResolveSelection( }); } - private ChatOptions CreateEffectiveOptions( + private bool TryCreateEffectiveOptions( ChatOptions source, ModelProfile profile, LlmProviderConfig providerConfig, ResolvedProviderRoute legacyPolicy, - LlmExecutionEstimate estimate) + LlmExecutionEstimate estimate, + out ChatOptions effectiveOptions, + out string profileLimitError) { var maxOutputTokens = source.MaxOutputTokens; if (profile.Capabilities.MaxOutputTokens > 0) @@ -421,8 +481,10 @@ private ChatOptions CreateEffectiveOptions( if (profile.Capabilities.MaxContextTokens > 0 && estimate.EstimatedInputTokens > profile.Capabilities.MaxContextTokens) { - throw new ModelSelectionException( - $"Selected model profile '{profile.Id}' cannot satisfy this request because estimated input tokens ({estimate.EstimatedInputTokens}) exceed MaxContextTokens ({profile.Capabilities.MaxContextTokens})."); + profileLimitError = + $"Selected model profile '{profile.Id}' cannot satisfy this request because estimated input tokens ({estimate.EstimatedInputTokens}) exceed MaxContextTokens ({profile.Capabilities.MaxContextTokens})."; + effectiveOptions = source; + return false; } if (legacyPolicy.MaxInputTokens > 0 && estimate.EstimatedInputTokens > legacyPolicy.MaxInputTokens) @@ -444,7 +506,8 @@ private ChatOptions CreateEffectiveOptions( maxOutputTokens = Math.Min(configuredOutput, (int)remaining); } - return new ChatOptions + profileLimitError = string.Empty; + effectiveOptions = new ChatOptions { ModelId = profile.ModelId, MaxOutputTokens = maxOutputTokens, @@ -452,6 +515,7 @@ private ChatOptions CreateEffectiveOptions( Tools = source.Tools, ResponseFormat = source.ResponseFormat }; + return true; } private string ResolveRequestedModelId(Session session, ModelProfile profile) @@ -491,4 +555,88 @@ private static bool IsTransient(Exception ex) || ex is TimeoutException || ex is TaskCanceledException || ex is CircuitOpenException; + + private RouteState GetRouteStateSnapshot(string profileId, string providerId, string modelId) + => _routes.TryGetValue(BuildRouteKey(profileId, providerId, modelId), out var state) + ? state + : new RouteState + { + CircuitBreaker = new CircuitBreaker( + _config.Llm.CircuitBreakerThreshold, + TimeSpan.FromSeconds(_config.Llm.CircuitBreakerCooldownSeconds), + _logger) + }; + + private IReadOnlyList<(string? ProfileId, string ProviderId, string ModelId, bool IsDefaultRoute, string[] Tags, string[] ValidationIssues)> BuildRouteDescriptors() + { + var descriptors = new Dictionary(StringComparer.Ordinal); + var statuses = _modelProfiles.ListStatuses().ToDictionary(status => status.Id, StringComparer.OrdinalIgnoreCase); + + foreach (var status in statuses.Values) + { + foreach (var modelId in status.FallbackModels.Prepend(status.ModelId).Distinct(StringComparer.OrdinalIgnoreCase)) + { + var key = BuildRouteKey(status.Id, status.ProviderId, modelId); + descriptors[key] = (status.Id, status.ProviderId, modelId, status.IsDefault, status.Tags, status.ValidationIssues); + } + } + + foreach (var key in _routes.Keys) + { + if (!TryParseRouteKey(key, out var profileId, out var providerId, out var modelId) || descriptors.ContainsKey(key)) + continue; + + if (statuses.TryGetValue(profileId, out var status)) + { + descriptors[key] = (profileId, providerId, modelId, status.IsDefault, status.Tags, status.ValidationIssues); + continue; + } + + descriptors[key] = (profileId, providerId, modelId, false, [], []); + } + + return descriptors.Values.ToArray(); + } + + private static string BuildRouteKey(string profileId, string providerId, string modelId) + => string.Join(':', EncodeRouteSegment(profileId), EncodeRouteSegment(providerId), EncodeRouteSegment(modelId)); + + private static bool TryParseRouteKey(string key, out string profileId, out string providerId, out string modelId) + { + profileId = string.Empty; + providerId = string.Empty; + modelId = string.Empty; + + var parts = key.Split(':'); + if (parts.Length != 3) + return false; + + try + { + profileId = DecodeRouteSegment(parts[0]); + providerId = DecodeRouteSegment(parts[1]); + modelId = DecodeRouteSegment(parts[2]); + return true; + } + catch + { + return false; + } + } + + private static string EncodeRouteSegment(string value) + => Convert.ToBase64String(Encoding.UTF8.GetBytes(value ?? string.Empty)); + + private static string DecodeRouteSegment(string value) + => Encoding.UTF8.GetString(Convert.FromBase64String(value)); + + private static CompatibilityServices CreateCompatibilityServices(GatewayConfig config, LlmProviderRegistry registry) + { + var modelProfiles = new ConfiguredModelProfileRegistry(config, NullLogger.Instance, registry); + return new CompatibilityServices + { + Registry = modelProfiles, + SelectionPolicy = new DefaultModelSelectionPolicy(modelProfiles) + }; + } } diff --git a/src/OpenClaw.Gateway/Models/ConfiguredModelProfileRegistry.cs b/src/OpenClaw.Gateway/Models/ConfiguredModelProfileRegistry.cs index 10715c1..9799723 100644 --- a/src/OpenClaw.Gateway/Models/ConfiguredModelProfileRegistry.cs +++ b/src/OpenClaw.Gateway/Models/ConfiguredModelProfileRegistry.cs @@ -3,6 +3,7 @@ using Microsoft.Extensions.Logging; using OpenClaw.Core.Abstractions; using OpenClaw.Core.Models; +using OpenClaw.Core.Security; using OpenClaw.Gateway.Extensions; namespace OpenClaw.Gateway.Models; @@ -20,10 +21,20 @@ internal sealed class Registration private readonly ConcurrentDictionary _registrations = new(StringComparer.OrdinalIgnoreCase); private readonly ILogger _logger; + private readonly LlmProviderRegistry? _providerRegistry; public ConfiguredModelProfileRegistry(GatewayConfig config, ILogger logger) + : this(config, logger, null) + { + } + + public ConfiguredModelProfileRegistry( + GatewayConfig config, + ILogger logger, + LlmProviderRegistry? providerRegistry) { _logger = logger; + _providerRegistry = providerRegistry; DefaultProfileId = BuildRegistrations(config); } @@ -80,14 +91,17 @@ private string BuildRegistrations(GatewayConfig config) IChatClient? client = null; if (issues.Length == 0) { - try - { - client = LlmClientFactory.CreateChatClient(providerConfig); - } - catch (Exception ex) + if (!TryResolveRegisteredClient(profile, out client)) { - issues = [.. issues, ex.Message]; - _logger.LogWarning(ex, "Failed to initialize model profile {ProfileId}", profile.Id); + try + { + client = LlmClientFactory.CreateChatClient(providerConfig); + } + catch (Exception ex) + { + issues = [.. issues, ex.Message]; + _logger.LogWarning(ex, "Failed to initialize model profile {ProfileId}", profile.Id); + } } } @@ -164,12 +178,12 @@ private static ModelProfile ToProfile(GatewayConfig config, ModelProfileConfig m Id = Normalize(model.Id) ?? "default", ProviderId = Normalize(model.Provider) ?? config.Llm.Provider, ModelId = Normalize(model.Model) ?? config.Llm.Model, - BaseUrl = Normalize(model.BaseUrl), - ApiKey = Normalize(model.ApiKey), + BaseUrl = ResolveSecretValue(model.BaseUrl), + ApiKey = ResolveSecretValue(model.ApiKey), Tags = NormalizeDistinct(model.Tags), FallbackProfileIds = NormalizeDistinct(model.FallbackProfileIds), FallbackModels = NormalizeDistinct(model.FallbackModels), - Capabilities = model.Capabilities ?? GuessCapabilities(model.Provider), + Capabilities = model.Capabilities ?? GuessCapabilities(Normalize(model.Provider) ?? config.Llm.Provider), IsImplicit = string.Equals(model.Id, "default", StringComparison.OrdinalIgnoreCase) && config.Models.Profiles.Count == 0 }; @@ -228,6 +242,15 @@ internal static LlmProviderConfig BuildProviderConfig(GatewayConfig config, Mode private static string? Normalize(string? value) => string.IsNullOrWhiteSpace(value) ? null : value.Trim(); + private static string? ResolveSecretValue(string? value) + { + if (string.IsNullOrWhiteSpace(value)) + return null; + + var resolved = SecretResolver.Resolve(value); + return string.IsNullOrWhiteSpace(resolved) ? null : resolved.Trim(); + } + private static string[] NormalizeDistinct(IEnumerable? values) => values is null ? [] @@ -235,4 +258,21 @@ private static string[] NormalizeDistinct(IEnumerable? values) .Select(static item => item.Trim()) .Distinct(StringComparer.OrdinalIgnoreCase) .ToArray(); + + private bool TryResolveRegisteredClient(ModelProfile profile, out IChatClient? client) + { + client = null; + if (_providerRegistry is null || !_providerRegistry.TryGet(profile.ProviderId, out var registration) || registration?.Client is null) + return false; + + if (registration.Models.Length > 0 && + !registration.Models.Contains(profile.ModelId, StringComparer.OrdinalIgnoreCase) && + !profile.FallbackModels.Any(model => registration.Models.Contains(model, StringComparer.OrdinalIgnoreCase))) + { + return false; + } + + client = registration.Client; + return true; + } } diff --git a/src/OpenClaw.Gateway/Models/DefaultModelSelectionPolicy.cs b/src/OpenClaw.Gateway/Models/DefaultModelSelectionPolicy.cs index d10df59..8980e58 100644 --- a/src/OpenClaw.Gateway/Models/DefaultModelSelectionPolicy.cs +++ b/src/OpenClaw.Gateway/Models/DefaultModelSelectionPolicy.cs @@ -20,65 +20,50 @@ public ModelSelectionResult Resolve(ModelSelectionRequest request) var fallbackProfileIds = CollectFallbackProfileIds(request.Session); var explicitProfileId = Normalize(request.ExplicitProfileId) ?? Normalize(request.Session.ModelProfileId); - var attempted = new List(); if (!string.IsNullOrWhiteSpace(explicitProfileId)) { if (!_registry.TryGetRegistration(explicitProfileId, out var explicitRegistration) || explicitRegistration is null) throw new ModelSelectionException($"Selected model profile '{explicitProfileId}' is not registered."); - attempted.Add(new ModelSelectionCandidate + var explicitCandidates = explicitRegistration.Profile.FallbackProfileIds + .Concat(fallbackProfileIds) + .Distinct(StringComparer.OrdinalIgnoreCase) + .Select(fallbackId => _registry.TryGetRegistration(fallbackId, out var fallbackRegistration) ? fallbackRegistration : null) + .Where(static registration => registration is not null) + .Cast() + .Where(registration => IsSelectable(registration) && Satisfies(registration.Profile, requirements)) + .Select(ToCandidate) + .ToList(); + var explicitMissing = DescribeUnselectable(explicitRegistration, requirements); + if (IsSelectable(explicitRegistration) && Satisfies(explicitRegistration.Profile, requirements)) { - Profile = explicitRegistration.Profile, - FallbackModels = explicitRegistration.ProviderConfig.FallbackModels - }); - - if (Satisfies(explicitRegistration.Profile, requirements)) - return BuildResult(explicitProfileId, explicitRegistration.Profile, requirements, preferredTags, attempted, null); - - foreach (var fallbackId in explicitRegistration.Profile.FallbackProfileIds.Concat(fallbackProfileIds)) + explicitCandidates.Insert(0, ToCandidate(explicitRegistration)); + return BuildResult(explicitProfileId, explicitRegistration.Profile, requirements, preferredTags, explicitCandidates, null); + } + if (explicitCandidates.Count > 0) { - if (!_registry.TryGetRegistration(fallbackId, out var fallbackRegistration) || fallbackRegistration is null) - continue; - - attempted.Add(new ModelSelectionCandidate - { - Profile = fallbackRegistration.Profile, - FallbackModels = fallbackRegistration.ProviderConfig.FallbackModels - }); - - if (Satisfies(fallbackRegistration.Profile, requirements)) - { - var explanation = - $"Falling back from '{explicitRegistration.Profile.Id}' to '{fallbackRegistration.Profile.Id}' because {DescribeMissing(explicitRegistration.Profile, requirements)}."; - return BuildResult(explicitProfileId, fallbackRegistration.Profile, requirements, preferredTags, attempted, explanation); - } + var explanation = + $"Falling back from '{explicitRegistration.Profile.Id}' to '{explicitCandidates[0].Profile.Id}' because {explicitMissing}."; + return BuildResult(explicitProfileId, explicitCandidates[0].Profile, requirements, preferredTags, explicitCandidates, explanation); } throw new ModelSelectionException( - $"This route requires {DescribeRequirementSummary(requirements)}, but selected model profile '{explicitRegistration.Profile.Id}' does not support it."); + $"This route requires {DescribeRequirementSummary(requirements)}, but selected model profile '{explicitRegistration.Profile.Id}' cannot satisfy it because {explicitMissing}."); } var candidates = _registry.ListStatuses() .OrderByDescending(item => Score(item, preferredTags, requirements)) .ThenByDescending(static item => item.IsDefault) - .ThenBy(static item => item.Id, StringComparer.OrdinalIgnoreCase); - - foreach (var status in candidates) - { - if (!_registry.TryGetRegistration(status.Id, out var registration) || registration is null) - continue; - - if (!Satisfies(registration.Profile, requirements)) - continue; - - attempted.Add(new ModelSelectionCandidate - { - Profile = registration.Profile, - FallbackModels = registration.ProviderConfig.FallbackModels - }); + .ThenBy(static item => item.Id, StringComparer.OrdinalIgnoreCase) + .Select(status => _registry.TryGetRegistration(status.Id, out var registration) ? registration : null) + .Where(static registration => registration is not null) + .Cast() + .Where(registration => IsSelectable(registration) && Satisfies(registration.Profile, requirements)) + .Select(ToCandidate) + .ToArray(); - return BuildResult(null, registration.Profile, requirements, preferredTags, attempted, null); - } + if (candidates.Length > 0) + return BuildResult(null, candidates[0].Profile, requirements, preferredTags, candidates, null); throw new ModelSelectionException( $"No configured model profile satisfies the current request requirements ({DescribeRequirementSummary(requirements)})."); @@ -122,6 +107,16 @@ private static int Score(ModelProfileStatus status, IReadOnlyList prefer return score; } + private static bool IsSelectable(ConfiguredModelProfileRegistry.Registration registration) + => registration.Client is not null && registration.ValidationIssues.Length == 0; + + private static ModelSelectionCandidate ToCandidate(ConfiguredModelProfileRegistry.Registration registration) + => new() + { + Profile = registration.Profile, + FallbackModels = registration.ProviderConfig.FallbackModels + }; + internal static ModelSelectionRequirements BuildRequirements(ModelSelectionRequest request) { var combined = Clone(request.Session.ModelRequirements); @@ -224,6 +219,18 @@ private static string DescribeMissing(ModelProfile profile, ModelSelectionRequir return missing.Count == 0 ? "required capabilities were not satisfied" : string.Join(", ", missing); } + private static string DescribeUnselectable( + ConfiguredModelProfileRegistry.Registration registration, + ModelSelectionRequirements requirements) + { + if (!IsSelectable(registration)) + return registration.ValidationIssues.Length > 0 + ? string.Join("; ", registration.ValidationIssues) + : "the profile is not available"; + + return DescribeMissing(registration.Profile, requirements); + } + private static string DescribeRequirementSummary(ModelSelectionRequirements requirements) { var items = new List(); diff --git a/src/OpenClaw.Tests/ModelProfileSelectionTests.cs b/src/OpenClaw.Tests/ModelProfileSelectionTests.cs index 5bc1404..ce71fdf 100644 --- a/src/OpenClaw.Tests/ModelProfileSelectionTests.cs +++ b/src/OpenClaw.Tests/ModelProfileSelectionTests.cs @@ -2,8 +2,11 @@ using Microsoft.Extensions.Configuration; using Microsoft.Extensions.AI; using Microsoft.Extensions.Logging.Abstractions; +using OpenClaw.Agent; using OpenClaw.Core.Models; +using OpenClaw.Core.Observability; using OpenClaw.Core.Validation; +using OpenClaw.Gateway; using OpenClaw.Gateway.Bootstrap; using OpenClaw.Gateway.Extensions; using OpenClaw.Gateway.Models; @@ -135,6 +138,67 @@ public void ConfigValidator_RejectsUnknownDefaultModelProfile() Assert.Contains(errors, error => error.Contains("Models.DefaultProfile", StringComparison.Ordinal)); } + [Fact] + public void ConfigValidator_RejectsUnknownFallbackProfileIds() + { + var config = new GatewayConfig + { + Models = new ModelsConfig + { + Profiles = + [ + new ModelProfileConfig + { + Id = "gemma4-local", + Provider = "ollama", + Model = "gemma4", + FallbackProfileIds = ["missing-profile"] + } + ] + }, + Routing = new RoutingConfig + { + Routes = new Dictionary(StringComparer.OrdinalIgnoreCase) + { + ["telegram:coder"] = new() + { + ModelProfileId = "missing-profile" + } + } + } + }; + + var errors = ConfigValidator.Validate(config); + Assert.Contains(errors, error => error.Contains("Models.Profiles.gemma4-local.FallbackProfileIds", StringComparison.Ordinal)); + Assert.Contains(errors, error => error.Contains("Routing.Routes.telegram:coder.ModelProfileId", StringComparison.Ordinal)); + } + + [Fact] + public void ConfigValidator_RejectsExplicitRouteProfileWhenUsingImplicitDefaultOnly() + { + var config = new GatewayConfig + { + Llm = new LlmProviderConfig + { + Provider = "openai", + Model = "gpt-4.1" + }, + Routing = new RoutingConfig + { + Routes = new Dictionary(StringComparer.OrdinalIgnoreCase) + { + ["telegram:coder"] = new() + { + ModelProfileId = "gemma4-local" + } + } + } + }; + + var errors = ConfigValidator.Validate(config); + Assert.Contains(errors, error => error.Contains("Routing.Routes.telegram:coder.ModelProfileId", StringComparison.Ordinal)); + } + [Fact] public void LoadGatewayConfig_BindsModelProfiles() { @@ -165,7 +229,224 @@ public void LoadGatewayConfig_BindsModelProfiles() Assert.Equal(2, config.Models.Profiles.Count); Assert.Equal("ollama", config.Models.Profiles[0].Provider); Assert.Equal("https://example.invalid/v1", config.Models.Profiles[1].BaseUrl); - Assert.True(config.Models.Profiles[1].Capabilities.SupportsTools); + Assert.NotNull(config.Models.Profiles[1].Capabilities); + Assert.True(config.Models.Profiles[1].Capabilities!.SupportsTools); + } + + [Fact] + public void Registry_WhenCapabilitiesOmitted_UsesProviderGuessAndResolvesSecrets() + { + LlmClientFactory.ResetDynamicProviders(); + LlmClientFactory.RegisterProvider("openai-compatible", new EvaluationChatClient()); + + Environment.SetEnvironmentVariable("MODEL_PROFILE_ENDPOINT", "https://example.invalid/v1"); + Environment.SetEnvironmentVariable("MODEL_PROFILE_KEY", "secret-token"); + try + { + var config = new GatewayConfig + { + Llm = new LlmProviderConfig + { + Provider = "openai", + Model = "gpt-4.1", + ApiKey = "fallback-key" + }, + Models = new ModelsConfig + { + Profiles = + [ + new ModelProfileConfig + { + Id = "gemma4-prod", + Provider = "openai-compatible", + Model = "gemma-4", + BaseUrl = "env:MODEL_PROFILE_ENDPOINT", + ApiKey = "env:MODEL_PROFILE_KEY" + } + ] + } + }; + + var registry = new ConfiguredModelProfileRegistry(config, NullLogger.Instance); + Assert.True(registry.TryGet("gemma4-prod", out var profile)); + Assert.NotNull(profile); + Assert.Equal("https://example.invalid/v1", profile!.BaseUrl); + Assert.Equal("secret-token", profile.ApiKey); + Assert.True(profile.Capabilities.SupportsTools); + Assert.True(profile.Capabilities.SupportsStructuredOutputs); + } + finally + { + Environment.SetEnvironmentVariable("MODEL_PROFILE_ENDPOINT", null); + Environment.SetEnvironmentVariable("MODEL_PROFILE_KEY", null); + } + } + + [Fact] + public void SelectionPolicy_SkipsUnavailableExplicitProfileAndFallsBack() + { + LlmClientFactory.ResetDynamicProviders(); + LlmClientFactory.RegisterProvider("fake-profile-tests", new EvaluationChatClient()); + + var config = new GatewayConfig + { + Llm = new LlmProviderConfig + { + Provider = "fake-profile-tests", + Model = "legacy-model" + }, + Models = new ModelsConfig + { + Profiles = + [ + new ModelProfileConfig + { + Id = "broken-remote", + Provider = "openai-compatible", + Model = "gemma-4", + FallbackProfileIds = ["frontier-tools"], + Capabilities = new ModelCapabilities + { + SupportsTools = true + } + }, + new ModelProfileConfig + { + Id = "frontier-tools", + Provider = "fake-profile-tests", + Model = "frontier", + Capabilities = new ModelCapabilities + { + SupportsTools = true, + SupportsStreaming = true, + SupportsSystemMessages = true + } + } + ] + } + }; + + var registry = new ConfiguredModelProfileRegistry(config, NullLogger.Instance); + var policy = new DefaultModelSelectionPolicy(registry); + var selection = policy.Resolve(new OpenClaw.Core.Abstractions.ModelSelectionRequest + { + ExplicitProfileId = "broken-remote", + Session = new Session + { + Id = "s3", + ChannelId = "test", + SenderId = "user" + }, + Messages = [new ChatMessage(ChatRole.User, "Need tools")], + Options = new ChatOptions + { + Tools = + [ + AIFunctionFactory.CreateDeclaration( + "record_observation", + "Record an observation", + JsonDocument.Parse("""{"type":"object","properties":{"value":{"type":"string"}},"required":["value"]}""").RootElement.Clone(), + returnJsonSchema: null) + ] + } + }); + + Assert.Equal("frontier-tools", selection.SelectedProfileId); + Assert.Contains("broken-remote", selection.Explanation, StringComparison.Ordinal); + } + + [Fact] + public async Task GatewayExecution_FallsBackWhenSelectedProfileContextTooSmall() + { + LlmClientFactory.ResetDynamicProviders(); + LlmClientFactory.RegisterProvider("fake-profile-tests", new EvaluationChatClient()); + + var storagePath = Path.Combine(Path.GetTempPath(), "openclaw-model-selection", Guid.NewGuid().ToString("N")); + Directory.CreateDirectory(storagePath); + var config = BuildProfileConfig(); + var registry = new ConfiguredModelProfileRegistry(config, NullLogger.Instance); + var policy = new DefaultModelSelectionPolicy(registry); + var service = new GatewayLlmExecutionService( + config, + registry, + policy, + new ProviderPolicyService(storagePath, NullLogger.Instance), + new RuntimeEventStore(storagePath, NullLogger.Instance), + new RuntimeMetrics(), + new ProviderUsageTracker(), + NullLogger.Instance); + + var session = new Session + { + Id = "s4", + ChannelId = "test", + SenderId = "user", + ModelProfileId = "gemma4-local" + }; + + var result = await service.GetResponseAsync( + session, + [new ChatMessage(ChatRole.User, "hello")], + new ChatOptions(), + new TurnContext { SessionId = session.Id, ChannelId = session.ChannelId }, + new LlmExecutionEstimate + { + EstimatedInputTokens = 200_000, + EstimatedInputTokensByComponent = new InputTokenComponentEstimate() + }, + CancellationToken.None); + + Assert.Equal("frontier-tools", result.ProfileId); + Assert.Equal("frontier", result.ModelId); + } + + [Fact] + public async Task GatewayExecutionService_CompatibilityConstructor_UsesInjectedProviderRegistry() + { + var storagePath = Path.Combine(Path.GetTempPath(), "openclaw-model-compat", Guid.NewGuid().ToString("N")); + Directory.CreateDirectory(storagePath); + var config = new GatewayConfig + { + Llm = new LlmProviderConfig + { + Provider = "fake-injected-provider", + Model = "legacy-model" + } + }; + + var providerRegistry = new LlmProviderRegistry(); + providerRegistry.RegisterDefault(config.Llm, new EvaluationChatClient()); + var service = new GatewayLlmExecutionService( + config, + providerRegistry, + new ProviderPolicyService(storagePath, NullLogger.Instance), + new RuntimeEventStore(storagePath, NullLogger.Instance), + new RuntimeMetrics(), + new ProviderUsageTracker(), + NullLogger.Instance); + + var session = new Session + { + Id = "s5", + ChannelId = "test", + SenderId = "user" + }; + + var result = await service.GetResponseAsync( + session, + [new ChatMessage(ChatRole.User, "hello")], + new ChatOptions(), + new TurnContext { SessionId = session.Id, ChannelId = session.ChannelId }, + new LlmExecutionEstimate + { + EstimatedInputTokens = 16, + EstimatedInputTokensByComponent = new InputTokenComponentEstimate() + }, + CancellationToken.None); + + Assert.Equal("default", result.ProfileId); + Assert.Equal("fake-injected-provider", result.ProviderId); + Assert.Equal("legacy-model", result.ModelId); } [Fact]