diff --git a/README.md b/README.md index a5ac41b..1ddca7a 100644 --- a/README.md +++ b/README.md @@ -65,6 +65,18 @@ OpenClaw now supports **provider-agnostic named model profiles**. This keeps mod See [Model Profiles and Gemma](docs/MODEL_PROFILES.md) for the full configuration and evaluation guide. +### Prompt Caching + +OpenClaw can now attach **provider-aware prompt caching hints** through the existing model-profile and provider seams rather than introducing a cache-specific runtime path. + +- Configure prompt caching globally under `OpenClaw:Llm:PromptCaching` or per named model profile +- Supported cache dialects are normalized as `openai`, `anthropic`, `gemini`, or `none` +- `openai-compatible` and dynamic providers must opt into a dialect explicitly before cache hints are sent +- Cache usage is normalized into `cacheRead` / `cacheWrite` counters and exposed through diagnostics, session status, and provider usage summaries +- Keep-warm is intentionally selective in v1 and only applies to providers with explicit cache TTL/resource semantics + +See [Prompt Caching](docs/PROMPT_CACHING.md) for configuration, provider behavior, and diagnostics details. + ### Review-First Learning - The runtime can observe completed sessions and create **pending learning proposals** instead of auto-mutating behavior diff --git a/docs/PROMPT_CACHING.md b/docs/PROMPT_CACHING.md new file mode 100644 index 0000000..9ea5e86 --- /dev/null +++ b/docs/PROMPT_CACHING.md @@ -0,0 +1,190 @@ +# Prompt Caching + +OpenClaw.NET supports prompt caching as a provider-aware optimization layered on top of the existing provider and model-profile architecture. The runtime still talks to providers through the same `ILlmExecutionService` and model-selection flow. Prompt caching only changes request shaping, normalized usage accounting, and optional keep-warm behavior. + +## Why it exists + +Prompt caching helps when a large prefix of the request stays stable across turns: + +- base system prompt +- tool declarations +- skill prompt content +- stable workspace prompt files + +When the upstream provider supports prompt caching, OpenClaw can attach cache hints and normalize returned cache usage as: + +- `cacheRead` +- `cacheWrite` + +This improves cost and latency visibility for long-running sessions without introducing a provider-specific runtime fork. + +## Configuration + +Prompt caching can be configured globally: + +```json +{ + "OpenClaw": { + "Llm": { + "Provider": "openai", + "Model": "gpt-4.1", + "PromptCaching": { + "Enabled": true, + "Retention": "auto", + "Dialect": "openai", + "KeepWarmEnabled": false, + "KeepWarmIntervalMinutes": 55, + "TraceEnabled": false, + "TraceFilePath": "./memory/logs/cache-trace.jsonl" + } + } + } +} +``` + +Or per model profile: + +```json +{ + "OpenClaw": { + "Models": { + "DefaultProfile": "gemma4-prod", + "Profiles": [ + { + "Id": "gemma4-prod", + "Provider": "openai-compatible", + "Model": "gemma-4", + "BaseUrl": "https://gateway.example.com/v1", + "ApiKey": "env:MODEL_PROVIDER_KEY", + "PromptCaching": { + "Enabled": true, + "Dialect": "openai", + "Retention": "auto" + } + }, + { + "Id": "claude-research", + "Provider": "anthropic", + "Model": "claude-sonnet-4.5", + "PromptCaching": { + "Enabled": true, + "Dialect": "anthropic", + "Retention": "long", + "KeepWarmEnabled": true, + "KeepWarmIntervalMinutes": 55 + } + } + ] + } + } +} +``` + +Profile settings override the global `OpenClaw:Llm:PromptCaching` values field-by-field. + +## Supported fields + +- `Enabled`: turns prompt caching behavior on for that scope +- `Retention`: `none`, `short`, `long`, or `auto` +- `Dialect`: `auto`, `openai`, `anthropic`, `gemini`, or `none` +- `KeepWarmEnabled`: enables selective keep-warm for eligible providers +- `KeepWarmIntervalMinutes`: minimum warm interval +- `TraceEnabled`: emits cache-trace JSONL entries +- `TraceFilePath`: optional trace output path + +## Provider behavior + +### OpenAI and Azure OpenAI + +- Uses deterministic cache-key hints through request additional properties +- Normalizes provider-reported cached prompt tokens into `cacheRead` +- Does not fabricate `cacheWrite` when the provider does not report it + +### OpenAI-compatible + +- Prompt caching is only enabled when `Dialect` is explicitly set to `openai` +- If prompt caching is enabled but the dialect stays `auto`, config validation and doctor mode warn before runtime + +### Anthropic and Anthropic Vertex + +- Uses Anthropic-style cache hints +- Maps provider cache read and cache creation/write usage when reported +- Eligible for keep-warm when explicitly enabled + +### Amazon Bedrock + +- Bedrock is available as a provider id for cache-policy routing and validation +- Anthropic-style cache behavior is only meaningful for Anthropic Claude models behind a Bedrock-compatible endpoint or adapter +- Non-Anthropic Bedrock models are treated as no-cache for retention/keep-warm purposes + +### Gemini + +- Uses Gemini cache dialect hints and normalized cache accounting +- Eligible for keep-warm when explicitly enabled + +### Ollama + +- No prompt caching behavior in v1 +- Model capabilities reflect that prompt caching is unsupported + +### Dynamic / plugin providers + +- Prompt cache hints are passed through `ChatOptions.AdditionalProperties` +- The provider must opt into a cache dialect explicitly +- If the provider returns usage counters with cache fields, OpenClaw normalizes them into `cacheRead` / `cacheWrite` + +## Diagnostics + +Prompt cache usage is surfaced in: + +- `/metrics/providers` +- `/doctor/text` +- session status summaries +- `/status` and `/usage` command output + +If live session cache totals are missing, OpenClaw falls back to the most recent nonzero cache counters recorded in provider usage history for that session. + +## Cache tracing + +Cache tracing can be enabled with config: + +```json +{ + "OpenClaw": { + "Diagnostics": { + "CacheTrace": { + "Enabled": true, + "FilePath": "./memory/logs/cache-trace.jsonl", + "IncludeMessages": true, + "IncludePrompt": true, + "IncludeSystem": true + } + } + } +} +``` + +Or with environment variables: + +- `OPENCLAW_CACHE_TRACE=1` +- `OPENCLAW_CACHE_TRACE_FILE=/path/to/cache-trace.jsonl` +- `OPENCLAW_CACHE_TRACE_PROMPT=0|1` +- `OPENCLAW_CACHE_TRACE_SYSTEM=0|1` + +Trace output is JSONL and includes: + +- selected profile/provider/model +- dialect and retention +- stable fingerprint +- normalized cache usage counters + +## Keep-warm + +Keep-warm is intentionally conservative in v1. + +- It runs in a dedicated background service +- It only warms active sessions with recent stable prompt fingerprints +- It only warms profiles that explicitly set `KeepWarmEnabled=true` +- It only applies to providers with explicit TTL or cache-resource semantics + +Providers that are not explicitly eligible are skipped without failing normal requests. diff --git a/src/OpenClaw.Agent/AgentRuntime.cs b/src/OpenClaw.Agent/AgentRuntime.cs index 16e6f47..df697c1 100644 --- a/src/OpenClaw.Agent/AgentRuntime.cs +++ b/src/OpenClaw.Agent/AgentRuntime.cs @@ -65,6 +65,7 @@ public sealed class AgentRuntime : IAgentRuntime private readonly SkillsConfig? _skillsConfig; private readonly string? _skillWorkspacePath; private readonly IReadOnlyList _pluginSkillDirs; + private readonly string? _memoryRecallPrefix; private readonly object _skillGate = new(); private string[] _loadedSkillNames = []; private int _skillPromptLength; @@ -158,6 +159,9 @@ public AgentRuntime( _isContractRuntimeBudgetExceeded = isContractRuntimeBudgetExceeded; _recordContractTurnUsage = recordContractTurnUsage; _appendContractSnapshot = appendContractSnapshot; + var projectId = gatewayConfig?.Memory.ProjectId + ?? Environment.GetEnvironmentVariable("OPENCLAW_PROJECT"); + _memoryRecallPrefix = string.IsNullOrWhiteSpace(projectId) ? null : $"project:{projectId.Trim()}:"; ApplySkills(skills ?? []); } @@ -326,11 +330,15 @@ public async Task RunAsync( // Extract token usage from response var inputTokens = response.Usage?.InputTokenCount ?? 0; var outputTokens = response.Usage?.OutputTokenCount ?? 0; + var cacheUsage = PromptCacheUsageExtractor.FromUsage(response.Usage); turnCtx.RecordLlmCall(llmSw.Elapsed, inputTokens, outputTokens); _metrics?.IncrementLlmCalls(); _metrics?.AddInputTokens(inputTokens); _metrics?.AddOutputTokens(outputTokens); + _metrics?.AddPromptCacheReads(cacheUsage.CacheReadTokens); + _metrics?.AddPromptCacheWrites(cacheUsage.CacheWriteTokens); _providerUsage?.AddTokens(executionResult.ProviderId, executionResult.ModelId, inputTokens, outputTokens); + _providerUsage?.AddCacheTokens(executionResult.ProviderId, executionResult.ModelId, cacheUsage.CacheReadTokens, cacheUsage.CacheWriteTokens); _providerUsage?.RecordTurn( session.Id, session.ChannelId, @@ -338,10 +346,13 @@ public async Task RunAsync( executionResult.ModelId, inputTokens, outputTokens, + cacheUsage.CacheReadTokens, + cacheUsage.CacheWriteTokens, LlmExecutionEstimateBuilder.BuildInputTokenEstimate(messages, inputTokens, _skillPromptLength)); // Track token usage on the session session.AddTokenUsage(inputTokens, outputTokens); + session.AddCacheUsage(cacheUsage.CacheReadTokens, cacheUsage.CacheWriteTokens); _recordContractTurnUsage?.Invoke(session, executionResult.ProviderId, executionResult.ModelId, inputTokens, outputTokens); if (TryRejectContractBudget(session, out contractBudgetMessage)) @@ -496,6 +507,7 @@ public async IAsyncEnumerable RunStreamingAsync( } session.AddTokenUsage(streamResult.InputTokens, streamResult.OutputTokens); + session.AddCacheUsage(streamResult.CacheReadTokens, streamResult.CacheWriteTokens); if (!string.IsNullOrWhiteSpace(streamResult.ProviderId) && !string.IsNullOrWhiteSpace(streamResult.ModelId)) _recordContractTurnUsage?.Invoke(session, streamResult.ProviderId, streamResult.ModelId, streamResult.InputTokens, streamResult.OutputTokens); if (!string.IsNullOrWhiteSpace(streamResult.ProviderId) && !string.IsNullOrWhiteSpace(streamResult.ModelId)) @@ -507,6 +519,8 @@ public async IAsyncEnumerable RunStreamingAsync( streamResult.ModelId, streamResult.InputTokens, streamResult.OutputTokens, + streamResult.CacheReadTokens, + streamResult.CacheWriteTokens, LlmExecutionEstimateBuilder.BuildInputTokenEstimate(messages, streamResult.InputTokens, _skillPromptLength)); } @@ -632,9 +646,16 @@ private async ValueTask TryInjectRecallAsync(List messages, string try { var limit = Math.Clamp(_recall.MaxNotes, 1, 32); - var hits = await search.SearchNotesAsync(userMessage, prefix: null, limit, ct); + _metrics?.IncrementMemoryRecallSearches(); + var hits = await search.SearchNotesAsync(userMessage, _memoryRecallPrefix, limit, ct); + if (hits.Count == 0 && !string.IsNullOrWhiteSpace(_memoryRecallPrefix)) + { + _metrics?.IncrementMemoryRecallSearches(); + hits = await search.SearchNotesAsync(userMessage, prefix: null, limit, ct); + } if (hits.Count == 0) return; + _metrics?.AddMemoryRecallHits(hits.Count); var maxChars = Math.Clamp(_recall.MaxChars, 256, 100_000); @@ -743,6 +764,8 @@ private sealed class StreamCollectResult public List ToolCalls { get; } = []; public int InputTokens { get; set; } public int OutputTokens { get; set; } + public int CacheReadTokens { get; set; } + public int CacheWriteTokens { get; set; } public string? ProviderId { get; set; } public string? ModelId { get; set; } public string? Error { get; set; } @@ -789,6 +812,11 @@ private async Task StreamLlmCollectAsync( result.InputTokens = (int)usage.Details.InputTokenCount.Value; if (usage.Details.OutputTokenCount is > 0) result.OutputTokens = (int)usage.Details.OutputTokenCount.Value; + var cacheUsage = PromptCacheUsageExtractor.FromUsage(usage.Details); + if (cacheUsage.CacheReadTokens > 0) + result.CacheReadTokens = (int)cacheUsage.CacheReadTokens; + if (cacheUsage.CacheWriteTokens > 0) + result.CacheWriteTokens = (int)cacheUsage.CacheWriteTokens; } } } @@ -884,6 +912,11 @@ private async Task StreamLlmCollectAsync( result.InputTokens = (int)usage.Details.InputTokenCount.Value; if (usage.Details.OutputTokenCount is > 0) result.OutputTokens = (int)usage.Details.OutputTokenCount.Value; + var cacheUsage = PromptCacheUsageExtractor.FromUsage(usage.Details); + if (cacheUsage.CacheReadTokens > 0) + result.CacheReadTokens = (int)cacheUsage.CacheReadTokens; + if (cacheUsage.CacheWriteTokens > 0) + result.CacheWriteTokens = (int)cacheUsage.CacheWriteTokens; } } } @@ -936,7 +969,10 @@ private async Task StreamLlmCollectAsync( _metrics?.IncrementLlmCalls(); _metrics?.AddInputTokens(result.InputTokens); _metrics?.AddOutputTokens(result.OutputTokens); + _metrics?.AddPromptCacheReads(result.CacheReadTokens); + _metrics?.AddPromptCacheWrites(result.CacheWriteTokens); _providerUsage?.AddTokens(_config.Provider, options.ModelId ?? _config.Model, result.InputTokens, result.OutputTokens); + _providerUsage?.AddCacheTokens(_config.Provider, options.ModelId ?? _config.Model, result.CacheReadTokens, result.CacheWriteTokens); result.ProviderId = _config.Provider; result.ModelId = options.ModelId ?? _config.Model; @@ -1258,6 +1294,7 @@ public async Task CompactHistoryAsync(Session session, CancellationToken ct) if (!string.IsNullOrWhiteSpace(summary)) { + _metrics?.IncrementMemoryCompactions(); session.History.RemoveRange(0, toSummarizeCount); session.History.Insert(0, new ChatTurn { diff --git a/src/OpenClaw.Core/Memory/FileMemoryStore.cs b/src/OpenClaw.Core/Memory/FileMemoryStore.cs index c6e05b5..83ee806 100644 --- a/src/OpenClaw.Core/Memory/FileMemoryStore.cs +++ b/src/OpenClaw.Core/Memory/FileMemoryStore.cs @@ -6,6 +6,7 @@ using Microsoft.Extensions.Logging; using OpenClaw.Core.Abstractions; using OpenClaw.Core.Models; +using OpenClaw.Core.Observability; namespace OpenClaw.Core.Memory; @@ -37,12 +38,17 @@ public sealed class FileMemoryStore : IMemoryStore, IMemoryNoteSearch, IMemoryRe private readonly string _branchesPath; private readonly IMemoryCache _sessionCache; private readonly SemaphoreSlim[] _sessionLoadStripes; + private readonly SemaphoreSlim _noteIndexGate = new(1, 1); + private readonly ConcurrentDictionary _noteIndex = new(StringComparer.Ordinal); private readonly ILogger? _logger; + private readonly RuntimeMetrics? _metrics; + private int _noteIndexInitialized; - public FileMemoryStore(string basePath, int maxCachedSessions = 100, ILogger? logger = null) + public FileMemoryStore(string basePath, int maxCachedSessions = 100, ILogger? logger = null, RuntimeMetrics? metrics = null) { _basePath = basePath ?? throw new ArgumentNullException(nameof(basePath)); _logger = logger; + _metrics = metrics; _sessionsPath = Path.Combine(_basePath, "sessions"); _notesPath = Path.Combine(_basePath, "notes"); @@ -68,7 +74,11 @@ public FileMemoryStore(string basePath, int maxCachedSessions = 100, ILogger> ListNotesWithPrefixAsync(string prefix, CancellationToken ct) { - var results = new List(); - - try - { - var files = Directory.EnumerateFiles(_notesPath, "*.md"); - foreach (var file in files) - { - var encodedKey = Path.GetFileNameWithoutExtension(file); - var key = ResolveNoteKey(encodedKey); - - if (key.StartsWith(prefix, StringComparison.Ordinal)) - results.Add(key); - } - } - catch - { - // Return empty list on error - } - - return ValueTask.FromResult>(results); + return ListNotesWithPrefixCoreAsync(prefix ?? "", ct); } public async ValueTask> SearchNotesAsync(string query, string? prefix, int limit, CancellationToken ct) @@ -340,56 +335,48 @@ public async ValueTask> SearchNotesAsync(string que limit = Math.Clamp(limit, 1, 50); prefix ??= ""; - - var hits = new List(capacity: Math.Min(limit, 16)); try { - foreach (var file in Directory.EnumerateFiles(_notesPath, "*.md")) + await EnsureNoteIndexLoadedAsync(ct); + var normalizedQuery = NormalizeSearchText(query); + if (normalizedQuery.Length == 0) + return []; + + var terms = BuildQueryTerms(normalizedQuery); + var candidates = _noteIndex.Values + .Where(entry => string.IsNullOrEmpty(prefix) || entry.Key.StartsWith(prefix, StringComparison.Ordinal)) + .Select(entry => new { Entry = entry, Score = ScoreNoteEntry(entry, normalizedQuery, terms) }) + .Where(static item => item.Score > 0) + .OrderByDescending(static item => item.Score) + .ThenByDescending(static item => item.Entry.UpdatedAt) + .ThenBy(static item => item.Entry.Key, StringComparer.Ordinal) + .Take(Math.Min(limit * 4, 64)) + .ToArray(); + + var hits = new List(capacity: Math.Min(limit, candidates.Length)); + foreach (var candidate in candidates) { ct.ThrowIfCancellationRequested(); - var encodedKey = Path.GetFileNameWithoutExtension(file); - var key = ResolveNoteKey(encodedKey); - - if (!string.IsNullOrEmpty(prefix) && !key.StartsWith(prefix, StringComparison.Ordinal)) - continue; - - string content; - try - { - content = await File.ReadAllTextAsync(file, ct); - } - catch - { - continue; - } - - if (content.IndexOf(query, StringComparison.OrdinalIgnoreCase) < 0 && - key.IndexOf(query, StringComparison.OrdinalIgnoreCase) < 0) - { - continue; - } - - var updatedAt = File.GetLastWriteTimeUtc(file); - + var content = await LoadNoteAsync(candidate.Entry.Key, ct) ?? candidate.Entry.PreviewContent; hits.Add(new MemoryNoteHit { - Key = key, + Key = candidate.Entry.Key, Content = content, - UpdatedAt = new DateTimeOffset(updatedAt, TimeSpan.Zero), - Score = 1.0f + UpdatedAt = candidate.Entry.UpdatedAt, + Score = candidate.Score }); if (hits.Count >= limit) break; } + + return hits; } catch { return []; } - - return hits; } public async ValueTask SaveBranchAsync(SessionBranch branch, CancellationToken ct) @@ -794,6 +781,129 @@ private ValueTask AddToCacheAsync(string sessionId, Session session) return ValueTask.CompletedTask; } + private async ValueTask> ListNotesWithPrefixCoreAsync(string prefix, CancellationToken ct) + { + try + { + await EnsureNoteIndexLoadedAsync(ct); + return _noteIndex.Keys + .Where(key => key.StartsWith(prefix, StringComparison.Ordinal)) + .OrderBy(static key => key, StringComparer.Ordinal) + .ToArray(); + } + catch + { + return []; + } + } + + private async ValueTask EnsureNoteIndexLoadedAsync(CancellationToken ct) + { + if (Volatile.Read(ref _noteIndexInitialized) != 0) + return; + + await _noteIndexGate.WaitAsync(ct); + try + { + if (_noteIndexInitialized != 0) + return; + + _noteIndex.Clear(); + foreach (var file in Directory.EnumerateFiles(_notesPath, "*.md")) + { + ct.ThrowIfCancellationRequested(); + + var encodedKey = Path.GetFileNameWithoutExtension(file); + var key = ResolveNoteKey(encodedKey); + + string content; + try + { + content = await File.ReadAllTextAsync(file, ct); + } + catch + { + continue; + } + + var updatedAt = new DateTimeOffset(File.GetLastWriteTimeUtc(file), TimeSpan.Zero); + _noteIndex[key] = CreateNoteIndexEntry(key, content, updatedAt); + } + + Volatile.Write(ref _noteIndexInitialized, 1); + } + finally + { + _noteIndexGate.Release(); + } + } + + private void UpsertNoteIndexEntry(string key, string content, DateTimeOffset updatedAt) + { + if (Volatile.Read(ref _noteIndexInitialized) == 0) + return; + + _noteIndex[key] = CreateNoteIndexEntry(key, content, updatedAt); + } + + private static NoteIndexEntry CreateNoteIndexEntry(string key, string content, DateTimeOffset updatedAt) + { + content ??= ""; + return new NoteIndexEntry + { + Key = key, + PreviewContent = content.Length <= 4_096 ? content : content[..4_096] + "…", + SearchText = NormalizeSearchText($"{key}\n{content}"), + UpdatedAt = updatedAt + }; + } + + private static string NormalizeSearchText(string value) + { + if (string.IsNullOrWhiteSpace(value)) + return string.Empty; + + var normalized = value.Replace("\r\n", "\n", StringComparison.Ordinal) + .Replace('\r', '\n') + .ToLowerInvariant(); + + return normalized.Length <= 16_384 ? normalized : normalized[..16_384]; + } + + private static string[] BuildQueryTerms(string normalizedQuery) + { + return normalizedQuery + .Split([' ', '\n', '\t', ',', '.', ';', ':', '!', '?', '(', ')', '[', ']', '{', '}', '"', '\'', '/', '\\', '-', '_'], StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries) + .Where(static term => term.Length >= 3) + .Distinct(StringComparer.Ordinal) + .Take(8) + .ToArray(); + } + + private static float ScoreNoteEntry(NoteIndexEntry entry, string normalizedQuery, IReadOnlyList terms) + { + var score = 0f; + if (entry.SearchText.Contains(normalizedQuery, StringComparison.Ordinal)) + score += 6f; + if (entry.Key.Contains(normalizedQuery, StringComparison.OrdinalIgnoreCase)) + score += 4f; + + foreach (var term in terms) + { + if (entry.Key.Contains(term, StringComparison.OrdinalIgnoreCase)) + score += 2f; + if (entry.SearchText.Contains(term, StringComparison.Ordinal)) + score += 1f; + } + + if (score <= 0f) + return 0f; + + var ageDays = Math.Max(0d, (DateTimeOffset.UtcNow - entry.UpdatedAt).TotalDays); + var recencyBoost = (float)Math.Max(0.1d, 1.5d - Math.Min(1.4d, ageDays / 14d)); + return score + recencyBoost; + } + private async ValueTask PersistOriginalNoteKeyAsync(string key, string keyPath, string keyTempPath, CancellationToken ct) { if (!RequiresKeySidecar(key)) @@ -876,6 +986,14 @@ private static string DecodeKey(string encoded) } } + private sealed class NoteIndexEntry + { + public required string Key { get; init; } + public required string PreviewContent { get; init; } + public required string SearchText { get; init; } + public required DateTimeOffset UpdatedAt { get; init; } + } + // ── ISessionAdminStore ──────────────────────────────────────────────── public async ValueTask ListSessionsAsync( diff --git a/src/OpenClaw.Core/Memory/MemoryRetentionArchive.cs b/src/OpenClaw.Core/Memory/MemoryRetentionArchive.cs index 309e8c1..2e9aaa0 100644 --- a/src/OpenClaw.Core/Memory/MemoryRetentionArchive.cs +++ b/src/OpenClaw.Core/Memory/MemoryRetentionArchive.cs @@ -106,25 +106,37 @@ public static (int DeletedFiles, int Errors, List ErrorMessages) PurgeEx { ct.ThrowIfCancellationRequested(); + var shouldDelete = false; try { - using var stream = File.OpenRead(file); - using var doc = JsonDocument.Parse(stream); - if (!doc.RootElement.TryGetProperty("sweptAtUtc", out var sweptAtElement) || - sweptAtElement.ValueKind != JsonValueKind.String || - !DateTime.TryParse( - sweptAtElement.GetString(), - provider: null, - System.Globalization.DateTimeStyles.RoundtripKind, - out var sweptAtUtc)) + if (TryGetArchiveSweepDayUtc(archiveRoot, file, out var archiveDayUtc)) { - var fallbackLastWriteUtc = File.GetLastWriteTimeUtc(file); - if (fallbackLastWriteUtc >= cutoff) + if (archiveDayUtc > cutoff.Date) continue; + if (archiveDayUtc < cutoff.Date) + shouldDelete = true; } - else if (sweptAtUtc >= cutoff) + + if (!shouldDelete) { - continue; + using var stream = File.OpenRead(file); + using var doc = JsonDocument.Parse(stream); + if (!doc.RootElement.TryGetProperty("sweptAtUtc", out var sweptAtElement) || + sweptAtElement.ValueKind != JsonValueKind.String || + !DateTime.TryParse( + sweptAtElement.GetString(), + provider: null, + System.Globalization.DateTimeStyles.RoundtripKind, + out var sweptAtUtc)) + { + var fallbackLastWriteUtc = File.GetLastWriteTimeUtc(file); + if (fallbackLastWriteUtc >= cutoff) + continue; + } + else if (sweptAtUtc >= cutoff) + { + continue; + } } } catch (Exception ex) @@ -159,6 +171,33 @@ private static string EncodeId(string id) return Convert.ToHexString(hash).ToLowerInvariant(); } + private static bool TryGetArchiveSweepDayUtc(string archiveRoot, string filePath, out DateTime archiveDayUtc) + { + archiveDayUtc = default; + + try + { + var relative = Path.GetRelativePath(archiveRoot, filePath); + var segments = relative.Split(Path.DirectorySeparatorChar, Path.AltDirectorySeparatorChar); + if (segments.Length < 4) + return false; + + if (!int.TryParse(segments[0], out var year) || + !int.TryParse(segments[1], out var month) || + !int.TryParse(segments[2], out var day)) + { + return false; + } + + archiveDayUtc = new DateTime(year, month, day, 0, 0, 0, DateTimeKind.Utc); + return true; + } + catch + { + return false; + } + } + private static void CleanupEmptyDirectories(string archiveRoot) { foreach (var dir in Directory.EnumerateDirectories(archiveRoot, "*", SearchOption.AllDirectories) diff --git a/src/OpenClaw.Core/Memory/SqliteMemoryStore.cs b/src/OpenClaw.Core/Memory/SqliteMemoryStore.cs index efc5189..2d14384 100644 --- a/src/OpenClaw.Core/Memory/SqliteMemoryStore.cs +++ b/src/OpenClaw.Core/Memory/SqliteMemoryStore.cs @@ -163,9 +163,14 @@ INSERT INTO notes_fts(key, content) { return JsonSerializer.Deserialize(json, CoreJsonContext.Default.Session); } - catch + catch (Exception ex) { - return null; + _logger?.LogError(ex, "Persisted sqlite session row for {SessionId} is corrupt or unreadable", sessionId); + throw new MemoryStoreCorruptionException( + $"Session '{sessionId}' could not be loaded because its persisted sqlite state is corrupt.", + sessionId, + $"{_dbPath}#sessions/{sessionId}", + ex); } } @@ -516,9 +521,10 @@ ON CONFLICT(branch_id) DO UPDATE SET { return JsonSerializer.Deserialize(json, CoreJsonContext.Default.SessionBranch); } - catch + catch (Exception ex) { - return null; + _logger?.LogError(ex, "Persisted sqlite branch row for {BranchId} is corrupt or unreadable", branchId); + throw new InvalidDataException($"Branch '{branchId}' could not be loaded because its persisted sqlite state is corrupt.", ex); } } @@ -955,11 +961,11 @@ public async Task BackfillEmbeddingsAsync(int batchSize = 50, CancellationToken if (!_enableVectors || _embeddingGenerator is null) return; + await using var conn = new SqliteConnection(ConnectionString); + await conn.OpenAsync(ct); + while (!ct.IsCancellationRequested) { - await using var conn = new SqliteConnection(ConnectionString); - await conn.OpenAsync(ct); - await using var cmd = conn.CreateCommand(); cmd.CommandText = "SELECT key, content FROM notes WHERE embedding IS NULL LIMIT $limit;"; cmd.Parameters.AddWithValue("$limit", batchSize); @@ -972,6 +978,8 @@ public async Task BackfillEmbeddingsAsync(int batchSize = 50, CancellationToken if (batch.Count == 0) break; + var updates = new List<(string Key, byte[] Embedding)>(batch.Count); + foreach (var (key, content) in batch) { try @@ -979,14 +987,7 @@ public async Task BackfillEmbeddingsAsync(int batchSize = 50, CancellationToken var result = await _embeddingGenerator.GenerateAsync([content], cancellationToken: ct); if (result is { Count: > 0 }) { - var blob = SerializeEmbedding(result[0]); - await using var updateConn = new SqliteConnection(ConnectionString); - await updateConn.OpenAsync(ct); - await using var updateCmd = updateConn.CreateCommand(); - updateCmd.CommandText = "UPDATE notes SET embedding = $embedding WHERE key = $key;"; - updateCmd.Parameters.AddWithValue("$embedding", blob); - updateCmd.Parameters.AddWithValue("$key", key); - await updateCmd.ExecuteNonQueryAsync(ct); + updates.Add((key, SerializeEmbedding(result[0]))); } } catch (Exception ex) @@ -994,6 +995,25 @@ public async Task BackfillEmbeddingsAsync(int batchSize = 50, CancellationToken _logger?.LogWarning(ex, "Backfill embedding failed for note '{Key}'", key); } } + + if (updates.Count == 0) + continue; + + await using var tx = (SqliteTransaction)await conn.BeginTransactionAsync(ct); + await using var updateCmd = conn.CreateCommand(); + updateCmd.Transaction = tx; + updateCmd.CommandText = "UPDATE notes SET embedding = $embedding WHERE key = $key;"; + var embeddingParam = updateCmd.Parameters.Add("$embedding", SqliteType.Blob); + var keyParam = updateCmd.Parameters.Add("$key", SqliteType.Text); + + foreach (var update in updates) + { + embeddingParam.Value = update.Embedding; + keyParam.Value = update.Key; + await updateCmd.ExecuteNonQueryAsync(ct); + } + + await tx.CommitAsync(ct); } } diff --git a/src/OpenClaw.Core/Models/GatewayConfig.cs b/src/OpenClaw.Core/Models/GatewayConfig.cs index 490828f..2bced61 100644 --- a/src/OpenClaw.Core/Models/GatewayConfig.cs +++ b/src/OpenClaw.Core/Models/GatewayConfig.cs @@ -34,6 +34,7 @@ public sealed class GatewayConfig public TailscaleConfig Tailscale { get; set; } = new(); public GmailPubSubConfig GmailPubSub { get; set; } = new(); public MdnsConfig Mdns { get; set; } = new(); + public DiagnosticsConfig Diagnostics { get; set; } = new(); public string UsageFooter { get; set; } = "off"; // "off", "tokens", "full" public int MaxConcurrentSessions { get; set; } = 64; @@ -94,6 +95,33 @@ public sealed class LlmProviderConfig /// Seconds the circuit breaker stays open before probing. public int CircuitBreakerCooldownSeconds { get; set; } = 30; + + public PromptCachingConfig PromptCaching { get; set; } = new(); +} + +public sealed class PromptCachingConfig +{ + public bool? Enabled { get; set; } + public string? Retention { get; set; } // none | short | long | auto + public string? Dialect { get; set; } // auto | openai | anthropic | gemini | none + public bool? KeepWarmEnabled { get; set; } + public int KeepWarmIntervalMinutes { get; set; } = 55; + public bool? TraceEnabled { get; set; } + public string? TraceFilePath { get; set; } +} + +public sealed class DiagnosticsConfig +{ + public PromptCacheTraceConfig CacheTrace { get; set; } = new(); +} + +public sealed class PromptCacheTraceConfig +{ + public bool Enabled { get; set; } + public string? FilePath { get; set; } + public bool IncludeMessages { get; set; } = true; + public bool IncludePrompt { get; set; } = true; + public bool IncludeSystem { get; set; } = true; } public sealed class MemoryConfig @@ -113,7 +141,7 @@ public sealed class MemoryConfig public bool EnableCompaction { get; set; } = false; /// Number of history turns that triggers compaction (must exceed MaxHistoryTurns). - public int CompactionThreshold { get; set; } = 40; + public int CompactionThreshold { get; set; } = 80; /// Number of recent turns to keep verbatim during compaction. public int CompactionKeepRecent { get; set; } = 10; diff --git a/src/OpenClaw.Core/Models/ModelProfiles.cs b/src/OpenClaw.Core/Models/ModelProfiles.cs index 742ed69..16f399d 100644 --- a/src/OpenClaw.Core/Models/ModelProfiles.cs +++ b/src/OpenClaw.Core/Models/ModelProfiles.cs @@ -17,6 +17,7 @@ public sealed class ModelProfileConfig public string[] FallbackProfileIds { get; set; } = []; public string[] FallbackModels { get; set; } = []; public ModelCapabilities? Capabilities { get; set; } + public PromptCachingConfig? PromptCaching { get; set; } } public sealed class ModelCapabilities @@ -31,6 +32,10 @@ public sealed class ModelCapabilities public bool SupportsSystemMessages { get; set; } = true; public bool SupportsImageInput { get; set; } public bool SupportsAudioInput { get; set; } + public bool SupportsPromptCaching { get; set; } + public bool SupportsExplicitCacheRetention { get; set; } + public bool ReportsCacheReadTokens { get; set; } + public bool ReportsCacheWriteTokens { get; set; } public int MaxContextTokens { get; set; } public int MaxOutputTokens { get; set; } } @@ -62,6 +67,7 @@ public sealed class ModelProfile public string[] FallbackProfileIds { get; init; } = []; public string[] FallbackModels { get; init; } = []; public required ModelCapabilities Capabilities { get; init; } + public PromptCachingConfig PromptCaching { get; init; } = new(); public bool IsImplicit { get; init; } } @@ -75,6 +81,7 @@ public sealed class ModelProfileStatus public bool IsAvailable { get; init; } public string[] Tags { get; init; } = []; public required ModelCapabilities Capabilities { get; init; } + public PromptCachingConfig PromptCaching { get; init; } = new(); public string[] ValidationIssues { get; init; } = []; public string[] FallbackProfileIds { get; init; } = []; public string[] FallbackModels { get; init; } = []; diff --git a/src/OpenClaw.Core/Models/OperatorApiModels.cs b/src/OpenClaw.Core/Models/OperatorApiModels.cs index 592747c..8217c7d 100644 --- a/src/OpenClaw.Core/Models/OperatorApiModels.cs +++ b/src/OpenClaw.Core/Models/OperatorApiModels.cs @@ -70,6 +70,8 @@ public sealed class ProviderTurnUsageEntry public required string ModelId { get; init; } public long InputTokens { get; init; } public long OutputTokens { get; init; } + public long CacheReadTokens { get; init; } + public long CacheWriteTokens { get; init; } public required InputTokenComponentEstimate EstimatedInputTokensByComponent { get; init; } } diff --git a/src/OpenClaw.Core/Models/Session.cs b/src/OpenClaw.Core/Models/Session.cs index 1423798..b76a0f4 100644 --- a/src/OpenClaw.Core/Models/Session.cs +++ b/src/OpenClaw.Core/Models/Session.cs @@ -16,6 +16,8 @@ public sealed class Session { private long _totalInputTokens; private long _totalOutputTokens; + private long _totalCacheReadTokens; + private long _totalCacheWriteTokens; public required string Id { get; init; } public required string ChannelId { get; init; } @@ -69,6 +71,20 @@ public long TotalOutputTokens set => Interlocked.Exchange(ref _totalOutputTokens, value); } + /// Total input tokens served from upstream prompt cache across all turns. + public long TotalCacheReadTokens + { + get => Interlocked.Read(ref _totalCacheReadTokens); + set => Interlocked.Exchange(ref _totalCacheReadTokens, value); + } + + /// Total input tokens written into upstream prompt cache across all turns. + public long TotalCacheWriteTokens + { + get => Interlocked.Read(ref _totalCacheWriteTokens); + set => Interlocked.Exchange(ref _totalCacheWriteTokens, value); + } + /// Optional contract policy governing this session's execution limits. public ContractPolicy? ContractPolicy { get; set; } @@ -93,6 +109,14 @@ public void AddTokenUsage(long inputTokens, long outputTokens) Interlocked.Add(ref _totalOutputTokens, outputTokens); } + public void AddCacheUsage(long cacheReadTokens, long cacheWriteTokens) + { + if (cacheReadTokens != 0) + Interlocked.Add(ref _totalCacheReadTokens, cacheReadTokens); + if (cacheWriteTokens != 0) + Interlocked.Add(ref _totalCacheWriteTokens, cacheWriteTokens); + } + public long GetTotalTokens() => TotalInputTokens + TotalOutputTokens; } @@ -135,6 +159,9 @@ public sealed record ToolInvocation [JsonSerializable(typeof(RuntimeConfig))] [JsonSerializable(typeof(GatewayRuntimeState))] [JsonSerializable(typeof(LlmProviderConfig))] +[JsonSerializable(typeof(PromptCachingConfig))] +[JsonSerializable(typeof(DiagnosticsConfig))] +[JsonSerializable(typeof(PromptCacheTraceConfig))] [JsonSerializable(typeof(ModelsConfig))] [JsonSerializable(typeof(ModelProfileConfig))] [JsonSerializable(typeof(List))] diff --git a/src/OpenClaw.Core/Observability/PromptCacheUsage.cs b/src/OpenClaw.Core/Observability/PromptCacheUsage.cs new file mode 100644 index 0000000..9354d2c --- /dev/null +++ b/src/OpenClaw.Core/Observability/PromptCacheUsage.cs @@ -0,0 +1,54 @@ +using Microsoft.Extensions.AI; + +namespace OpenClaw.Core.Observability; + +public readonly record struct PromptCacheUsage(long CacheReadTokens, long CacheWriteTokens) +{ + public static PromptCacheUsage Empty { get; } = new(0, 0); +} + +public static class PromptCacheUsageExtractor +{ + private static readonly string[] CacheWriteKeys = + [ + "cache_write_tokens", + "cacheWriteTokens", + "cache_creation_input_tokens", + "cacheCreationInputTokens" + ]; + + public static PromptCacheUsage FromUsage(UsageDetails? usage) + { + if (usage is null) + return PromptCacheUsage.Empty; + + var cacheRead = usage.CachedInputTokenCount ?? 0; + long cacheWrite = 0; + if (usage.AdditionalCounts is not null) + { + foreach (var key in CacheWriteKeys) + { + if (usage.AdditionalCounts.TryGetValue(key, out var value)) + { + cacheWrite = value; + break; + } + } + } + + return new PromptCacheUsage(cacheRead, cacheWrite); + } + + public static PromptCacheUsage Merge(params PromptCacheUsage[] items) + { + long cacheRead = 0; + long cacheWrite = 0; + foreach (var item in items) + { + cacheRead += item.CacheReadTokens; + cacheWrite += item.CacheWriteTokens; + } + + return new PromptCacheUsage(cacheRead, cacheWrite); + } +} diff --git a/src/OpenClaw.Core/Observability/ProviderUsageTracker.cs b/src/OpenClaw.Core/Observability/ProviderUsageTracker.cs index bd9e216..ede3341 100644 --- a/src/OpenClaw.Core/Observability/ProviderUsageTracker.cs +++ b/src/OpenClaw.Core/Observability/ProviderUsageTracker.cs @@ -28,6 +28,15 @@ public void AddTokens(string providerId, string modelId, long inputTokens, long counter.AddOutputTokens(outputTokens); } + public void AddCacheTokens(string providerId, string modelId, long cacheReadTokens, long cacheWriteTokens) + { + var counter = GetCounter(providerId, modelId); + if (cacheReadTokens > 0) + counter.AddCacheReadTokens(cacheReadTokens); + if (cacheWriteTokens > 0) + counter.AddCacheWriteTokens(cacheWriteTokens); + } + public IReadOnlyList Snapshot() => _usage .Select(static kvp => kvp.Value.Snapshot(kvp.Key.ProviderId, kvp.Key.ModelId)) @@ -42,6 +51,8 @@ public void RecordTurn( string modelId, long inputTokens, long outputTokens, + long cacheReadTokens, + long cacheWriteTokens, InputTokenComponentEstimate estimatedInputTokensByComponent) { _recentTurns.Enqueue(new ProviderTurnUsageEntry @@ -52,6 +63,8 @@ public void RecordTurn( ModelId = string.IsNullOrWhiteSpace(modelId) ? "default" : modelId, InputTokens = inputTokens, OutputTokens = outputTokens, + CacheReadTokens = cacheReadTokens, + CacheWriteTokens = cacheWriteTokens, EstimatedInputTokensByComponent = estimatedInputTokensByComponent }); @@ -60,6 +73,25 @@ public void RecordTurn( } } + public void RecordTurn( + string sessionId, + string channelId, + string providerId, + string modelId, + long inputTokens, + long outputTokens, + InputTokenComponentEstimate estimatedInputTokensByComponent) + => RecordTurn( + sessionId, + channelId, + providerId, + modelId, + inputTokens, + outputTokens, + cacheReadTokens: 0, + cacheWriteTokens: 0, + estimatedInputTokensByComponent); + public IReadOnlyList RecentTurns(string? sessionId = null, int limit = 50) { var normalizedLimit = Math.Clamp(limit, 1, MaxRecentTurns); @@ -73,6 +105,21 @@ public IReadOnlyList RecentTurns(string? sessionId = nul return items; } + public (long CacheReadTokens, long CacheWriteTokens) GetLatestSessionCacheTotals(string? sessionId) + { + var latest = _recentTurns.ToArray() + .Where(item => + !string.IsNullOrWhiteSpace(sessionId) && + string.Equals(item.SessionId, sessionId, StringComparison.Ordinal) && + (item.CacheReadTokens > 0 || item.CacheWriteTokens > 0)) + .OrderByDescending(static item => item.TimestampUtc) + .FirstOrDefault(); + + return latest is null + ? (0, 0) + : (latest.CacheReadTokens, latest.CacheWriteTokens); + } + private UsageCounter GetCounter(string providerId, string modelId) { var normalizedProviderId = string.IsNullOrWhiteSpace(providerId) ? "unknown" : providerId.Trim(); @@ -87,12 +134,16 @@ private sealed class UsageCounter private long _errors; private long _inputTokens; private long _outputTokens; + private long _cacheReadTokens; + private long _cacheWriteTokens; public void IncrementRequests() => Interlocked.Increment(ref _requests); public void IncrementRetries() => Interlocked.Increment(ref _retries); public void IncrementErrors() => Interlocked.Increment(ref _errors); public void AddInputTokens(long value) => Interlocked.Add(ref _inputTokens, value); public void AddOutputTokens(long value) => Interlocked.Add(ref _outputTokens, value); + public void AddCacheReadTokens(long value) => Interlocked.Add(ref _cacheReadTokens, value); + public void AddCacheWriteTokens(long value) => Interlocked.Add(ref _cacheWriteTokens, value); public ProviderUsageSnapshot Snapshot(string providerId, string modelId) => new() @@ -103,7 +154,9 @@ public ProviderUsageSnapshot Snapshot(string providerId, string modelId) Retries = Interlocked.Read(ref _retries), Errors = Interlocked.Read(ref _errors), InputTokens = Interlocked.Read(ref _inputTokens), - OutputTokens = Interlocked.Read(ref _outputTokens) + OutputTokens = Interlocked.Read(ref _outputTokens), + CacheReadTokens = Interlocked.Read(ref _cacheReadTokens), + CacheWriteTokens = Interlocked.Read(ref _cacheWriteTokens) }; } } @@ -117,4 +170,6 @@ public sealed class ProviderUsageSnapshot public long Errors { get; init; } public long InputTokens { get; init; } public long OutputTokens { get; init; } + public long CacheReadTokens { get; init; } + public long CacheWriteTokens { get; init; } } diff --git a/src/OpenClaw.Core/Observability/RuntimeMetrics.cs b/src/OpenClaw.Core/Observability/RuntimeMetrics.cs index bc0d8e6..a871645 100644 --- a/src/OpenClaw.Core/Observability/RuntimeMetrics.cs +++ b/src/OpenClaw.Core/Observability/RuntimeMetrics.cs @@ -38,6 +38,16 @@ public sealed class RuntimeMetrics private long _retentionSkippedProtectedSessions; private long _operatorAuditWriteFailures; private long _runtimeEventWriteFailures; + private long _sessionCacheHits; + private long _sessionCacheMisses; + private long _memoryRecallSearches; + private long _memoryRecallHits; + private long _memoryCompactions; + private long _promptCacheReads; + private long _promptCacheWrites; + private long _promptCacheWarmRuns; + private long _promptCacheWarmSkips; + private long _promptCacheWarmFailures; // ── Gauges ──────────────────────────────────────────────────────────── private int _activeSessions; @@ -74,6 +84,16 @@ public sealed class RuntimeMetrics public long RetentionSkippedProtectedSessions => Interlocked.Read(ref _retentionSkippedProtectedSessions); public long OperatorAuditWriteFailures => Interlocked.Read(ref _operatorAuditWriteFailures); public long RuntimeEventWriteFailures => Interlocked.Read(ref _runtimeEventWriteFailures); + public long SessionCacheHits => Interlocked.Read(ref _sessionCacheHits); + public long SessionCacheMisses => Interlocked.Read(ref _sessionCacheMisses); + public long MemoryRecallSearches => Interlocked.Read(ref _memoryRecallSearches); + public long MemoryRecallHits => Interlocked.Read(ref _memoryRecallHits); + public long MemoryCompactions => Interlocked.Read(ref _memoryCompactions); + public long PromptCacheReads => Interlocked.Read(ref _promptCacheReads); + public long PromptCacheWrites => Interlocked.Read(ref _promptCacheWrites); + public long PromptCacheWarmRuns => Interlocked.Read(ref _promptCacheWarmRuns); + public long PromptCacheWarmSkips => Interlocked.Read(ref _promptCacheWarmSkips); + public long PromptCacheWarmFailures => Interlocked.Read(ref _promptCacheWarmFailures); public int ActiveSessions => Volatile.Read(ref _activeSessions); public int CircuitBreakerState => Volatile.Read(ref _circuitBreakerState); public long RetentionLastRunAtUnixSeconds => Interlocked.Read(ref _retentionLastRunAtUnixSeconds); @@ -108,6 +128,16 @@ public sealed class RuntimeMetrics public void AddRetentionSkippedProtectedSessions(long n) => Interlocked.Add(ref _retentionSkippedProtectedSessions, n); public void IncrementOperatorAuditWriteFailures() => Interlocked.Increment(ref _operatorAuditWriteFailures); public void IncrementRuntimeEventWriteFailures() => Interlocked.Increment(ref _runtimeEventWriteFailures); + public void IncrementSessionCacheHits() => Interlocked.Increment(ref _sessionCacheHits); + public void IncrementSessionCacheMisses() => Interlocked.Increment(ref _sessionCacheMisses); + public void IncrementMemoryRecallSearches() => Interlocked.Increment(ref _memoryRecallSearches); + public void AddMemoryRecallHits(long n) => Interlocked.Add(ref _memoryRecallHits, n); + public void IncrementMemoryCompactions() => Interlocked.Increment(ref _memoryCompactions); + public void AddPromptCacheReads(long n) => Interlocked.Add(ref _promptCacheReads, n); + public void AddPromptCacheWrites(long n) => Interlocked.Add(ref _promptCacheWrites, n); + public void IncrementPromptCacheWarmRuns() => Interlocked.Increment(ref _promptCacheWarmRuns); + public void IncrementPromptCacheWarmSkips() => Interlocked.Increment(ref _promptCacheWarmSkips); + public void IncrementPromptCacheWarmFailures() => Interlocked.Increment(ref _promptCacheWarmFailures); public void SetActiveSessions(int count) => Volatile.Write(ref _activeSessions, count); public void SetCircuitBreakerState(int state) => Volatile.Write(ref _circuitBreakerState, state); public void SetRetentionLastRun(DateTimeOffset runAtUtc, long durationMs, bool succeeded) @@ -150,6 +180,16 @@ public void SetRetentionLastRun(DateTimeOffset runAtUtc, long durationMs, bool s RetentionSkippedProtectedSessions = RetentionSkippedProtectedSessions, OperatorAuditWriteFailures = OperatorAuditWriteFailures, RuntimeEventWriteFailures = RuntimeEventWriteFailures, + SessionCacheHits = SessionCacheHits, + SessionCacheMisses = SessionCacheMisses, + MemoryRecallSearches = MemoryRecallSearches, + MemoryRecallHits = MemoryRecallHits, + MemoryCompactions = MemoryCompactions, + PromptCacheReads = PromptCacheReads, + PromptCacheWrites = PromptCacheWrites, + PromptCacheWarmRuns = PromptCacheWarmRuns, + PromptCacheWarmSkips = PromptCacheWarmSkips, + PromptCacheWarmFailures = PromptCacheWarmFailures, RetentionLastRunAtUnixSeconds = RetentionLastRunAtUnixSeconds, RetentionLastRunDurationMs = RetentionLastRunDurationMs, RetentionLastRunSucceeded = RetentionLastRunSucceeded, @@ -188,6 +228,16 @@ public struct MetricsSnapshot public long RetentionSkippedProtectedSessions { get; set; } public long OperatorAuditWriteFailures { get; set; } public long RuntimeEventWriteFailures { get; set; } + public long SessionCacheHits { get; set; } + public long SessionCacheMisses { get; set; } + public long MemoryRecallSearches { get; set; } + public long MemoryRecallHits { get; set; } + public long MemoryCompactions { get; set; } + public long PromptCacheReads { get; set; } + public long PromptCacheWrites { get; set; } + public long PromptCacheWarmRuns { get; set; } + public long PromptCacheWarmSkips { get; set; } + public long PromptCacheWarmFailures { get; set; } public long RetentionLastRunAtUnixSeconds { get; set; } public long RetentionLastRunDurationMs { get; set; } public int RetentionLastRunSucceeded { get; set; } diff --git a/src/OpenClaw.Core/Pipeline/ChatCommandProcessor.cs b/src/OpenClaw.Core/Pipeline/ChatCommandProcessor.cs index 7fd21bb..cad13e3 100644 --- a/src/OpenClaw.Core/Pipeline/ChatCommandProcessor.cs +++ b/src/OpenClaw.Core/Pipeline/ChatCommandProcessor.cs @@ -4,6 +4,7 @@ using System.Threading.Tasks; using OpenClaw.Core.Abstractions; using OpenClaw.Core.Models; +using OpenClaw.Core.Observability; using OpenClaw.Core.Sessions; namespace OpenClaw.Core.Pipeline; @@ -31,12 +32,14 @@ public sealed class ChatCommandProcessor }.ToFrozenSet(StringComparer.OrdinalIgnoreCase); private readonly SessionManager _sessionManager; + private readonly ProviderUsageTracker? _providerUsage; private readonly ConcurrentDictionary>> _dynamicCommands = new(StringComparer.OrdinalIgnoreCase); private Func>? _compactCallback; - public ChatCommandProcessor(SessionManager sessionManager) + public ChatCommandProcessor(SessionManager sessionManager, ProviderUsageTracker? providerUsage = null) { _sessionManager = sessionManager; + _providerUsage = providerUsage; } /// @@ -77,7 +80,8 @@ public DynamicCommandRegistrationResult RegisterDynamic(string command, Func 0 || session.TotalCacheWriteTokens > 0) + return (session.TotalCacheReadTokens, session.TotalCacheWriteTokens); + + return _providerUsage?.GetLatestSessionCacheTotals(session.Id) ?? (0, 0); + } } diff --git a/src/OpenClaw.Core/Validation/ConfigValidator.cs b/src/OpenClaw.Core/Validation/ConfigValidator.cs index e9a2f1e..9979165 100644 --- a/src/OpenClaw.Core/Validation/ConfigValidator.cs +++ b/src/OpenClaw.Core/Validation/ConfigValidator.cs @@ -20,6 +20,8 @@ public static class ConfigValidator "ollama", "azure-openai", "openai-compatible", + "anthropic-vertex", + "amazon-bedrock", "groq", "together", "lmstudio" @@ -52,9 +54,15 @@ public static IReadOnlyList Validate(Models.GatewayConfig config) errors.Add($"Llm.CircuitBreakerThreshold must be >= 1 (got {config.Llm.CircuitBreakerThreshold})."); if (config.Llm.CircuitBreakerCooldownSeconds < 1) errors.Add($"Llm.CircuitBreakerCooldownSeconds must be >= 1 (got {config.Llm.CircuitBreakerCooldownSeconds})."); + ValidatePromptCaching("Llm.PromptCaching", config.Llm.Provider, config.Llm.PromptCaching, errors, isDynamicProvider: false); ValidateModelProfiles(config, errors, pluginBackedProvidersPossible); // Memory + if (!string.Equals(config.Memory.Provider, "file", StringComparison.OrdinalIgnoreCase) && + !string.Equals(config.Memory.Provider, "sqlite", StringComparison.OrdinalIgnoreCase)) + { + errors.Add($"Memory.Provider '{config.Memory.Provider}' must be 'file' or 'sqlite'."); + } if (string.IsNullOrWhiteSpace(config.Memory.StoragePath)) errors.Add("Memory.StoragePath must be set."); if (config.Memory.MaxHistoryTurns < 1) @@ -538,6 +546,12 @@ private static void ValidateModelProfiles(GatewayConfig config, List err errors.Add($"Models.Profiles.{profile.Id}.Capabilities.MaxContextTokens must be >= 0."); if (profile.Capabilities?.MaxOutputTokens < 0) errors.Add($"Models.Profiles.{profile.Id}.Capabilities.MaxOutputTokens must be >= 0."); + ValidatePromptCaching( + $"Models.Profiles.{profile.Id}.PromptCaching", + profile.Provider, + profile.PromptCaching, + errors, + isDynamicProvider: pluginBackedProvidersPossible && !BuiltInLlmProviders.Contains(profile.Provider)); } if (!hasExplicitProfiles) @@ -573,4 +587,68 @@ private static void ValidateModelProfiles(GatewayConfig config, List err private static string ResolveConfiguredPath(string? path) => ConfigPathResolver.Resolve(path); + + private static void ValidatePromptCaching( + string prefix, + string? providerId, + PromptCachingConfig? caching, + List errors, + bool isDynamicProvider) + { + if (caching is null || caching.Enabled != true) + return; + + var retention = (caching.Retention ?? "auto").Trim().ToLowerInvariant(); + if (retention is not ("none" or "short" or "long" or "auto")) + errors.Add($"{prefix}.Retention must be one of: none, short, long, auto."); + + var dialect = (caching.Dialect ?? "auto").Trim().ToLowerInvariant(); + if (dialect is not ("auto" or "openai" or "anthropic" or "gemini" or "none")) + errors.Add($"{prefix}.Dialect must be one of: auto, openai, anthropic, gemini, none."); + + var provider = (providerId ?? string.Empty).Trim(); + var requireExplicitDialect = + provider.Equals("openai-compatible", StringComparison.OrdinalIgnoreCase) + || provider.Equals("groq", StringComparison.OrdinalIgnoreCase) + || provider.Equals("together", StringComparison.OrdinalIgnoreCase) + || provider.Equals("lmstudio", StringComparison.OrdinalIgnoreCase) + || isDynamicProvider; + + if (requireExplicitDialect && dialect == "auto") + errors.Add($"{prefix}.Dialect must be explicit for provider '{provider}'."); + + if (caching.KeepWarmEnabled == true) + { + if (caching.KeepWarmIntervalMinutes < 5) + errors.Add($"{prefix}.KeepWarmIntervalMinutes must be >= 5 when keep-warm is enabled."); + + if (!SupportsExplicitCacheTtl(provider, dialect)) + { + errors.Add($"{prefix}.KeepWarmEnabled is only valid for providers with explicit cache TTL semantics."); + } + } + } + + private static bool SupportsExplicitCacheTtl(string? providerId, string? dialect) + { + var provider = (providerId ?? string.Empty).Trim(); + var normalizedDialect = (dialect ?? "auto").Trim(); + if (provider.Equals("anthropic", StringComparison.OrdinalIgnoreCase) || + provider.Equals("claude", StringComparison.OrdinalIgnoreCase) || + provider.Equals("anthropic-vertex", StringComparison.OrdinalIgnoreCase)) + return true; + + if (provider.Equals("amazon-bedrock", StringComparison.OrdinalIgnoreCase)) + return string.Equals(normalizedDialect, "anthropic", StringComparison.OrdinalIgnoreCase) + || string.Equals(normalizedDialect, "auto", StringComparison.OrdinalIgnoreCase); + + if (provider.Equals("gemini", StringComparison.OrdinalIgnoreCase) || + provider.Equals("google", StringComparison.OrdinalIgnoreCase)) + { + return string.Equals(normalizedDialect, "gemini", StringComparison.OrdinalIgnoreCase) + || string.Equals(normalizedDialect, "auto", StringComparison.OrdinalIgnoreCase); + } + + return false; + } } diff --git a/src/OpenClaw.Core/Validation/DoctorCheck.cs b/src/OpenClaw.Core/Validation/DoctorCheck.cs index 9096b4c..71c53ac 100644 --- a/src/OpenClaw.Core/Validation/DoctorCheck.cs +++ b/src/OpenClaw.Core/Validation/DoctorCheck.cs @@ -24,6 +24,11 @@ public static async Task RunAsync(GatewayConfig config, GatewayRuntimeStat allPassed &= Check("LLM API Key configured", () => !string.IsNullOrWhiteSpace(config.Llm.ApiKey)); allPassed &= Check("LLM max tokens > 0", () => config.Llm.MaxTokens > 0); + allPassed &= Check( + "Prompt cache config is provider-compatible", + () => HasValidPromptCacheConfiguration(config), + warnOnly: true, + detail: "OpenAI-compatible and dynamic providers require an explicit cache dialect. Keep-warm is only supported for Anthropic-family and Gemini profiles."); allPassed &= Check( "Model profile configuration is internally consistent", () => HasValidModelProfileConfiguration(config), @@ -261,6 +266,48 @@ private static void PrintResult(string description, bool passed, bool warnOnly, } } + private static bool HasValidPromptCacheConfiguration(GatewayConfig config) + { + static bool RequiresExplicitDialect(string provider) + => provider.Equals("openai-compatible", StringComparison.OrdinalIgnoreCase) + || provider.Equals("groq", StringComparison.OrdinalIgnoreCase) + || provider.Equals("together", StringComparison.OrdinalIgnoreCase) + || provider.Equals("lmstudio", StringComparison.OrdinalIgnoreCase); + + static bool SupportsKeepWarm(string provider, string dialect) + => (dialect.Equals("anthropic", StringComparison.OrdinalIgnoreCase) && + (provider.Equals("anthropic", StringComparison.OrdinalIgnoreCase) + || provider.Equals("claude", StringComparison.OrdinalIgnoreCase) + || provider.Equals("anthropic-vertex", StringComparison.OrdinalIgnoreCase) + || provider.Equals("amazon-bedrock", StringComparison.OrdinalIgnoreCase))) + || (dialect.Equals("gemini", StringComparison.OrdinalIgnoreCase) && + (provider.Equals("gemini", StringComparison.OrdinalIgnoreCase) + || provider.Equals("google", StringComparison.OrdinalIgnoreCase))); + + static bool IsValid(GatewayConfig root, string provider, PromptCachingConfig? caching) + { + if (caching is null || caching.Enabled != true) + return true; + + var dialect = (caching.Dialect ?? "auto").Trim(); + if (RequiresExplicitDialect(provider) && dialect.Equals("auto", StringComparison.OrdinalIgnoreCase)) + return false; + + return caching.KeepWarmEnabled != true || SupportsKeepWarm(provider, dialect); + } + + if (!IsValid(config, config.Llm.Provider, config.Llm.PromptCaching)) + return false; + + foreach (var profile in config.Models.Profiles) + { + if (!IsValid(config, profile.Provider, profile.PromptCaching)) + return false; + } + + return true; + } + private static async Task PingOpenSandboxAsync(GatewayConfig config) { if (string.IsNullOrWhiteSpace(config.Sandbox.Endpoint) || diff --git a/src/OpenClaw.Gateway/Composition/CoreServicesExtensions.cs b/src/OpenClaw.Gateway/Composition/CoreServicesExtensions.cs index 2e3bd51..86c3369 100644 --- a/src/OpenClaw.Gateway/Composition/CoreServicesExtensions.cs +++ b/src/OpenClaw.Gateway/Composition/CoreServicesExtensions.cs @@ -13,6 +13,7 @@ using OpenClaw.Gateway.Bootstrap; using OpenClaw.Gateway.Extensions; using OpenClaw.Gateway.Models; +using OpenClaw.Gateway.PromptCaching; namespace OpenClaw.Gateway.Composition; @@ -30,7 +31,8 @@ public static IServiceCollection AddOpenClawCoreServices(this IServiceCollection services.AddSingleton(sp => new AllowlistManager(config.Memory.StoragePath, sp.GetRequiredService>())); - services.AddSingleton(_ => CreateMemoryStore(config)); + services.AddSingleton(); + services.AddSingleton(sp => CreateMemoryStore(config, sp.GetRequiredService())); services.AddSingleton(sp => { var memory = sp.GetRequiredService(); @@ -39,7 +41,6 @@ public static IServiceCollection AddOpenClawCoreServices(this IServiceCollection }); services.AddSingleton(sp => (ISessionSearchStore)sp.GetRequiredService()); AddFeatureStores(services, config); - services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); @@ -47,6 +48,9 @@ public static IServiceCollection AddOpenClawCoreServices(this IServiceCollection services.AddSingleton(sp => sp.GetRequiredService()); services.AddSingleton(); services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); services.AddSingleton(sp => new ProviderPolicyService( config.Memory.StoragePath, @@ -107,13 +111,21 @@ public static IServiceCollection AddOpenClawCoreServices(this IServiceCollection config, sp.GetRequiredService().CreateLogger("SessionManager"), sp.GetRequiredService())); - services.AddSingleton(); + services.AddSingleton(sp => new MemoryRetentionSweeperService( + config, + sp.GetRequiredService(), + sp.GetRequiredService(), + sp.GetRequiredService(), + sp.GetRequiredService>(), + sp.GetRequiredService().GetAll)); services.AddSingleton(sp => sp.GetRequiredService()); services.AddHostedService(sp => sp.GetRequiredService()); services.AddSingleton(); services.AddSingleton(new WebSocketChannel(config.WebSocket)); services.AddSingleton(); services.AddSingleton(); + services.AddSingleton(); + services.AddHostedService(sp => sp.GetRequiredService()); services.AddSingleton(); return services; @@ -150,7 +162,7 @@ private static string ResolveSqliteDbPath(GatewayConfig config) return Path.GetFullPath(dbPath); } - private static IMemoryStore CreateMemoryStore(OpenClaw.Core.Models.GatewayConfig config) + private static IMemoryStore CreateMemoryStore(OpenClaw.Core.Models.GatewayConfig config, RuntimeMetrics metrics) { if (string.Equals(config.Memory.Provider, "sqlite", StringComparison.OrdinalIgnoreCase)) { @@ -181,6 +193,7 @@ private static IMemoryStore CreateMemoryStore(OpenClaw.Core.Models.GatewayConfig return new FileMemoryStore( config.Memory.StoragePath, - config.Memory.MaxCachedSessions ?? config.MaxConcurrentSessions); + config.Memory.MaxCachedSessions ?? config.MaxConcurrentSessions, + metrics: metrics); } } diff --git a/src/OpenClaw.Gateway/Endpoints/DiagnosticsEndpoints.cs b/src/OpenClaw.Gateway/Endpoints/DiagnosticsEndpoints.cs index 0d6bc9f..4f41f18 100644 --- a/src/OpenClaw.Gateway/Endpoints/DiagnosticsEndpoints.cs +++ b/src/OpenClaw.Gateway/Endpoints/DiagnosticsEndpoints.cs @@ -315,7 +315,7 @@ public static void MapOpenClawDiagnosticsEndpoints( sb.AppendLine("Provider Usage"); foreach (var item in runtime.ProviderUsage.Snapshot()) { - sb.AppendLine($"- {item.ProviderId}/{item.ModelId}: requests={item.Requests} retries={item.Retries} errors={item.Errors} tokens={item.InputTokens}in/{item.OutputTokens}out"); + sb.AppendLine($"- {item.ProviderId}/{item.ModelId}: requests={item.Requests} retries={item.Retries} errors={item.Errors} tokens={item.InputTokens}in/{item.OutputTokens}out cache={item.CacheReadTokens}read/{item.CacheWriteTokens}write"); } sb.AppendLine("- routes:"); foreach (var route in runtime.Operations.LlmExecution.SnapshotRoutes()) @@ -379,6 +379,14 @@ public static void MapOpenClawDiagnosticsEndpoints( } sb.AppendLine(); + sb.AppendLine("Prompt Cache"); + sb.AppendLine($"- enabled: {EndpointHelpers.ToBoolWord(startup.Config.Llm.PromptCaching.Enabled == true)}"); + sb.AppendLine($"- retention: {startup.Config.Llm.PromptCaching.Retention ?? "auto"}"); + sb.AppendLine($"- dialect: {startup.Config.Llm.PromptCaching.Dialect ?? "auto"}"); + sb.AppendLine($"- keep_warm: {EndpointHelpers.ToBoolWord(startup.Config.Llm.PromptCaching.KeepWarmEnabled == true)} interval_minutes={startup.Config.Llm.PromptCaching.KeepWarmIntervalMinutes}"); + sb.AppendLine($"- trace_enabled: {EndpointHelpers.ToBoolWord(startup.Config.Diagnostics.CacheTrace.Enabled || startup.Config.Llm.PromptCaching.TraceEnabled == true)}"); + sb.AppendLine(); + sb.AppendLine("Cron"); sb.AppendLine($"- enabled: {EndpointHelpers.ToBoolWord(startup.Config.Cron.Enabled)} jobs={startup.Config.Cron.Jobs.Count}"); foreach (var job in startup.Config.Cron.Jobs.Take(20)) diff --git a/src/OpenClaw.Gateway/Extensions/LlmClientFactory.cs b/src/OpenClaw.Gateway/Extensions/LlmClientFactory.cs index a54b34b..d6b10e1 100644 --- a/src/OpenClaw.Gateway/Extensions/LlmClientFactory.cs +++ b/src/OpenClaw.Gateway/Extensions/LlmClientFactory.cs @@ -86,6 +86,16 @@ public static IChatClient CreateChatClient(LlmProviderConfig config) .AsIChatClient(), "anthropic" or "claude" => CreateAnthropicClient(config) .AsIChatClient(config.Model), + "anthropic-vertex" => CreateAnthropicClient(new LlmProviderConfig + { + ApiKey = config.ApiKey, + Endpoint = config.Endpoint + ?? throw new InvalidOperationException( + "Endpoint must be set for provider 'anthropic-vertex'. " + + "Set OpenClaw:Llm:Endpoint or Models:Profiles::BaseUrl."), + Model = config.Model + }) + .AsIChatClient(config.Model), "gemini" or "google" => CreateGeminiClient(config), "ollama" => CreateOpenAiClient(new LlmProviderConfig { @@ -110,9 +120,19 @@ public static IChatClient CreateChatClient(LlmProviderConfig config) }) .GetChatClient(config.Model) .AsIChatClient(), + "amazon-bedrock" => CreateAnthropicClient(new LlmProviderConfig + { + ApiKey = config.ApiKey, + Endpoint = config.Endpoint + ?? throw new InvalidOperationException( + "Endpoint must be set for provider 'amazon-bedrock'. " + + "Use a Bedrock-compatible proxy endpoint or register a dynamic provider."), + Model = config.Model + }) + .AsIChatClient(config.Model), _ => throw new InvalidOperationException( $"Unsupported LLM provider: {config.Provider}. " + - "Supported: openai, anthropic, claude, gemini, google, ollama, azure-openai, openai-compatible, groq, together, lmstudio") + "Supported: openai, anthropic, claude, anthropic-vertex, gemini, google, ollama, azure-openai, openai-compatible, groq, together, lmstudio, amazon-bedrock") }; } @@ -143,6 +163,7 @@ public static IChatClient CreateChatClient(LlmProviderConfig config) Model = config.Model, Endpoint = config.Endpoint }, embeddingModel!), + "anthropic-vertex" or "amazon-bedrock" => null, _ => null }; } diff --git a/src/OpenClaw.Gateway/Extensions/MemoryRetentionSweeperService.cs b/src/OpenClaw.Gateway/Extensions/MemoryRetentionSweeperService.cs index 95e429d..31a2691 100644 --- a/src/OpenClaw.Gateway/Extensions/MemoryRetentionSweeperService.cs +++ b/src/OpenClaw.Gateway/Extensions/MemoryRetentionSweeperService.cs @@ -5,6 +5,7 @@ using OpenClaw.Core.Models; using OpenClaw.Core.Observability; using OpenClaw.Core.Sessions; +using OpenClaw.Gateway; namespace OpenClaw.Gateway.Extensions; @@ -20,11 +21,20 @@ public interface IMemoryRetentionCoordinator /// public sealed class MemoryRetentionSweeperService : BackgroundService, IMemoryRetentionCoordinator { + private static readonly HashSet ProtectedRetentionTags = new(StringComparer.OrdinalIgnoreCase) + { + "keep", + "pinned", + "retain", + "retention-exempt" + }; + private readonly GatewayConfig _config; private readonly SessionManager _sessionManager; private readonly RuntimeMetrics _metrics; private readonly ILogger _logger; private readonly IMemoryRetentionStore? _retentionStore; + private readonly Func>? _metadataSnapshotProvider; private readonly SemaphoreSlim _runGate = new(1, 1); private readonly object _statusLock = new(); private RetentionRunStatus _status; @@ -34,13 +44,15 @@ public MemoryRetentionSweeperService( SessionManager sessionManager, IMemoryStore memoryStore, RuntimeMetrics metrics, - ILogger logger) + ILogger logger, + Func>? metadataSnapshotProvider = null) { _config = config; _sessionManager = sessionManager; _metrics = metrics; _logger = logger; _retentionStore = memoryStore as IMemoryRetentionStore; + _metadataSnapshotProvider = metadataSnapshotProvider; _status = new RetentionRunStatus { @@ -216,6 +228,15 @@ private async ValueTask> BuildProtectedSetAsync(Cancellatio set.Add(session.Id); } + if (_metadataSnapshotProvider is not null) + { + foreach (var metadata in _metadataSnapshotProvider().Values) + { + if (metadata.Starred || metadata.Tags.Any(static tag => ProtectedRetentionTags.Contains(tag))) + set.Add(metadata.SessionId); + } + } + return set; } diff --git a/src/OpenClaw.Gateway/GatewayLlmExecutionService.cs b/src/OpenClaw.Gateway/GatewayLlmExecutionService.cs index 27ce5f9..60768ae 100644 --- a/src/OpenClaw.Gateway/GatewayLlmExecutionService.cs +++ b/src/OpenClaw.Gateway/GatewayLlmExecutionService.cs @@ -8,6 +8,7 @@ using OpenClaw.Core.Models; using OpenClaw.Core.Observability; using OpenClaw.Gateway.Models; +using OpenClaw.Gateway.PromptCaching; namespace OpenClaw.Gateway; @@ -36,6 +37,8 @@ private sealed class RouteState private readonly RuntimeEventStore _eventStore; private readonly RuntimeMetrics _runtimeMetrics; private readonly ProviderUsageTracker _providerUsage; + private readonly PromptCacheCoordinator _promptCacheCoordinator; + private readonly PromptCacheWarmRegistry _promptCacheWarmRegistry; private readonly ILogger _logger; private readonly ConcurrentDictionary _routes = new(StringComparer.OrdinalIgnoreCase); @@ -48,6 +51,31 @@ public GatewayLlmExecutionService( RuntimeMetrics runtimeMetrics, ProviderUsageTracker providerUsage, ILogger logger) + : this( + config, + modelProfiles, + selectionPolicy, + policyService, + eventStore, + runtimeMetrics, + providerUsage, + new PromptCacheCoordinator(config, new PromptCacheTraceWriter(config)), + new PromptCacheWarmRegistry(), + logger) + { + } + + public GatewayLlmExecutionService( + GatewayConfig config, + ConfiguredModelProfileRegistry modelProfiles, + IModelSelectionPolicy selectionPolicy, + ProviderPolicyService policyService, + RuntimeEventStore eventStore, + RuntimeMetrics runtimeMetrics, + ProviderUsageTracker providerUsage, + PromptCacheCoordinator promptCacheCoordinator, + PromptCacheWarmRegistry promptCacheWarmRegistry, + ILogger logger) { _config = config; _modelProfiles = modelProfiles; @@ -56,6 +84,8 @@ public GatewayLlmExecutionService( _eventStore = eventStore; _runtimeMetrics = runtimeMetrics; _providerUsage = providerUsage; + _promptCacheCoordinator = promptCacheCoordinator; + _promptCacheWarmRegistry = promptCacheWarmRegistry; _logger = logger; } @@ -67,6 +97,29 @@ public GatewayLlmExecutionService( RuntimeMetrics runtimeMetrics, ProviderUsageTracker providerUsage, ILogger logger) + : this( + config, + registry, + policyService, + eventStore, + runtimeMetrics, + providerUsage, + new PromptCacheCoordinator(config, new PromptCacheTraceWriter(config)), + new PromptCacheWarmRegistry(), + logger) + { + } + + public GatewayLlmExecutionService( + GatewayConfig config, + LlmProviderRegistry registry, + ProviderPolicyService policyService, + RuntimeEventStore eventStore, + RuntimeMetrics runtimeMetrics, + ProviderUsageTracker providerUsage, + PromptCacheCoordinator promptCacheCoordinator, + PromptCacheWarmRegistry promptCacheWarmRegistry, + ILogger logger) : this( config, CreateCompatibilityServices(config, registry), @@ -74,6 +127,8 @@ public GatewayLlmExecutionService( eventStore, runtimeMetrics, providerUsage, + promptCacheCoordinator, + promptCacheWarmRegistry, logger) { } @@ -85,6 +140,8 @@ private GatewayLlmExecutionService( RuntimeEventStore eventStore, RuntimeMetrics runtimeMetrics, ProviderUsageTracker providerUsage, + PromptCacheCoordinator promptCacheCoordinator, + PromptCacheWarmRegistry promptCacheWarmRegistry, ILogger logger) : this( config, @@ -94,6 +151,8 @@ private GatewayLlmExecutionService( eventStore, runtimeMetrics, providerUsage, + promptCacheCoordinator, + promptCacheWarmRegistry, logger) { } @@ -185,6 +244,10 @@ public async Task GetResponseAsync( continue; } + effectiveOptions.ModelId = modelId; + var prepared = _promptCacheCoordinator.Prepare(session, candidate.Profile, modelId, messages, effectiveOptions); + _promptCacheWarmRegistry.Record(prepared); + var routeState = GetOrAddRouteState(candidate.Profile.Id, candidate.Profile.ProviderId, modelId); for (var attempt = 0; attempt <= registration.ProviderConfig.RetryCount; attempt++) @@ -230,11 +293,14 @@ public async Task GetResponseAsync( { using var timeoutCts = CancellationTokenSource.CreateLinkedTokenSource(innerCt); timeoutCts.CancelAfter(TimeSpan.FromSeconds(registration.ProviderConfig.TimeoutSeconds)); - return await chatClient.GetResponseAsync(messages, effectiveOptions, timeoutCts.Token); + return await chatClient.GetResponseAsync(prepared.Messages, prepared.Options, timeoutCts.Token); } - return await chatClient.GetResponseAsync(messages, effectiveOptions, innerCt); + return await chatClient.GetResponseAsync(prepared.Messages, prepared.Options, innerCt); }, ct); + NormalizePromptCacheUsage(response); + var cacheUsage = PromptCacheUsageExtractor.FromUsage(response.Usage); + _promptCacheCoordinator.RecordResponse(prepared.Descriptor, cacheUsage.CacheReadTokens, cacheUsage.CacheWriteTokens); RecordEvent(session, turnContext, "llm", "request_completed", "info", $"LLM request completed for {candidate.Profile.ProviderId}/{modelId}", new() { @@ -306,6 +372,8 @@ public Task StartStreamingAsync( } var selectedModelId = ResolveRequestedModelId(session, candidate.Profile); + var prepared = _promptCacheCoordinator.Prepare(session, candidate.Profile, selectedModelId, messages, effectiveOptions); + _promptCacheWarmRegistry.Record(prepared); var routeState = GetOrAddRouteState(candidate.Profile.Id, candidate.Profile.ProviderId, selectedModelId); var chatClient = registration.Client; @@ -326,7 +394,6 @@ public Task StartStreamingAsync( ["policyRuleId"] = legacyPolicy.RuleId ?? "" }); - effectiveOptions.ModelId = selectedModelId; IAsyncEnumerable updates = StreamWithCircuitAsync( session, turnContext, @@ -334,10 +401,11 @@ public Task StartStreamingAsync( routeState, candidate.Profile.ProviderId, selectedModelId, - messages, - effectiveOptions, + prepared.Messages, + prepared.Options, registration.ProviderConfig.TimeoutSeconds, candidate.Profile.Id, + prepared.Descriptor, ct); return Task.FromResult(new LlmStreamingExecutionResult @@ -365,6 +433,7 @@ private async IAsyncEnumerable StreamWithCircuitAsync( ChatOptions options, int timeoutSeconds, string profileId, + PromptCacheDescriptor descriptor, [System.Runtime.CompilerServices.EnumeratorCancellation] CancellationToken ct) { routeState.CircuitBreaker.ThrowIfOpen(); @@ -405,6 +474,7 @@ private async IAsyncEnumerable StreamWithCircuitAsync( routeState.LastErrorAtUtc = DateTimeOffset.UtcNow; _runtimeMetrics.IncrementLlmErrors(); _providerUsage.RecordError(providerId, modelId); + _promptCacheCoordinator.RecordResponse(descriptor, 0, 0); RecordEvent(session, turnContext, "llm", "stream_failed", "error", ex.Message, new() { ["providerId"] = providerId, @@ -415,6 +485,13 @@ private async IAsyncEnumerable StreamWithCircuitAsync( throw; } + foreach (var usage in current.Contents.OfType()) + { + var cacheUsage = PromptCacheUsageExtractor.FromUsage(usage.Details); + if (cacheUsage != PromptCacheUsage.Empty) + _promptCacheCoordinator.RecordResponse(descriptor, cacheUsage.CacheReadTokens, cacheUsage.CacheWriteTokens); + } + yield return current; } @@ -513,11 +590,60 @@ private bool TryCreateEffectiveOptions( MaxOutputTokens = maxOutputTokens, Temperature = source.Temperature, Tools = source.Tools, - ResponseFormat = source.ResponseFormat + ResponseFormat = source.ResponseFormat, + ConversationId = source.ConversationId, + Instructions = source.Instructions, + TopP = source.TopP, + TopK = source.TopK, + FrequencyPenalty = source.FrequencyPenalty, + PresencePenalty = source.PresencePenalty, + Seed = source.Seed, + Reasoning = source.Reasoning, + StopSequences = source.StopSequences?.ToList(), + AllowMultipleToolCalls = source.AllowMultipleToolCalls, + ToolMode = source.ToolMode, + AdditionalProperties = source.AdditionalProperties?.Clone() }; return true; } + private static void NormalizePromptCacheUsage(ChatResponse response) + { + if (response.Usage is null) + response.Usage = new UsageDetails(); + + if (response.AdditionalProperties is null) + return; + + if (response.Usage.CachedInputTokenCount is null && + TryReadLong(response.AdditionalProperties, "cache_read_tokens", out var cacheRead)) + { + response.Usage.CachedInputTokenCount = cacheRead; + } + + if (TryReadLong(response.AdditionalProperties, "cache_write_tokens", out var cacheWrite) || + TryReadLong(response.AdditionalProperties, "cache_creation_input_tokens", out cacheWrite)) + { + response.Usage.AdditionalCounts ??= new AdditionalPropertiesDictionary(); + response.Usage.AdditionalCounts["cache_write_tokens"] = cacheWrite; + } + } + + private static bool TryReadLong(IReadOnlyDictionary properties, string key, out long value) + { + value = 0; + if (!properties.TryGetValue(key, out var raw) || raw is null) + return false; + + return raw switch + { + long longValue => (value = longValue) >= 0, + int intValue => (value = intValue) >= 0, + string text when long.TryParse(text, out var parsed) => (value = parsed) >= 0, + _ => false + }; + } + private string ResolveRequestedModelId(Session session, ModelProfile profile) { if (!string.IsNullOrWhiteSpace(session.ModelOverride) && !_modelProfiles.TryGet(session.ModelOverride!, out _)) diff --git a/src/OpenClaw.Gateway/HeartbeatService.cs b/src/OpenClaw.Gateway/HeartbeatService.cs index d2f49aa..a7ad3c1 100644 --- a/src/OpenClaw.Gateway/HeartbeatService.cs +++ b/src/OpenClaw.Gateway/HeartbeatService.cs @@ -851,8 +851,13 @@ private async ValueTask> BuildSuggestionsA if (session is null) continue; - foreach (var turn in session.History.Where(static turn => string.Equals(turn.Role, "user", StringComparison.OrdinalIgnoreCase))) - texts.Add(($"session:{summary.Id}", turn.Content)); + foreach (var turn in session.History + .Where(static turn => string.Equals(turn.Role, "user", StringComparison.OrdinalIgnoreCase)) + .TakeLast(6)) + { + if (!string.IsNullOrWhiteSpace(turn.Content)) + texts.Add(($"session:{summary.Id}", Truncate(turn.Content, 2_000))); + } } } @@ -861,7 +866,7 @@ private async ValueTask> BuildSuggestionsA { var note = await _memoryStore.LoadNoteAsync(key, ct); if (!string.IsNullOrWhiteSpace(note)) - texts.Add(($"note:{key}", note!)); + texts.Add(($"note:{key}", Truncate(note!, 2_000))); } foreach (var (source, text) in texts) diff --git a/src/OpenClaw.Gateway/Models/ConfiguredModelProfileRegistry.cs b/src/OpenClaw.Gateway/Models/ConfiguredModelProfileRegistry.cs index 9799723..709a54c 100644 --- a/src/OpenClaw.Gateway/Models/ConfiguredModelProfileRegistry.cs +++ b/src/OpenClaw.Gateway/Models/ConfiguredModelProfileRegistry.cs @@ -69,6 +69,7 @@ public IReadOnlyList ListStatuses() IsAvailable = item.Client is not null && item.ValidationIssues.Length == 0, Tags = item.Profile.Tags, Capabilities = item.Profile.Capabilities, + PromptCaching = item.Profile.PromptCaching, ValidationIssues = item.ValidationIssues, FallbackProfileIds = item.Profile.FallbackProfileIds, FallbackModels = item.Profile.FallbackModels @@ -149,14 +150,19 @@ private static ModelProfileConfig CreateImplicitConfig(GatewayConfig config) BaseUrl = config.Llm.Endpoint, ApiKey = config.Llm.ApiKey, FallbackModels = config.Llm.FallbackModels, - Capabilities = GuessCapabilities(config.Llm.Provider) + Capabilities = GuessCapabilities(config.Llm.Provider), + PromptCaching = ClonePromptCaching(config.Llm.PromptCaching) }; private static ModelCapabilities GuessCapabilities(string providerId) { var provider = (providerId ?? string.Empty).Trim().ToLowerInvariant(); - var supportsTools = provider is "openai" or "openai-compatible" or "azure-openai" or "groq" or "together" or "lmstudio" or "anthropic" or "claude" or "gemini" or "google"; - var supportsVision = provider is "openai" or "openai-compatible" or "azure-openai" or "gemini" or "google" or "ollama"; + var supportsTools = provider is "openai" or "openai-compatible" or "azure-openai" or "groq" or "together" or "lmstudio" or "anthropic" or "claude" or "anthropic-vertex" or "amazon-bedrock" or "gemini" or "google"; + var supportsVision = provider is "openai" or "openai-compatible" or "azure-openai" or "gemini" or "google" or "ollama" or "amazon-bedrock"; + var supportsPromptCaching = provider is "openai" or "azure-openai" or "anthropic" or "claude" or "anthropic-vertex" or "gemini" or "google"; + var supportsExplicitCacheRetention = provider is "anthropic" or "claude" or "anthropic-vertex"; + var reportsCacheReadTokens = supportsPromptCaching; + var reportsCacheWriteTokens = provider is "anthropic" or "claude" or "anthropic-vertex"; return new ModelCapabilities { SupportsTools = supportsTools, @@ -168,7 +174,11 @@ private static ModelCapabilities GuessCapabilities(string providerId) SupportsReasoningEffort = provider is "openai" or "openai-compatible" or "azure-openai", SupportsSystemMessages = true, SupportsImageInput = supportsVision, - SupportsAudioInput = provider is "openai" or "openai-compatible" or "azure-openai" + SupportsAudioInput = provider is "openai" or "openai-compatible" or "azure-openai", + SupportsPromptCaching = supportsPromptCaching, + SupportsExplicitCacheRetention = supportsExplicitCacheRetention, + ReportsCacheReadTokens = reportsCacheReadTokens, + ReportsCacheWriteTokens = reportsCacheWriteTokens }; } @@ -184,6 +194,7 @@ private static ModelProfile ToProfile(GatewayConfig config, ModelProfileConfig m FallbackProfileIds = NormalizeDistinct(model.FallbackProfileIds), FallbackModels = NormalizeDistinct(model.FallbackModels), Capabilities = model.Capabilities ?? GuessCapabilities(Normalize(model.Provider) ?? config.Llm.Provider), + PromptCaching = MergePromptCaching(config.Llm.PromptCaching, model.PromptCaching), IsImplicit = string.Equals(model.Id, "default", StringComparison.OrdinalIgnoreCase) && config.Models.Profiles.Count == 0 }; @@ -200,11 +211,13 @@ private static IEnumerable ValidateProfile(ModelProfile profile, Gateway profile.ProviderId.Equals("groq", StringComparison.OrdinalIgnoreCase) || profile.ProviderId.Equals("together", StringComparison.OrdinalIgnoreCase) || profile.ProviderId.Equals("lmstudio", StringComparison.OrdinalIgnoreCase) || + profile.ProviderId.Equals("anthropic-vertex", StringComparison.OrdinalIgnoreCase) || + profile.ProviderId.Equals("amazon-bedrock", StringComparison.OrdinalIgnoreCase) || profile.ProviderId.Equals("azure-openai", StringComparison.OrdinalIgnoreCase)) && string.IsNullOrWhiteSpace(profile.BaseUrl) && string.IsNullOrWhiteSpace(config.Llm.Endpoint)) { - yield return "BaseUrl is required for OpenAI-compatible and Azure OpenAI profiles unless inherited from OpenClaw:Llm:Endpoint."; + yield return "BaseUrl is required for OpenAI-compatible, Anthropic Vertex, Amazon Bedrock, and Azure OpenAI profiles unless inherited from OpenClaw:Llm:Endpoint."; } if ((profile.ProviderId.Equals("openai", StringComparison.OrdinalIgnoreCase) || @@ -214,6 +227,8 @@ private static IEnumerable ValidateProfile(ModelProfile profile, Gateway profile.ProviderId.Equals("azure-openai", StringComparison.OrdinalIgnoreCase) || profile.ProviderId.Equals("anthropic", StringComparison.OrdinalIgnoreCase) || profile.ProviderId.Equals("claude", StringComparison.OrdinalIgnoreCase) || + profile.ProviderId.Equals("anthropic-vertex", StringComparison.OrdinalIgnoreCase) || + profile.ProviderId.Equals("amazon-bedrock", StringComparison.OrdinalIgnoreCase) || profile.ProviderId.Equals("gemini", StringComparison.OrdinalIgnoreCase) || profile.ProviderId.Equals("google", StringComparison.OrdinalIgnoreCase)) && string.IsNullOrWhiteSpace(profile.ApiKey) && @@ -236,7 +251,8 @@ internal static LlmProviderConfig BuildProviderConfig(GatewayConfig config, Mode TimeoutSeconds = config.Llm.TimeoutSeconds, RetryCount = config.Llm.RetryCount, CircuitBreakerThreshold = config.Llm.CircuitBreakerThreshold, - CircuitBreakerCooldownSeconds = config.Llm.CircuitBreakerCooldownSeconds + CircuitBreakerCooldownSeconds = config.Llm.CircuitBreakerCooldownSeconds, + PromptCaching = ClonePromptCaching(profile.PromptCaching) }; private static string? Normalize(string? value) @@ -259,6 +275,35 @@ private static string[] NormalizeDistinct(IEnumerable? values) .Distinct(StringComparer.OrdinalIgnoreCase) .ToArray(); + private static PromptCachingConfig MergePromptCaching(PromptCachingConfig inherited, PromptCachingConfig? configured) + { + if (configured is null) + return ClonePromptCaching(inherited); + + return new PromptCachingConfig + { + Enabled = configured.Enabled ?? inherited.Enabled, + Retention = string.IsNullOrWhiteSpace(configured.Retention) ? inherited.Retention : configured.Retention, + Dialect = string.IsNullOrWhiteSpace(configured.Dialect) ? inherited.Dialect : configured.Dialect, + KeepWarmEnabled = configured.KeepWarmEnabled ?? inherited.KeepWarmEnabled, + KeepWarmIntervalMinutes = configured.KeepWarmIntervalMinutes > 0 ? configured.KeepWarmIntervalMinutes : inherited.KeepWarmIntervalMinutes, + TraceEnabled = configured.TraceEnabled ?? inherited.TraceEnabled, + TraceFilePath = string.IsNullOrWhiteSpace(configured.TraceFilePath) ? inherited.TraceFilePath : configured.TraceFilePath + }; + } + + private static PromptCachingConfig ClonePromptCaching(PromptCachingConfig source) + => new() + { + Enabled = source.Enabled, + Retention = source.Retention, + Dialect = source.Dialect, + KeepWarmEnabled = source.KeepWarmEnabled, + KeepWarmIntervalMinutes = source.KeepWarmIntervalMinutes, + TraceEnabled = source.TraceEnabled, + TraceFilePath = source.TraceFilePath + }; + private bool TryResolveRegisteredClient(ModelProfile profile, out IChatClient? client) { client = null; diff --git a/src/OpenClaw.Gateway/Models/DefaultModelSelectionPolicy.cs b/src/OpenClaw.Gateway/Models/DefaultModelSelectionPolicy.cs index 8980e58..69e32d3 100644 --- a/src/OpenClaw.Gateway/Models/DefaultModelSelectionPolicy.cs +++ b/src/OpenClaw.Gateway/Models/DefaultModelSelectionPolicy.cs @@ -65,10 +65,35 @@ public ModelSelectionResult Resolve(ModelSelectionRequest request) if (candidates.Length > 0) return BuildResult(null, candidates[0].Profile, requirements, preferredTags, candidates, null); + if (TryResolveLegacyImplicitDefault(out var legacyDefault) && IsSelectable(legacyDefault)) + { + return BuildResult( + requestedProfileId: null, + legacyDefault.Profile, + requirements, + preferredTags, + [ToCandidate(legacyDefault)], + "Using the implicit default model profile because no explicit model-profile configuration is present."); + } + throw new ModelSelectionException( $"No configured model profile satisfies the current request requirements ({DescribeRequirementSummary(requirements)})."); } + private bool TryResolveLegacyImplicitDefault(out ConfiguredModelProfileRegistry.Registration registration) + { + registration = null!; + var statuses = _registry.ListStatuses(); + if (statuses.Count != 1 || !statuses[0].IsImplicit) + return false; + + if (!_registry.TryGetRegistration(statuses[0].Id, out var resolved) || resolved is null) + return false; + + registration = resolved; + return true; + } + private static ModelSelectionResult BuildResult( string? requestedProfileId, ModelProfile selectedProfile, diff --git a/src/OpenClaw.Gateway/PromptCaching/PromptCacheCoordinator.cs b/src/OpenClaw.Gateway/PromptCaching/PromptCacheCoordinator.cs new file mode 100644 index 0000000..4850618 --- /dev/null +++ b/src/OpenClaw.Gateway/PromptCaching/PromptCacheCoordinator.cs @@ -0,0 +1,268 @@ +using System.Collections.Concurrent; +using System.Security.Cryptography; +using System.Text; +using System.Text.Json; +using Microsoft.Extensions.AI; +using OpenClaw.Core.Models; + +namespace OpenClaw.Gateway.PromptCaching; + +internal sealed class PromptCachePreparedRequest +{ + public required IReadOnlyList Messages { get; init; } + public required ChatOptions Options { get; init; } + public required PromptCacheDescriptor Descriptor { get; init; } +} + +internal sealed class PromptCacheDescriptor +{ + public required string SessionId { get; init; } + public required string ProfileId { get; init; } + public required string ProviderId { get; init; } + public required string ModelId { get; init; } + public required string Dialect { get; init; } + public required string Retention { get; init; } + public required string StableFingerprint { get; init; } + public required string StableSystemPrompt { get; init; } + public required string VolatileSuffix { get; init; } + public required string ToolSignature { get; init; } + public required DateTimeOffset CreatedAtUtc { get; init; } + public bool Enabled { get; init; } + public bool KeepWarmEligible { get; init; } +} + +internal sealed class PromptCacheWarmCandidate +{ + public required PromptCacheDescriptor Descriptor { get; init; } + public required IReadOnlyList WarmMessages { get; init; } + public required ChatOptions WarmOptions { get; init; } + public DateTimeOffset LastSeenAtUtc { get; set; } = DateTimeOffset.UtcNow; + public DateTimeOffset? LastWarmedAtUtc { get; set; } +} + +internal sealed class PromptCacheWarmRegistry +{ + private readonly ConcurrentDictionary _entries = new(StringComparer.Ordinal); + + public void Record(PromptCachePreparedRequest request) + { + if (!request.Descriptor.Enabled || !request.Descriptor.KeepWarmEligible || string.IsNullOrWhiteSpace(request.Descriptor.StableSystemPrompt)) + return; + + var key = BuildKey(request.Descriptor.SessionId, request.Descriptor.ProfileId); + _entries[key] = new PromptCacheWarmCandidate + { + Descriptor = request.Descriptor, + WarmMessages = [new ChatMessage(ChatRole.System, request.Descriptor.StableSystemPrompt)], + WarmOptions = new ChatOptions + { + ModelId = request.Options.ModelId, + Tools = request.Options.Tools, + ResponseFormat = request.Options.ResponseFormat, + AdditionalProperties = request.Options.AdditionalProperties?.Clone(), + MaxOutputTokens = 1, + Temperature = 0 + }, + LastSeenAtUtc = DateTimeOffset.UtcNow + }; + } + + public IReadOnlyList Snapshot() => _entries.Values.ToArray(); + + public void MarkWarmed(PromptCacheWarmCandidate candidate, DateTimeOffset warmedAtUtc) + { + var key = BuildKey(candidate.Descriptor.SessionId, candidate.Descriptor.ProfileId); + if (_entries.TryGetValue(key, out var current)) + current.LastWarmedAtUtc = warmedAtUtc; + } + + public void Prune(IReadOnlySet activeSessionIds, DateTimeOffset staleBeforeUtc) + { + foreach (var entry in _entries) + { + if (!activeSessionIds.Contains(entry.Value.Descriptor.SessionId) || entry.Value.LastSeenAtUtc < staleBeforeUtc) + _entries.TryRemove(entry.Key, out _); + } + } + + private static string BuildKey(string sessionId, string profileId) => $"{sessionId}:{profileId}"; +} + +internal sealed class PromptCacheCoordinator +{ + private const string RouteInstructionsMarker = "\n\n[Route Instructions]\n"; + private readonly GatewayConfig _config; + private readonly PromptCacheTraceWriter _traceWriter; + + public PromptCacheCoordinator(GatewayConfig config, PromptCacheTraceWriter traceWriter) + { + _config = config; + _traceWriter = traceWriter; + } + + public PromptCachePreparedRequest Prepare( + Session session, + ModelProfile profile, + string modelId, + IReadOnlyList messages, + ChatOptions options) + { + var caching = profile.PromptCaching; + var dialect = ResolveDialect(profile.ProviderId, caching.Dialect); + var retention = NormalizeRetention(caching.Retention); + var (stableSystem, volatileSuffix) = ExtractSystemPromptSegments(messages); + var toolSignature = BuildToolSignature(options); + var stableFingerprint = BuildStableFingerprint(profile.ProviderId, modelId, stableSystem, toolSignature, options.ResponseFormat); + var preparedOptions = CloneOptions(options); + preparedOptions.ModelId = modelId; + + if (caching.Enabled == true && dialect != "none" && profile.Capabilities.SupportsPromptCaching) + { + preparedOptions.AdditionalProperties ??= new AdditionalPropertiesDictionary(); + preparedOptions.AdditionalProperties["openclaw_prompt_cache_enabled"] = true; + preparedOptions.AdditionalProperties["openclaw_prompt_cache_dialect"] = dialect; + preparedOptions.AdditionalProperties["openclaw_prompt_cache_retention"] = retention; + preparedOptions.AdditionalProperties["openclaw_prompt_cache_fingerprint"] = stableFingerprint; + preparedOptions.AdditionalProperties["openclaw_prompt_cache_keep_warm"] = caching.KeepWarmEnabled == true; + + switch (dialect) + { + case "openai": + preparedOptions.AdditionalProperties["prompt_cache_key"] = stableFingerprint; + if (retention == "long") + preparedOptions.AdditionalProperties["prompt_cache_retention"] = "24h"; + break; + case "anthropic": + preparedOptions.AdditionalProperties["anthropic_cache_key"] = stableFingerprint; + preparedOptions.AdditionalProperties["anthropic_cache_control"] = retention == "long" ? "1h" : "ephemeral"; + break; + case "gemini": + preparedOptions.AdditionalProperties["gemini_cached_content_key"] = stableFingerprint; + break; + } + } + + var descriptor = new PromptCacheDescriptor + { + SessionId = session.Id, + ProfileId = profile.Id, + ProviderId = profile.ProviderId, + ModelId = modelId, + Dialect = dialect, + Retention = retention, + StableFingerprint = stableFingerprint, + StableSystemPrompt = stableSystem, + VolatileSuffix = volatileSuffix, + ToolSignature = toolSignature, + CreatedAtUtc = DateTimeOffset.UtcNow, + Enabled = caching.Enabled == true && dialect != "none" && profile.Capabilities.SupportsPromptCaching, + KeepWarmEligible = caching.KeepWarmEnabled == true && SupportsKeepWarm(profile.ProviderId, dialect) + }; + + _traceWriter.WriteRequest(descriptor, messages, preparedOptions); + return new PromptCachePreparedRequest + { + Messages = messages, + Options = preparedOptions, + Descriptor = descriptor + }; + } + + public void RecordResponse(PromptCacheDescriptor descriptor, long cacheReadTokens, long cacheWriteTokens) + => _traceWriter.WriteResponse(descriptor, cacheReadTokens, cacheWriteTokens); + + public static string ResolveDialect(string providerId, string? configuredDialect) + { + var dialect = (configuredDialect ?? "auto").Trim().ToLowerInvariant(); + if (dialect != "auto") + return dialect; + + var provider = (providerId ?? string.Empty).Trim().ToLowerInvariant(); + return provider switch + { + "openai" or "azure-openai" => "openai", + "anthropic" or "claude" or "anthropic-vertex" or "amazon-bedrock" => "anthropic", + "gemini" or "google" => "gemini", + _ => "none" + }; + } + + private static bool SupportsKeepWarm(string providerId, string dialect) + { + var provider = (providerId ?? string.Empty).Trim().ToLowerInvariant(); + return dialect == "anthropic" && provider is "anthropic" or "claude" or "anthropic-vertex" or "amazon-bedrock" + || dialect == "gemini" && provider is "gemini" or "google"; + } + + private static string NormalizeRetention(string? retention) + { + var value = (retention ?? "auto").Trim().ToLowerInvariant(); + return value is "none" or "short" or "long" ? value : "auto"; + } + + private static (string StableSystemPrompt, string VolatileSuffix) ExtractSystemPromptSegments(IReadOnlyList messages) + { + var firstSystem = messages.FirstOrDefault(static message => message.Role == ChatRole.System)?.Text ?? string.Empty; + if (string.IsNullOrWhiteSpace(firstSystem)) + return (string.Empty, string.Empty); + + var markerIndex = firstSystem.IndexOf(RouteInstructionsMarker, StringComparison.Ordinal); + if (markerIndex < 0) + return (NormalizeText(firstSystem), string.Empty); + + return ( + NormalizeText(firstSystem[..markerIndex]), + NormalizeText(firstSystem[(markerIndex + RouteInstructionsMarker.Length)..])); + } + + private static string BuildToolSignature(ChatOptions options) + { + if (options.Tools is null || options.Tools.Count == 0) + return string.Empty; + + var signatures = options.Tools + .Select(static tool => + { + var schema = tool is AIFunctionDeclaration declaration ? declaration.JsonSchema.GetRawText() : string.Empty; + return $"{tool.Name}|{tool.Description}|{schema}"; + }) + .OrderBy(static item => item, StringComparer.Ordinal) + .ToArray(); + return string.Join("\n", signatures); + } + + private static string BuildStableFingerprint(string providerId, string modelId, string stableSystem, string toolSignature, ChatResponseFormat? responseFormat) + { + var responseFormatSignature = responseFormat is null ? string.Empty : responseFormat.GetType().FullName ?? responseFormat.ToString() ?? string.Empty; + var payload = string.Join("\n---\n", NormalizeText(providerId), NormalizeText(modelId), stableSystem, toolSignature, responseFormatSignature); + var hash = SHA256.HashData(Encoding.UTF8.GetBytes(payload)); + return Convert.ToHexString(hash).ToLowerInvariant(); + } + + private static string NormalizeText(string? value) + => string.IsNullOrWhiteSpace(value) + ? string.Empty + : value.Replace("\r\n", "\n", StringComparison.Ordinal).Trim(); + + private static ChatOptions CloneOptions(ChatOptions source) + => new() + { + ConversationId = source.ConversationId, + Instructions = source.Instructions, + Temperature = source.Temperature, + MaxOutputTokens = source.MaxOutputTokens, + TopP = source.TopP, + TopK = source.TopK, + FrequencyPenalty = source.FrequencyPenalty, + PresencePenalty = source.PresencePenalty, + Seed = source.Seed, + Reasoning = source.Reasoning, + ResponseFormat = source.ResponseFormat, + ModelId = source.ModelId, + StopSequences = source.StopSequences?.ToList(), + AllowMultipleToolCalls = source.AllowMultipleToolCalls, + ToolMode = source.ToolMode, + Tools = source.Tools, + AdditionalProperties = source.AdditionalProperties?.Clone() + }; +} diff --git a/src/OpenClaw.Gateway/PromptCaching/PromptCacheTraceWriter.cs b/src/OpenClaw.Gateway/PromptCaching/PromptCacheTraceWriter.cs new file mode 100644 index 0000000..bbd3c7f --- /dev/null +++ b/src/OpenClaw.Gateway/PromptCaching/PromptCacheTraceWriter.cs @@ -0,0 +1,175 @@ +using System.Text.Json; +using System.Text.Json.Serialization; +using Microsoft.Extensions.AI; +using OpenClaw.Core.Models; + +namespace OpenClaw.Gateway.PromptCaching; + +internal sealed class PromptCacheTraceWriter +{ + private readonly GatewayConfig _config; + private readonly Lock _gate = new(); + + public PromptCacheTraceWriter(GatewayConfig config) + { + _config = config; + } + + public void WriteRequest(PromptCacheDescriptor descriptor, IReadOnlyList messages, ChatOptions options) + { + if (!IsEnabled(descriptor)) + return; + + Write(new PromptCacheTraceEntry + { + TimestampUtc = DateTimeOffset.UtcNow, + Event = "request", + SessionId = descriptor.SessionId, + ProfileId = descriptor.ProfileId, + ProviderId = descriptor.ProviderId, + ModelId = descriptor.ModelId, + Dialect = descriptor.Dialect, + Retention = descriptor.Retention, + Fingerprint = descriptor.StableFingerprint, + StableSystemPrompt = ShouldIncludeSystem() ? descriptor.StableSystemPrompt : null, + PromptText = ShouldIncludePrompt() ? descriptor.VolatileSuffix : null, + MessageCount = messages.Count, + AdditionalProperties = options.AdditionalProperties?.ToDictionary(static kvp => kvp.Key, static kvp => RenderPropertyValue(kvp.Value)) + }); + } + + public void WriteResponse(PromptCacheDescriptor descriptor, long cacheReadTokens, long cacheWriteTokens) + { + if (!IsEnabled(descriptor)) + return; + + Write(new PromptCacheTraceEntry + { + TimestampUtc = DateTimeOffset.UtcNow, + Event = "response", + SessionId = descriptor.SessionId, + ProfileId = descriptor.ProfileId, + ProviderId = descriptor.ProviderId, + ModelId = descriptor.ModelId, + Dialect = descriptor.Dialect, + Retention = descriptor.Retention, + Fingerprint = descriptor.StableFingerprint, + CacheReadTokens = cacheReadTokens, + CacheWriteTokens = cacheWriteTokens + }); + } + + private bool IsEnabled(PromptCacheDescriptor descriptor) + { + if (GetBoolEnv("OPENCLAW_CACHE_TRACE")) + return true; + + return descriptor.Enabled && (_config.Diagnostics.CacheTrace.Enabled || _config.Llm.PromptCaching.TraceEnabled == true); + } + + private bool ShouldIncludePrompt() => GetBoolEnv("OPENCLAW_CACHE_TRACE_PROMPT", _config.Diagnostics.CacheTrace.IncludePrompt); + private bool ShouldIncludeSystem() => GetBoolEnv("OPENCLAW_CACHE_TRACE_SYSTEM", _config.Diagnostics.CacheTrace.IncludeSystem); + + private string ResolvePath(PromptCacheDescriptor descriptor) + { + var env = Environment.GetEnvironmentVariable("OPENCLAW_CACHE_TRACE_FILE"); + if (!string.IsNullOrWhiteSpace(env)) + return Path.GetFullPath(env); + if (!string.IsNullOrWhiteSpace(descriptor.StableFingerprint) && !string.IsNullOrWhiteSpace(_config.Llm.PromptCaching.TraceFilePath)) + return Path.GetFullPath(_config.Llm.PromptCaching.TraceFilePath); + if (!string.IsNullOrWhiteSpace(_config.Diagnostics.CacheTrace.FilePath)) + return Path.GetFullPath(_config.Diagnostics.CacheTrace.FilePath); + return Path.GetFullPath(Path.Combine(_config.Memory.StoragePath, "logs", "cache-trace.jsonl")); + } + + private void Write(PromptCacheTraceEntry entry) + { + var path = ResolvePath(new PromptCacheDescriptor + { + SessionId = entry.SessionId ?? "unknown", + ProfileId = entry.ProfileId ?? "unknown", + ProviderId = entry.ProviderId ?? "unknown", + ModelId = entry.ModelId ?? "unknown", + Dialect = entry.Dialect ?? "none", + Retention = entry.Retention ?? "auto", + StableFingerprint = entry.Fingerprint ?? string.Empty, + StableSystemPrompt = string.Empty, + VolatileSuffix = string.Empty, + ToolSignature = string.Empty, + CreatedAtUtc = entry.TimestampUtc, + Enabled = true, + KeepWarmEligible = false + }); + Directory.CreateDirectory(Path.GetDirectoryName(path)!); + var json = JsonSerializer.Serialize(entry, PromptCacheTraceJsonContext.Default.PromptCacheTraceEntry); + lock (_gate) + { + File.AppendAllText(path, json + Environment.NewLine); + } + } + + private static bool GetBoolEnv(string name, bool fallback = false) + { + var raw = Environment.GetEnvironmentVariable(name); + if (string.IsNullOrWhiteSpace(raw)) + return fallback; + + return raw.Trim() switch + { + "1" => true, + "0" => false, + _ => bool.TryParse(raw, out var parsed) ? parsed : fallback + }; + } + + private static string? RenderPropertyValue(object? value) + => value switch + { + null => null, + string stringValue => stringValue, + bool boolValue => boolValue ? "true" : "false", + byte byteValue => byteValue.ToString(), + sbyte sbyteValue => sbyteValue.ToString(), + short shortValue => shortValue.ToString(), + ushort ushortValue => ushortValue.ToString(), + int intValue => intValue.ToString(), + uint uintValue => uintValue.ToString(), + long longValue => longValue.ToString(), + ulong ulongValue => ulongValue.ToString(), + float floatValue => floatValue.ToString(), + double doubleValue => doubleValue.ToString(), + decimal decimalValue => decimalValue.ToString(), + Guid guidValue => guidValue.ToString(), + DateTime dateTimeValue => dateTimeValue.ToString("O"), + DateTimeOffset dateTimeOffsetValue => dateTimeOffsetValue.ToString("O"), + TimeSpan timeSpanValue => timeSpanValue.ToString(), + Uri uriValue => uriValue.ToString(), + _ => "[OMITTED]" + }; + + internal sealed class PromptCacheTraceEntry + { + public DateTimeOffset TimestampUtc { get; init; } + public string? Event { get; init; } + public string? SessionId { get; init; } + public string? ProfileId { get; init; } + public string? ProviderId { get; init; } + public string? ModelId { get; init; } + public string? Dialect { get; init; } + public string? Retention { get; init; } + public string? Fingerprint { get; init; } + public string? StableSystemPrompt { get; init; } + public string? PromptText { get; init; } + public int MessageCount { get; init; } + public long CacheReadTokens { get; init; } + public long CacheWriteTokens { get; init; } + public Dictionary? AdditionalProperties { get; init; } + } +} + +[JsonSourceGenerationOptions( + PropertyNamingPolicy = JsonKnownNamingPolicy.CamelCase, + DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull, + WriteIndented = false)] +[JsonSerializable(typeof(PromptCacheTraceWriter.PromptCacheTraceEntry))] +internal sealed partial class PromptCacheTraceJsonContext : JsonSerializerContext; diff --git a/src/OpenClaw.Gateway/PromptCaching/PromptCacheWarmService.cs b/src/OpenClaw.Gateway/PromptCaching/PromptCacheWarmService.cs new file mode 100644 index 0000000..48c63f4 --- /dev/null +++ b/src/OpenClaw.Gateway/PromptCaching/PromptCacheWarmService.cs @@ -0,0 +1,135 @@ +using Microsoft.Extensions.AI; +using Microsoft.Extensions.Hosting; +using Microsoft.Extensions.Logging; +using OpenClaw.Core.Observability; +using OpenClaw.Core.Sessions; +using OpenClaw.Gateway.Models; + +namespace OpenClaw.Gateway.PromptCaching; + +internal sealed class PromptCacheWarmService : BackgroundService +{ + private readonly SessionManager _sessions; + private readonly ConfiguredModelProfileRegistry _profiles; + private readonly PromptCacheWarmRegistry _warmRegistry; + private readonly RuntimeMetrics _metrics; + private readonly RuntimeEventStore _eventStore; + private readonly ILogger _logger; + + public PromptCacheWarmService( + SessionManager sessions, + ConfiguredModelProfileRegistry profiles, + PromptCacheWarmRegistry warmRegistry, + RuntimeMetrics metrics, + RuntimeEventStore eventStore, + ILogger logger) + { + _sessions = sessions; + _profiles = profiles; + _warmRegistry = warmRegistry; + _metrics = metrics; + _eventStore = eventStore; + _logger = logger; + } + + protected override async Task ExecuteAsync(CancellationToken stoppingToken) + { + while (!stoppingToken.IsCancellationRequested) + { + try + { + await RunSweepAsync(stoppingToken); + } + catch (OperationCanceledException) when (stoppingToken.IsCancellationRequested) + { + break; + } + catch (Exception ex) + { + _metrics.IncrementPromptCacheWarmFailures(); + _logger.LogWarning(ex, "Prompt cache warm sweep failed."); + } + + await Task.Delay(TimeSpan.FromMinutes(1), stoppingToken); + } + } + + private async Task RunSweepAsync(CancellationToken ct) + { + var activeSessionIds = (await _sessions.ListActiveAsync(ct)) + .Select(static session => session.Id) + .ToHashSet(StringComparer.Ordinal); + var now = DateTimeOffset.UtcNow; + _warmRegistry.Prune(activeSessionIds, now - TimeSpan.FromHours(6)); + + foreach (var candidate in _warmRegistry.Snapshot()) + { + if (!activeSessionIds.Contains(candidate.Descriptor.SessionId)) + { + _metrics.IncrementPromptCacheWarmSkips(); + continue; + } + + if (!_profiles.TryGetRegistration(candidate.Descriptor.ProfileId, out var registration) || registration?.Client is null) + { + _metrics.IncrementPromptCacheWarmSkips(); + continue; + } + + var intervalMinutes = Math.Max(5, registration.Profile.PromptCaching.KeepWarmIntervalMinutes); + if (candidate.LastWarmedAtUtc is not null && now - candidate.LastWarmedAtUtc < TimeSpan.FromMinutes(intervalMinutes)) + { + _metrics.IncrementPromptCacheWarmSkips(); + continue; + } + + try + { + await registration.Client.GetResponseAsync(candidate.WarmMessages, candidate.WarmOptions, ct); + candidate.LastWarmedAtUtc = now; + _warmRegistry.MarkWarmed(candidate, now); + _metrics.IncrementPromptCacheWarmRuns(); + _eventStore.Append(new Core.Models.RuntimeEventEntry + { + Id = $"evt_{Guid.NewGuid():N}"[..20], + SessionId = candidate.Descriptor.SessionId, + Component = "prompt_cache", + Action = "warm", + Severity = "info", + Summary = $"Prompt cache warmed for profile '{candidate.Descriptor.ProfileId}'.", + Metadata = new Dictionary + { + ["profileId"] = candidate.Descriptor.ProfileId, + ["providerId"] = candidate.Descriptor.ProviderId, + ["modelId"] = candidate.Descriptor.ModelId, + ["fingerprint"] = candidate.Descriptor.StableFingerprint + } + }); + } + catch (OperationCanceledException) when (ct.IsCancellationRequested) + { + throw; + } + catch (Exception ex) + { + _metrics.IncrementPromptCacheWarmFailures(); + _logger.LogDebug(ex, "Prompt cache warm attempt failed for profile {ProfileId}", candidate.Descriptor.ProfileId); + _eventStore.Append(new Core.Models.RuntimeEventEntry + { + Id = $"evt_{Guid.NewGuid():N}"[..20], + SessionId = candidate.Descriptor.SessionId, + Component = "prompt_cache", + Action = "warm_failed", + Severity = "warning", + Summary = ex.Message, + Metadata = new Dictionary + { + ["profileId"] = candidate.Descriptor.ProfileId, + ["providerId"] = candidate.Descriptor.ProviderId, + ["modelId"] = candidate.Descriptor.ModelId + } + }); + } + } + } +} diff --git a/src/OpenClaw.Gateway/Tools/SessionStatusTool.cs b/src/OpenClaw.Gateway/Tools/SessionStatusTool.cs index 3a82a16..75d3726 100644 --- a/src/OpenClaw.Gateway/Tools/SessionStatusTool.cs +++ b/src/OpenClaw.Gateway/Tools/SessionStatusTool.cs @@ -2,6 +2,7 @@ using System.Text.Json; using OpenClaw.Core.Abstractions; using OpenClaw.Core.Models; +using OpenClaw.Core.Observability; using OpenClaw.Core.Sessions; namespace OpenClaw.Gateway.Tools; @@ -12,10 +13,12 @@ namespace OpenClaw.Gateway.Tools; internal sealed class SessionStatusTool : IToolWithContext { private readonly SessionManager _sessions; + private readonly ProviderUsageTracker? _providerUsage; - public SessionStatusTool(SessionManager sessions) + public SessionStatusTool(SessionManager sessions, ProviderUsageTracker? providerUsage = null) { _sessions = sessions; + _providerUsage = providerUsage; } public string Name => "session_status"; @@ -45,6 +48,8 @@ public ValueTask ExecuteAsync(string argumentsJson, ToolExecutionContext sb.AppendLine($" Sender: {session.SenderId}"); sb.AppendLine($" Turns: {session.History.Count}"); sb.AppendLine($" Tokens (in/out): {session.TotalInputTokens}/{session.TotalOutputTokens}"); + var (cacheReadTokens, cacheWriteTokens) = GetCacheTotals(session); + sb.AppendLine($" Prompt Cache (read/write): {cacheReadTokens}/{cacheWriteTokens}"); sb.AppendLine($" Created: {session.CreatedAt:u}"); sb.AppendLine($" Last Active: {session.LastActiveAt:u}"); sb.AppendLine($" Duration: {(int)duration.TotalHours}h {duration.Minutes}m"); @@ -59,4 +64,12 @@ public ValueTask ExecuteAsync(string argumentsJson, ToolExecutionContext => root.TryGetProperty(property, out var el) && el.ValueKind == JsonValueKind.String ? el.GetString() : null; + + private (long CacheReadTokens, long CacheWriteTokens) GetCacheTotals(Session session) + { + if (session.TotalCacheReadTokens > 0 || session.TotalCacheWriteTokens > 0) + return (session.TotalCacheReadTokens, session.TotalCacheWriteTokens); + + return _providerUsage?.GetLatestSessionCacheTotals(session.Id) ?? (0, 0); + } } diff --git a/src/OpenClaw.Gateway/appsettings.Production.json b/src/OpenClaw.Gateway/appsettings.Production.json index dca5160..4d57831 100644 --- a/src/OpenClaw.Gateway/appsettings.Production.json +++ b/src/OpenClaw.Gateway/appsettings.Production.json @@ -13,9 +13,26 @@ }, "Memory": { + "Provider": "sqlite", "StoragePath": "/app/memory", "MaxHistoryTurns": 50, - "MaxCachedSessions": 128 + "MaxCachedSessions": 128, + "Sqlite": { + "DbPath": "/app/memory/openclaw.db", + "EnableFts": true, + "EnableVectors": false + }, + "Retention": { + "Enabled": true, + "RunOnStartup": true, + "SweepIntervalMinutes": 30, + "SessionTtlDays": 30, + "BranchTtlDays": 14, + "ArchiveEnabled": true, + "ArchivePath": "/app/memory/archive", + "ArchiveRetentionDays": 30, + "MaxItemsPerSweep": 1000 + } }, "Security": { diff --git a/src/OpenClaw.Gateway/appsettings.json b/src/OpenClaw.Gateway/appsettings.json index 84d0388..41d29af 100644 --- a/src/OpenClaw.Gateway/appsettings.json +++ b/src/OpenClaw.Gateway/appsettings.json @@ -9,7 +9,16 @@ "ApiKey": null, "Endpoint": null, "MaxTokens": 4096, - "Temperature": 0.7 + "Temperature": 0.7, + "PromptCaching": { + "Enabled": false, + "Retention": "auto", + "Dialect": "auto", + "KeepWarmEnabled": false, + "KeepWarmIntervalMinutes": 55, + "TraceEnabled": false, + "TraceFilePath": null + } }, "Memory": { "Provider": "file", diff --git a/src/OpenClaw.Gateway/wwwroot/admin.html b/src/OpenClaw.Gateway/wwwroot/admin.html index 2bb5e97..8d77a2d 100644 --- a/src/OpenClaw.Gateway/wwwroot/admin.html +++ b/src/OpenClaw.Gateway/wwwroot/admin.html @@ -1532,7 +1532,7 @@

Notes

allowBrowserEvaluate: settingsInputs.allowBrowserEvaluate.checked, maxHistoryTurns: toInt(settingsInputs.maxHistoryTurns.value, 50), enableCompaction: settingsInputs.enableCompaction.checked, - compactionThreshold: toInt(settingsInputs.compactionThreshold.value, 40), + compactionThreshold: toInt(settingsInputs.compactionThreshold.value, 80), compactionKeepRecent: toInt(settingsInputs.compactionKeepRecent.value, 10), retentionEnabled: settingsInputs.retentionEnabled.checked, retentionRunOnStartup: settingsInputs.retentionRunOnStartup.checked, diff --git a/src/OpenClaw.MicrosoftAgentFrameworkAdapter/MafAgentRuntime.cs b/src/OpenClaw.MicrosoftAgentFrameworkAdapter/MafAgentRuntime.cs index c6f4cf2..24824f7 100644 --- a/src/OpenClaw.MicrosoftAgentFrameworkAdapter/MafAgentRuntime.cs +++ b/src/OpenClaw.MicrosoftAgentFrameworkAdapter/MafAgentRuntime.cs @@ -43,6 +43,7 @@ public sealed class MafAgentRuntime : IAgentRuntime private readonly Func? _isContractTokenBudgetExceeded; private readonly Func? _isContractRuntimeBudgetExceeded; private readonly Action? _appendContractSnapshot; + private readonly string? _memoryRecallPrefix; private readonly object _skillGate = new(); private readonly IList _mafTools; private string _systemPrompt = string.Empty; @@ -93,6 +94,9 @@ public MafAgentRuntime( _isContractTokenBudgetExceeded = context.IsContractTokenBudgetExceeded; _isContractRuntimeBudgetExceeded = context.IsContractRuntimeBudgetExceeded; _appendContractSnapshot = context.AppendContractSnapshot; + var projectId = context.Config.Memory.ProjectId + ?? Environment.GetEnvironmentVariable("OPENCLAW_PROJECT"); + _memoryRecallPrefix = string.IsNullOrWhiteSpace(projectId) ? null : $"project:{projectId.Trim()}:"; _chatClient = new MafExecutionServiceChatClient( context.LlmExecutionService, context.RuntimeMetrics, @@ -507,9 +511,16 @@ private async ValueTask TryInjectRecallAsync(List messages, string try { var limit = Math.Clamp(_recall.MaxNotes, 1, 32); - var hits = await search.SearchNotesAsync(userMessage, prefix: null, limit, ct); + _metrics?.IncrementMemoryRecallSearches(); + var hits = await search.SearchNotesAsync(userMessage, _memoryRecallPrefix, limit, ct); + if (hits.Count == 0 && !string.IsNullOrWhiteSpace(_memoryRecallPrefix)) + { + _metrics?.IncrementMemoryRecallSearches(); + hits = await search.SearchNotesAsync(userMessage, prefix: null, limit, ct); + } if (hits.Count == 0) return; + _metrics?.AddMemoryRecallHits(hits.Count); var maxChars = Math.Clamp(_recall.MaxChars, 256, 100_000); var sb = new StringBuilder(); sb.AppendLine("[Relevant memory]"); @@ -612,6 +623,7 @@ private async Task CompactHistoryAsync(Session session, CancellationToken ct) return; } + _metrics?.IncrementMemoryCompactions(); session.History.RemoveRange(0, toSummarizeCount); session.History.Insert(0, new ChatTurn { @@ -691,13 +703,18 @@ private void RecordSummaryUsage( ?? LlmExecutionEstimateBuilder.EstimateInputTokens(messages); var outputTokens = execution.Response.Usage?.OutputTokenCount ?? LlmExecutionEstimateBuilder.EstimateTokenCount(execution.Response.Text?.Length ?? 0); + var cacheUsage = PromptCacheUsageExtractor.FromUsage(execution.Response.Usage); session.AddTokenUsage(inputTokens, outputTokens); + session.AddCacheUsage(cacheUsage.CacheReadTokens, cacheUsage.CacheWriteTokens); turnContext.RecordLlmCall(elapsed, inputTokens, outputTokens); _metrics.IncrementLlmCalls(); _metrics.AddInputTokens(inputTokens); _metrics.AddOutputTokens(outputTokens); + _metrics.AddPromptCacheReads(cacheUsage.CacheReadTokens); + _metrics.AddPromptCacheWrites(cacheUsage.CacheWriteTokens); _providerUsage.AddTokens(execution.ProviderId, execution.ModelId, inputTokens, outputTokens); + _providerUsage.AddCacheTokens(execution.ProviderId, execution.ModelId, cacheUsage.CacheReadTokens, cacheUsage.CacheWriteTokens); _providerUsage.RecordTurn( session.Id, session.ChannelId, @@ -705,6 +722,8 @@ private void RecordSummaryUsage( execution.ModelId, inputTokens, outputTokens, + cacheUsage.CacheReadTokens, + cacheUsage.CacheWriteTokens, LlmExecutionEstimateBuilder.BuildInputTokenEstimate(messages, inputTokens, 0)); } diff --git a/src/OpenClaw.MicrosoftAgentFrameworkAdapter/MafExecutionServiceChatClient.cs b/src/OpenClaw.MicrosoftAgentFrameworkAdapter/MafExecutionServiceChatClient.cs index 41edb5d..9f86544 100644 --- a/src/OpenClaw.MicrosoftAgentFrameworkAdapter/MafExecutionServiceChatClient.cs +++ b/src/OpenClaw.MicrosoftAgentFrameworkAdapter/MafExecutionServiceChatClient.cs @@ -58,7 +58,8 @@ public async Task GetResponseAsync( result.ProviderId, result.ModelId, result.Response.Usage?.InputTokenCount, - result.Response.Usage?.OutputTokenCount); + result.Response.Usage?.OutputTokenCount, + PromptCacheUsageExtractor.FromUsage(result.Response.Usage)); _telemetry.TagProvider(Activity.Current, result.ProviderId, result.ModelId); return result.Response; @@ -86,6 +87,7 @@ public async IAsyncEnumerable GetStreamingResponseAsync( long? inputTokens = null; long? outputTokens = null; + var cacheUsage = PromptCacheUsage.Empty; var streamedText = new StringBuilder(); await foreach (var update in result.Updates.WithCancellation(cancellationToken)) @@ -102,6 +104,7 @@ public async IAsyncEnumerable GetStreamingResponseAsync( inputTokens = usage.Details.InputTokenCount.Value; if (usage.Details.OutputTokenCount is > 0) outputTokens = usage.Details.OutputTokenCount.Value; + cacheUsage = PromptCacheUsageExtractor.FromUsage(usage.Details); } yield return update; @@ -116,6 +119,7 @@ public async IAsyncEnumerable GetStreamingResponseAsync( result.ModelId, inputTokens, outputTokens, + cacheUsage, fallbackOutputLength: streamedText.Length); _telemetry.TagProvider(Activity.Current, result.ProviderId, result.ModelId); @@ -136,6 +140,7 @@ private void RecordUsage( string modelId, long? inputTokens, long? outputTokens, + PromptCacheUsage cacheUsage, int fallbackOutputLength = 0) { var resolvedInputTokens = inputTokens is > 0 @@ -149,11 +154,15 @@ private void RecordUsage( executionContext.TurnContext.RecordLlmCall(elapsed, resolvedInputTokens, resolvedOutputTokens); executionContext.Session.AddTokenUsage(resolvedInputTokens, resolvedOutputTokens); + executionContext.Session.AddCacheUsage(cacheUsage.CacheReadTokens, cacheUsage.CacheWriteTokens); executionContext.RecordContractTurnUsage?.Invoke(executionContext.Session, providerId, modelId, resolvedInputTokens, resolvedOutputTokens); _metrics.IncrementLlmCalls(); _metrics.AddInputTokens(resolvedInputTokens); _metrics.AddOutputTokens(resolvedOutputTokens); + _metrics.AddPromptCacheReads(cacheUsage.CacheReadTokens); + _metrics.AddPromptCacheWrites(cacheUsage.CacheWriteTokens); _providerUsage.AddTokens(providerId, modelId, resolvedInputTokens, resolvedOutputTokens); + _providerUsage.AddCacheTokens(providerId, modelId, cacheUsage.CacheReadTokens, cacheUsage.CacheWriteTokens); _providerUsage.RecordTurn( executionContext.Session.Id, executionContext.Session.ChannelId, @@ -161,6 +170,8 @@ private void RecordUsage( modelId, resolvedInputTokens, resolvedOutputTokens, + cacheUsage.CacheReadTokens, + cacheUsage.CacheWriteTokens, LlmExecutionEstimateBuilder.BuildInputTokenEstimate( messages, resolvedInputTokens, diff --git a/src/OpenClaw.Tests/ChatCommandProcessorTests.cs b/src/OpenClaw.Tests/ChatCommandProcessorTests.cs index f8110fa..fdb9f1c 100644 --- a/src/OpenClaw.Tests/ChatCommandProcessorTests.cs +++ b/src/OpenClaw.Tests/ChatCommandProcessorTests.cs @@ -1,6 +1,7 @@ using Microsoft.Extensions.Logging.Abstractions; using OpenClaw.Core.Memory; using OpenClaw.Core.Models; +using OpenClaw.Core.Observability; using OpenClaw.Core.Pipeline; using OpenClaw.Core.Sessions; using Xunit; @@ -59,4 +60,34 @@ public async Task Compact_Command_ReportsRemainingTurns() Assert.True(handled); Assert.Equal("Compacted: 4 turns → 6 turns remaining.", response); } + + [Fact] + public async Task Status_Command_UsesRecentUsageFallbackForPromptCacheCounters() + { + var store = new FileMemoryStore(System.IO.Path.Combine(System.IO.Path.GetTempPath(), "openclaw-command-tests", Guid.NewGuid().ToString("N")), 4); + var usage = new ProviderUsageTracker(); + usage.RecordTurn( + "sess-cache", + "websocket", + "openai", + "gpt-4.1", + inputTokens: 100, + outputTokens: 20, + cacheReadTokens: 512, + cacheWriteTokens: 0, + estimatedInputTokensByComponent: new InputTokenComponentEstimate()); + var processor = new ChatCommandProcessor(new SessionManager(store, new GatewayConfig(), NullLogger.Instance), usage); + + var session = new Session + { + Id = "sess-cache", + ChannelId = "websocket", + SenderId = "user1" + }; + + var (handled, response) = await processor.TryProcessCommandAsync(session, "/status", CancellationToken.None); + + Assert.True(handled); + Assert.Contains("Prompt Cache: 512 read / 0 write", response, StringComparison.Ordinal); + } } diff --git a/src/OpenClaw.Tests/ConfigValidatorTests.cs b/src/OpenClaw.Tests/ConfigValidatorTests.cs index e047aff..752c70c 100644 --- a/src/OpenClaw.Tests/ConfigValidatorTests.cs +++ b/src/OpenClaw.Tests/ConfigValidatorTests.cs @@ -189,6 +189,21 @@ public void Validate_CompactionThresholdMustExceedMaxHistoryTurns_ReturnsError() Assert.Contains(errors, e => e.Contains("greater than MaxHistoryTurns", StringComparison.Ordinal)); } + [Fact] + public void Validate_InvalidMemoryProvider_ReturnsError() + { + var config = new GatewayConfig + { + Memory = new MemoryConfig + { + Provider = "redis" + } + }; + + var errors = ConfigValidator.Validate(config); + Assert.Contains(errors, e => e.Contains("Memory.Provider", StringComparison.Ordinal)); + } + [Fact] public void Validate_InvalidRuntimeMode_ReturnsError() { diff --git a/src/OpenClaw.Tests/DynamicProviderRegistryCollection.cs b/src/OpenClaw.Tests/DynamicProviderRegistryCollection.cs new file mode 100644 index 0000000..57ade14 --- /dev/null +++ b/src/OpenClaw.Tests/DynamicProviderRegistryCollection.cs @@ -0,0 +1,9 @@ +using Xunit; + +namespace OpenClaw.Tests; + +[CollectionDefinition(Name, DisableParallelization = true)] +public sealed class DynamicProviderRegistryCollection : ICollectionFixture +{ + public const string Name = "Dynamic provider registry"; +} diff --git a/src/OpenClaw.Tests/FeatureParityTests.cs b/src/OpenClaw.Tests/FeatureParityTests.cs index dbf919f..cde20dd 100644 --- a/src/OpenClaw.Tests/FeatureParityTests.cs +++ b/src/OpenClaw.Tests/FeatureParityTests.cs @@ -879,7 +879,7 @@ public void GatewayConfig_MemoryConfig_HasCompaction() { var config = new MemoryConfig(); Assert.False(config.EnableCompaction); // Default false - Assert.Equal(40, config.CompactionThreshold); + Assert.Equal(80, config.CompactionThreshold); Assert.Equal(10, config.CompactionKeepRecent); } diff --git a/src/OpenClaw.Tests/FileMemoryStoreTests.cs b/src/OpenClaw.Tests/FileMemoryStoreTests.cs index 1120fa8..0edfadb 100644 --- a/src/OpenClaw.Tests/FileMemoryStoreTests.cs +++ b/src/OpenClaw.Tests/FileMemoryStoreTests.cs @@ -186,4 +186,29 @@ public async Task SearchNotesAsync_LongKeys_RespectPrefixFilter() Directory.Delete(storagePath, recursive: true); } } + + [Fact] + public async Task SearchNotesAsync_PrefersHigherScoringAndMoreRecentNotes() + { + var storagePath = Path.Combine(Path.GetTempPath(), "openclaw-file-memory-tests", Guid.NewGuid().ToString("N")); + Directory.CreateDirectory(storagePath); + + try + { + var store = new FileMemoryStore(storagePath, 4); + await store.SaveNoteAsync("project:demo:legacy", "architecture notes about migration", CancellationToken.None); + await Task.Delay(20); + await store.SaveNoteAsync("project:demo:architecture", "architecture migration checklist", CancellationToken.None); + + var hits = await store.SearchNotesAsync("architecture migration", "project:demo:", 2, CancellationToken.None); + + Assert.Equal(2, hits.Count); + Assert.Equal("project:demo:architecture", hits[0].Key); + Assert.True(hits[0].Score >= hits[1].Score); + } + finally + { + Directory.Delete(storagePath, recursive: true); + } + } } diff --git a/src/OpenClaw.Tests/LlmClientFactoryTests.cs b/src/OpenClaw.Tests/LlmClientFactoryTests.cs index 275f342..e6fa164 100644 --- a/src/OpenClaw.Tests/LlmClientFactoryTests.cs +++ b/src/OpenClaw.Tests/LlmClientFactoryTests.cs @@ -6,6 +6,7 @@ namespace OpenClaw.Tests; +[Collection(DynamicProviderRegistryCollection.Name)] public sealed class LlmClientFactoryTests { [Fact] diff --git a/src/OpenClaw.Tests/MemoryRecallInjectionTests.cs b/src/OpenClaw.Tests/MemoryRecallInjectionTests.cs index 26fd7e9..81f0bc5 100644 --- a/src/OpenClaw.Tests/MemoryRecallInjectionTests.cs +++ b/src/OpenClaw.Tests/MemoryRecallInjectionTests.cs @@ -46,4 +46,44 @@ public async Task RunAsync_InsertsRelevantMemoryUserMessage_WhenEnabled() (m.Text ?? "").Contains("[Relevant memory]", StringComparison.Ordinal) && (m.Text ?? "").Contains("untrusted", StringComparison.OrdinalIgnoreCase)); } + + [Fact] + public async Task RunAsync_PrefersProjectScopedRecall_WhenProjectIdConfigured() + { + var chatClient = Substitute.For(); + chatClient.GetResponseAsync( + Arg.Any>(), + Arg.Any(), + Arg.Any()) + .Returns(Task.FromResult(new ChatResponse(new[] { new ChatMessage(ChatRole.Assistant, "ok") }))); + + var memory = Substitute.For(); + var search = (IMemoryNoteSearch)memory; + search.SearchNotesAsync(Arg.Any(), Arg.Any(), Arg.Any(), Arg.Any()) + .Returns(ValueTask.FromResult>([])); + + var agent = new AgentRuntime( + chatClient, + tools: [], + memory, + new LlmProviderConfig { Provider = "openai", ApiKey = "test", Model = "gpt-4" }, + maxHistoryTurns: 5, + recall: new MemoryRecallConfig { Enabled = true, MaxNotes = 5, MaxChars = 4000 }, + gatewayConfig: new GatewayConfig + { + Memory = new MemoryConfig + { + ProjectId = "demo" + } + }); + + var session = new Session { Id = "s1", ChannelId = "test", SenderId = "u1" }; + _ = await agent.RunAsync(session, "what should I remember?", CancellationToken.None); + + await search.Received().SearchNotesAsync( + "what should I remember?", + "project:demo:", + Arg.Any(), + Arg.Any()); + } } diff --git a/src/OpenClaw.Tests/MemoryRetentionSweeperServiceTests.cs b/src/OpenClaw.Tests/MemoryRetentionSweeperServiceTests.cs index b4d8f22..c0781e4 100644 --- a/src/OpenClaw.Tests/MemoryRetentionSweeperServiceTests.cs +++ b/src/OpenClaw.Tests/MemoryRetentionSweeperServiceTests.cs @@ -3,6 +3,7 @@ using OpenClaw.Core.Models; using OpenClaw.Core.Observability; using OpenClaw.Core.Sessions; +using OpenClaw.Gateway; using OpenClaw.Gateway.Extensions; using Xunit; @@ -131,9 +132,54 @@ public async Task SweepNowAsync_RejectsOverlappingRuns() await first; } + [Fact] + public async Task SweepNowAsync_ProtectsStarredSessionsFromMetadataStore() + { + var root = Path.Combine(Path.GetTempPath(), "openclaw-retention-tests", Guid.NewGuid().ToString("N")); + Directory.CreateDirectory(root); + + try + { + var config = new GatewayConfig + { + Memory = new MemoryConfig + { + StoragePath = root, + Retention = new MemoryRetentionConfig + { + Enabled = true, + ArchiveEnabled = false + } + } + }; + + var metadataStore = new SessionMetadataStore(root, NullLogger.Instance); + metadataStore.Set("session-starred", new SessionMetadataUpdateRequest { Starred = true }); + + var store = new StubRetentionStore(); + var manager = new SessionManager(store, config); + var service = new MemoryRetentionSweeperService( + config, + manager, + store, + new RuntimeMetrics(), + NullLogger.Instance, + metadataStore.GetAll); + + _ = await service.SweepNowAsync(dryRun: true, CancellationToken.None); + + Assert.Contains("session-starred", store.LastProtectedSessionIds); + } + finally + { + Directory.Delete(root, recursive: true); + } + } + private sealed class StubRetentionStore : IMemoryStore, IMemoryRetentionStore { public Func>? NextResultFactory { get; set; } + public IReadOnlySet LastProtectedSessionIds { get; private set; } = new HashSet(StringComparer.Ordinal); public ValueTask GetSessionAsync(string sessionId, CancellationToken ct) => ValueTask.FromResult(null); public ValueTask SaveSessionAsync(Session session, CancellationToken ct) => ValueTask.CompletedTask; @@ -151,6 +197,7 @@ public async ValueTask SweepAsync( IReadOnlySet protectedSessionIds, CancellationToken ct) { + LastProtectedSessionIds = new HashSet(protectedSessionIds, StringComparer.Ordinal); if (NextResultFactory is null) { return new RetentionSweepResult diff --git a/src/OpenClaw.Tests/ModelProfileSelectionTests.cs b/src/OpenClaw.Tests/ModelProfileSelectionTests.cs index ce71fdf..9a42c58 100644 --- a/src/OpenClaw.Tests/ModelProfileSelectionTests.cs +++ b/src/OpenClaw.Tests/ModelProfileSelectionTests.cs @@ -14,6 +14,7 @@ namespace OpenClaw.Tests; +[Collection(DynamicProviderRegistryCollection.Name)] public sealed class ModelProfileSelectionTests { [Fact] diff --git a/src/OpenClaw.Tests/PluginBridgeIntegrationTests.cs b/src/OpenClaw.Tests/PluginBridgeIntegrationTests.cs index 87bc987..25bab6f 100644 --- a/src/OpenClaw.Tests/PluginBridgeIntegrationTests.cs +++ b/src/OpenClaw.Tests/PluginBridgeIntegrationTests.cs @@ -17,6 +17,7 @@ namespace OpenClaw.Tests; +[Collection(DynamicProviderRegistryCollection.Name)] public sealed class PluginBridgeIntegrationTests : IDisposable { private readonly string _tempDir; diff --git a/src/OpenClaw.Tests/PromptCachingTests.cs b/src/OpenClaw.Tests/PromptCachingTests.cs new file mode 100644 index 0000000..02cef60 --- /dev/null +++ b/src/OpenClaw.Tests/PromptCachingTests.cs @@ -0,0 +1,238 @@ +using System.Text.Json; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.Logging.Abstractions; +using OpenClaw.Core.Models; +using OpenClaw.Core.Observability; +using OpenClaw.Core.Validation; +using OpenClaw.Gateway.Extensions; +using OpenClaw.Gateway.Models; +using OpenClaw.Gateway.PromptCaching; +using Xunit; + +namespace OpenClaw.Tests; + +[Collection(DynamicProviderRegistryCollection.Name)] +public sealed class PromptCachingTests +{ + [Fact] + public void ConfigValidator_RejectsOpenAiCompatiblePromptCachingWithoutExplicitDialect() + { + var config = new GatewayConfig + { + Models = new ModelsConfig + { + Profiles = + [ + new ModelProfileConfig + { + Id = "gemma4-prod", + Provider = "openai-compatible", + Model = "gemma-4", + BaseUrl = "https://example.invalid/v1", + ApiKey = "raw:test", + PromptCaching = new PromptCachingConfig + { + Enabled = true, + Dialect = "auto" + } + } + ] + } + }; + + var errors = ConfigValidator.Validate(config); + + Assert.Contains(errors, error => error.Contains("Models.Profiles.gemma4-prod.PromptCaching.Dialect", StringComparison.Ordinal)); + } + + [Fact] + public void Registry_MergesProfilePromptCachingOverrideOverGlobalDefaults() + { + LlmClientFactory.ResetDynamicProviders(); + LlmClientFactory.RegisterProvider("fake-profile-tests", new TestChatClient()); + + var config = new GatewayConfig + { + Llm = new LlmProviderConfig + { + Provider = "fake-profile-tests", + Model = "legacy-model", + PromptCaching = new PromptCachingConfig + { + Enabled = true, + Dialect = "openai", + Retention = "long", + TraceEnabled = true + } + }, + Models = new ModelsConfig + { + Profiles = + [ + new ModelProfileConfig + { + Id = "profile-a", + Provider = "fake-profile-tests", + Model = "model-a", + PromptCaching = new PromptCachingConfig + { + Retention = "short" + } + } + ] + } + }; + + var registry = new ConfiguredModelProfileRegistry(config, NullLogger.Instance); + + Assert.True(registry.TryGet("profile-a", out var profile)); + Assert.True(profile!.PromptCaching.Enabled == true); + Assert.Equal("openai", profile.PromptCaching.Dialect); + Assert.Equal("short", profile.PromptCaching.Retention); + Assert.True(profile.PromptCaching.TraceEnabled == true); + } + + [Fact] + public void Coordinator_UsesDeterministicFingerprintAcrossVolatileSuffixAndToolOrdering() + { + var config = new GatewayConfig + { + Memory = new MemoryConfig + { + StoragePath = Path.Combine(Path.GetTempPath(), "openclaw-prompt-cache-tests", Guid.NewGuid().ToString("N")) + } + }; + Directory.CreateDirectory(config.Memory.StoragePath); + var coordinator = new PromptCacheCoordinator(config, new PromptCacheTraceWriter(config)); + var profile = CreateProfile("anthropic", "anthropic"); + var session = new Session { Id = "s1", ChannelId = "test", SenderId = "user" }; + + var toolA = AIFunctionFactory.CreateDeclaration( + "tool_a", + "Tool A", + JsonDocument.Parse("""{"type":"object","properties":{"value":{"type":"string"}}}""").RootElement.Clone(), + returnJsonSchema: null); + var toolB = AIFunctionFactory.CreateDeclaration( + "tool_b", + "Tool B", + JsonDocument.Parse("""{"type":"object","properties":{"count":{"type":"integer"}}}""").RootElement.Clone(), + returnJsonSchema: null); + + var first = coordinator.Prepare( + session, + profile, + profile.ModelId, + [new ChatMessage(ChatRole.System, "Stable prelude\n\n[Route Instructions]\nroute one"), new ChatMessage(ChatRole.User, "hello")], + new ChatOptions { Tools = [toolA, toolB] }); + var second = coordinator.Prepare( + session, + profile, + profile.ModelId, + [new ChatMessage(ChatRole.System, "Stable prelude\n\n[Route Instructions]\nroute two"), new ChatMessage(ChatRole.User, "hello")], + new ChatOptions { Tools = [toolB, toolA] }); + + Assert.Equal(first.Descriptor.StableFingerprint, second.Descriptor.StableFingerprint); + Assert.Equal("Stable prelude", first.Descriptor.StableSystemPrompt); + Assert.Equal("route one", first.Descriptor.VolatileSuffix); + } + + [Fact] + public void Coordinator_OnlyMarksKeepWarmEligibleProviders() + { + var config = new GatewayConfig + { + Memory = new MemoryConfig + { + StoragePath = Path.Combine(Path.GetTempPath(), "openclaw-prompt-cache-tests", Guid.NewGuid().ToString("N")) + } + }; + Directory.CreateDirectory(config.Memory.StoragePath); + var coordinator = new PromptCacheCoordinator(config, new PromptCacheTraceWriter(config)); + var session = new Session { Id = "s2", ChannelId = "test", SenderId = "user" }; + + var openAi = coordinator.Prepare( + session, + CreateProfile("openai", "openai", keepWarmEnabled: true), + "gpt-4.1", + [new ChatMessage(ChatRole.System, "Stable prompt"), new ChatMessage(ChatRole.User, "hello")], + new ChatOptions()); + + var anthropic = coordinator.Prepare( + session, + CreateProfile("anthropic", "anthropic", keepWarmEnabled: true), + "claude-sonnet", + [new ChatMessage(ChatRole.System, "Stable prompt"), new ChatMessage(ChatRole.User, "hello")], + new ChatOptions()); + + Assert.False(openAi.Descriptor.KeepWarmEligible); + Assert.True(anthropic.Descriptor.KeepWarmEligible); + } + + [Fact] + public void PromptCacheUsageExtractor_ReadsNormalizedCacheCounters() + { + var usage = new UsageDetails + { + CachedInputTokenCount = 4096, + AdditionalCounts = new AdditionalPropertiesDictionary + { + ["cache_creation_input_tokens"] = 512L + } + }; + + var result = PromptCacheUsageExtractor.FromUsage(usage); + + Assert.Equal(4096, result.CacheReadTokens); + Assert.Equal(512, result.CacheWriteTokens); + } + + private static ModelProfile CreateProfile(string providerId, string dialect, bool keepWarmEnabled = false) + => new() + { + Id = providerId + "-profile", + ProviderId = providerId, + ModelId = providerId + "-model", + Capabilities = new ModelCapabilities + { + SupportsPromptCaching = true, + SupportsExplicitCacheRetention = true, + ReportsCacheReadTokens = true, + ReportsCacheWriteTokens = dialect == "anthropic", + SupportsSystemMessages = true, + SupportsStreaming = true + }, + PromptCaching = new PromptCachingConfig + { + Enabled = true, + Dialect = dialect, + Retention = "long", + KeepWarmEnabled = keepWarmEnabled + } + }; + + private sealed class TestChatClient : IChatClient + { + public ChatClientMetadata Metadata => new("fake-profile-tests"); + + public object? GetService(Type serviceType, object? serviceKey = null) => null; + + public Task GetResponseAsync( + IEnumerable messages, + ChatOptions? options = null, + CancellationToken cancellationToken = default) + => Task.FromResult(new ChatResponse(new ChatMessage(ChatRole.Assistant, "ok"))); + + public async IAsyncEnumerable GetStreamingResponseAsync( + IEnumerable messages, + ChatOptions? options = null, + [System.Runtime.CompilerServices.EnumeratorCancellation] CancellationToken cancellationToken = default) + { + yield return new ChatResponseUpdate(ChatRole.Assistant, "ok"); + await Task.CompletedTask; + } + + public void Dispose() + { + } + } +} diff --git a/src/OpenClaw.Tests/SqliteSessionSearchTests.cs b/src/OpenClaw.Tests/SqliteSessionSearchTests.cs index 5f61732..9a564a2 100644 --- a/src/OpenClaw.Tests/SqliteSessionSearchTests.cs +++ b/src/OpenClaw.Tests/SqliteSessionSearchTests.cs @@ -155,6 +155,39 @@ await store.SaveSessionAsync(new Session } } + [Fact] + public async Task GetSessionAsync_CorruptRow_ThrowsCorruptionException() + { + var root = CreateTempDirectory(); + try + { + var dbPath = Path.Combine(root, "memory.db"); + using var store = new SqliteMemoryStore(dbPath, enableFts: false); + + await store.SaveSessionAsync(new Session + { + Id = "session-corrupt", + ChannelId = "websocket", + SenderId = "alice" + }, CancellationToken.None); + + await using var conn = new Microsoft.Data.Sqlite.SqliteConnection(new Microsoft.Data.Sqlite.SqliteConnectionStringBuilder { DataSource = dbPath }.ToString()); + await conn.OpenAsync(); + await using var cmd = conn.CreateCommand(); + cmd.CommandText = "UPDATE sessions SET json = '{not valid json' WHERE id = $id;"; + cmd.Parameters.AddWithValue("$id", "session-corrupt"); + await cmd.ExecuteNonQueryAsync(); + + var ex = await Assert.ThrowsAsync(async () => + await store.GetSessionAsync("session-corrupt", CancellationToken.None)); + Assert.Equal("session-corrupt", ex.SessionId); + } + finally + { + Directory.Delete(root, recursive: true); + } + } + private static string CreateTempDirectory() { var path = Path.Combine(Path.GetTempPath(), "openclaw-tests", Guid.NewGuid().ToString("n"));