From d1d9540834e8ebde6e65435dfa5e470b1b16e145 Mon Sep 17 00:00:00 2001 From: Eugen Neufeld Date: Thu, 11 Dec 2025 14:53:37 +0100 Subject: [PATCH 1/5] feat(ai-chat): add automatic chat session summarization for long conversations Automatically summarizes chat sessions when token usage approaches the context limit (90% of 200k tokens), enabling continued conversations without losing context from earlier messages. Core functionality: - Add `ChatSessionSummarizationService` to orchestrate summarization - Add `insertSummary()` method to `MutableChatModel` for inserting summary nodes - Add `isStale` flag to mark pre-summary messages (excluded from future prompts) - Add `kind` field to `ChatRequest` interface ('user' | 'summary') Budget-aware tool loop: - Add `singleRoundTrip` flag to `UserRequest` for controlled tool execution - Extend `ChatLanguageModelServiceImpl` with budget checking before/during requests - Trigger mid-turn summarization when threshold exceeded during tool loops - Support both threshold-triggered and explicit summarization Token usage tracking: - Add `TokenUsageService` for recording token usage across providers - Add `TokenUsageServiceClient` for frontend notification of usage updates - Display token count indicator in chat UI with session switching support UI components: - Add collapsible summary node rendering with bookmark icon - Add `SummaryPartRenderer` for displaying summary content - Add token usage indicator showing current session token count fixes #16703 fixes #16724 Current Limitations: - only supported by anthropic - hard coded budget of 200k tokens - hard coded trigger when reaching 90% of tokens --- .../src/node/anthropic-language-model.ts | 31 +- .../src/browser/ai-chat-ui-frontend-module.ts | 2 + .../summary-part-renderer.tsx | 67 +++ .../chat-token-usage-indicator.spec.ts | 306 ++++++++++++ .../browser/chat-token-usage-indicator.tsx | 115 +++++ .../chat-tree-view/chat-view-tree-widget.tsx | 46 +- .../src/browser/chat-view-widget.tsx | 37 ++ .../ai-chat-ui/src/browser/style/index.css | 197 +++++++- packages/ai-chat/README.md | 24 + .../src/browser/ai-chat-frontend-module.ts | 22 +- .../chat-language-model-service.spec.ts | 441 ++++++++++++++++++ .../browser/chat-language-model-service.ts | 317 +++++++++++++ .../src/browser/chat-session-store-impl.ts | 7 +- .../chat-session-summarization-service.ts | 215 +++++++++ ...chat-session-token-restore-contribution.ts | 42 ++ .../chat-session-token-tracker.spec.ts | 207 ++++++++ .../src/browser/chat-session-token-tracker.ts | 105 +++++ packages/ai-chat/src/browser/index.ts | 17 + .../ai-chat/src/common/ai-chat-preferences.ts | 10 + packages/ai-chat/src/common/chat-agents.ts | 5 + .../ai-chat/src/common/chat-auto-save.spec.ts | 1 + .../common/chat-content-deserializer.spec.ts | 31 ++ .../src/common/chat-content-deserializer.ts | 9 +- .../common/chat-model-insert-summary.spec.ts | 319 +++++++++++++ .../common/chat-model-serialization.spec.ts | 65 +++ .../src/common/chat-model-serialization.ts | 6 + packages/ai-chat/src/common/chat-model.ts | 201 +++++++- packages/ai-chat/src/common/chat-service.ts | 18 +- .../ai-chat/src/common/chat-session-store.ts | 1 + .../src/common/chat-session-token-tracker.ts | 73 +++ packages/ai-chat/src/common/index.ts | 1 + packages/ai-core/src/common/language-model.ts | 19 +- .../ai-core/src/common/token-usage-service.ts | 4 + .../src/node/token-usage-service-impl.ts | 3 +- .../src/node/google-language-model.ts | 6 +- .../src/node/ollama-language-model.ts | 41 +- .../src/node/openai-language-model.ts | 8 +- .../src/node/openai-response-api-utils.ts | 9 +- .../node/openai-streaming-iterator.spec.ts | 2 +- .../src/node/openai-streaming-iterator.ts | 4 +- .../src/node/vercel-ai-language-model.ts | 6 +- 41 files changed, 2979 insertions(+), 61 deletions(-) create mode 100644 packages/ai-chat-ui/src/browser/chat-response-renderer/summary-part-renderer.tsx create mode 100644 packages/ai-chat-ui/src/browser/chat-token-usage-indicator.spec.ts create mode 100644 packages/ai-chat-ui/src/browser/chat-token-usage-indicator.tsx create mode 100644 packages/ai-chat/src/browser/chat-language-model-service.spec.ts create mode 100644 packages/ai-chat/src/browser/chat-language-model-service.ts create mode 100644 packages/ai-chat/src/browser/chat-session-summarization-service.ts create mode 100644 packages/ai-chat/src/browser/chat-session-token-restore-contribution.ts create mode 100644 packages/ai-chat/src/browser/chat-session-token-tracker.spec.ts create mode 100644 packages/ai-chat/src/browser/chat-session-token-tracker.ts create mode 100644 packages/ai-chat/src/browser/index.ts create mode 100644 packages/ai-chat/src/common/chat-model-insert-summary.spec.ts create mode 100644 packages/ai-chat/src/common/chat-session-token-tracker.ts diff --git a/packages/ai-anthropic/src/node/anthropic-language-model.ts b/packages/ai-anthropic/src/node/anthropic-language-model.ts index 006acd1898b82..4687bf98edd14 100644 --- a/packages/ai-anthropic/src/node/anthropic-language-model.ts +++ b/packages/ai-anthropic/src/node/anthropic-language-model.ts @@ -315,15 +315,21 @@ export class AnthropicModel implements LanguageModel { currentMessage = event.message; } else if (event.type === 'message_stop') { if (currentMessage) { - yield { input_tokens: currentMessage.usage.input_tokens, output_tokens: currentMessage.usage.output_tokens }; + yield { + input_tokens: currentMessage.usage.input_tokens, + output_tokens: currentMessage.usage.output_tokens, + cache_creation_input_tokens: currentMessage.usage.cache_creation_input_tokens ?? undefined, + cache_read_input_tokens: currentMessage.usage.cache_read_input_tokens ?? undefined + }; // Record token usage if token usage service is available if (that.tokenUsageService && currentMessage.usage) { const tokenUsageParams: TokenUsageParams = { inputTokens: currentMessage.usage.input_tokens, outputTokens: currentMessage.usage.output_tokens, - cachedInputTokens: currentMessage.usage.cache_creation_input_tokens || undefined, - readCachedInputTokens: currentMessage.usage.cache_read_input_tokens || undefined, - requestId: request.requestId + cachedInputTokens: currentMessage.usage.cache_creation_input_tokens ?? undefined, + readCachedInputTokens: currentMessage.usage.cache_read_input_tokens ?? undefined, + requestId: request.requestId, + sessionId: request.sessionId }; await that.tokenUsageService.recordTokenUsage(that.id, tokenUsageParams); } @@ -332,6 +338,18 @@ export class AnthropicModel implements LanguageModel { } } if (toolCalls.length > 0) { + // If singleRoundTrip is true, yield tool calls without executing them + // The caller is responsible for tool execution and continuation + if (request.singleRoundTrip) { + const pendingCalls = toolCalls.map(tc => ({ + finished: true, + id: tc.id, + function: { name: tc.name, arguments: tc.args.length === 0 ? '{}' : tc.args } + })); + yield { tool_calls: pendingCalls }; + return; + } + const toolResult = await Promise.all(toolCalls.map(async tc => { const tool = request.tools?.find(t => t.name === tc.name); const argsObject = tc.args.length === 0 ? '{}' : tc.args; @@ -415,7 +433,10 @@ export class AnthropicModel implements LanguageModel { const tokenUsageParams: TokenUsageParams = { inputTokens: response.usage.input_tokens, outputTokens: response.usage.output_tokens, - requestId: request.requestId + cachedInputTokens: response.usage.cache_creation_input_tokens ?? undefined, + readCachedInputTokens: response.usage.cache_read_input_tokens ?? undefined, + requestId: request.requestId, + sessionId: request.sessionId }; await this.tokenUsageService.recordTokenUsage(this.id, tokenUsageParams); } diff --git a/packages/ai-chat-ui/src/browser/ai-chat-ui-frontend-module.ts b/packages/ai-chat-ui/src/browser/ai-chat-ui-frontend-module.ts index e0250fdf579fb..3214f098b5422 100644 --- a/packages/ai-chat-ui/src/browser/ai-chat-ui-frontend-module.ts +++ b/packages/ai-chat-ui/src/browser/ai-chat-ui-frontend-module.ts @@ -40,6 +40,7 @@ import { TextPartRenderer, } from './chat-response-renderer'; import { UnknownPartRenderer } from './chat-response-renderer/unknown-part-renderer'; +import { SummaryPartRenderer } from './chat-response-renderer/summary-part-renderer'; import { GitHubSelectionResolver, TextFragmentSelectionResolver, @@ -139,6 +140,7 @@ export default new ContainerModule((bind, _unbind, _isBound, rebind) => { bind(ChatResponsePartRenderer).to(TextPartRenderer).inSingletonScope(); bind(ChatResponsePartRenderer).to(DelegationResponseRenderer).inSingletonScope(); bind(ChatResponsePartRenderer).to(UnknownPartRenderer).inSingletonScope(); + bind(ChatResponsePartRenderer).to(SummaryPartRenderer).inSingletonScope(); [CommandContribution, MenuContribution].forEach(serviceIdentifier => bind(serviceIdentifier).to(ChatViewMenuContribution).inSingletonScope() ); diff --git a/packages/ai-chat-ui/src/browser/chat-response-renderer/summary-part-renderer.tsx b/packages/ai-chat-ui/src/browser/chat-response-renderer/summary-part-renderer.tsx new file mode 100644 index 0000000000000..5021352bd4db4 --- /dev/null +++ b/packages/ai-chat-ui/src/browser/chat-response-renderer/summary-part-renderer.tsx @@ -0,0 +1,67 @@ +// ***************************************************************************** +// Copyright (C) 2025 EclipseSource GmbH. +// +// This program and the accompanying materials are made available under the +// terms of the Eclipse Public License v. 2.0 which is available at +// http://www.eclipse.org/legal/epl-2.0. +// +// This Source Code may also be made available under the following Secondary +// Licenses when the conditions for such availability set forth in the Eclipse +// Public License v. 2.0 are satisfied: GNU General Public License, version 2 +// with the GNU Classpath Exception which is available at +// https://www.gnu.org/software/classpath/license.html. +// +// SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-only WITH Classpath-exception-2.0 +// ***************************************************************************** + +import { ChatResponsePartRenderer } from '../chat-response-part-renderer'; +import { inject, injectable } from '@theia/core/shared/inversify'; +import { ChatResponseContent, SummaryChatResponseContent } from '@theia/ai-chat/lib/common'; +import { ReactNode } from '@theia/core/shared/react'; +import { nls } from '@theia/core/lib/common/nls'; +import * as React from '@theia/core/shared/react'; +import { OpenerService } from '@theia/core/lib/browser'; +import { useMarkdownRendering } from './markdown-part-renderer'; + +/** + * Renderer for SummaryChatResponseContent. + * Displays the summary in a collapsible section that is collapsed by default. + */ +@injectable() +export class SummaryPartRenderer implements ChatResponsePartRenderer { + + @inject(OpenerService) + protected readonly openerService: OpenerService; + + canHandle(response: ChatResponseContent): number { + if (SummaryChatResponseContent.is(response)) { + return 10; + } + return -1; + } + + render(response: SummaryChatResponseContent): ReactNode { + return ; + } +} + +interface SummaryContentProps { + content: string; + openerService: OpenerService; +} + +const SummaryContent: React.FC = ({ content, openerService }) => { + const contentRef = useMarkdownRendering(content, openerService); + + return ( +
+
+ + + {nls.localize('theia/ai/chat-ui/summary-part-renderer/conversationSummary', 'Conversation Summary')} + +
+
+
+ ); +}; diff --git a/packages/ai-chat-ui/src/browser/chat-token-usage-indicator.spec.ts b/packages/ai-chat-ui/src/browser/chat-token-usage-indicator.spec.ts new file mode 100644 index 0000000000000..4a47251b145e6 --- /dev/null +++ b/packages/ai-chat-ui/src/browser/chat-token-usage-indicator.spec.ts @@ -0,0 +1,306 @@ +// ***************************************************************************** +// Copyright (C) 2025 EclipseSource GmbH. +// +// This program and the accompanying materials are made available under the +// terms of the Eclipse Public License v. 2.0 which is available at +// http://www.eclipse.org/legal/epl-2.0. +// +// This Source Code may also be made available under the following Secondary +// Licenses when the conditions for such availability set forth in the Eclipse +// Public License v. 2.0 are satisfied: GNU General Public License, version 2 +// with the GNU Classpath Exception which is available at +// https://www.gnu.org/software/classpath/license.html. +// +// SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-only WITH Classpath-exception-2.0 +// ***************************************************************************** + +import { enableJSDOM } from '@theia/core/lib/browser/test/jsdom'; +let disableJSDOM = enableJSDOM(); +import { FrontendApplicationConfigProvider } from '@theia/core/lib/browser/frontend-application-config-provider'; +FrontendApplicationConfigProvider.set({}); + +import { expect } from 'chai'; +import * as React from '@theia/core/shared/react'; +import * as ReactDOMClient from '@theia/core/shared/react-dom/client'; +import { flushSync } from '@theia/core/shared/react-dom'; +import { Emitter } from '@theia/core'; +import { + ChatSessionTokenTracker, + SessionTokenThresholdEvent, + SessionTokenUpdateEvent, + CHAT_TOKEN_THRESHOLD +} from '@theia/ai-chat/lib/browser'; +import { ChatTokenUsageIndicator, ChatTokenUsageIndicatorProps } from './chat-token-usage-indicator'; + +disableJSDOM(); + +describe('ChatTokenUsageIndicator', () => { + let container: HTMLDivElement; + let root: ReactDOMClient.Root; + + const createMockTokenTracker = (tokens: number | undefined): ChatSessionTokenTracker => { + const thresholdEmitter = new Emitter(); + const updateEmitter = new Emitter(); + return { + onThresholdExceeded: thresholdEmitter.event, + onSessionTokensUpdated: updateEmitter.event, + getSessionInputTokens: () => tokens, + resetSessionTokens: () => { }, + resetThresholdTrigger: () => { } + }; + }; + + before(() => { + disableJSDOM = enableJSDOM(); + }); + + after(() => { + disableJSDOM(); + }); + + beforeEach(() => { + container = document.createElement('div'); + document.body.appendChild(container); + root = ReactDOMClient.createRoot(container); + }); + + afterEach(() => { + flushSync(() => { + root.unmount(); + }); + container.remove(); + }); + + const renderComponent = (props: ChatTokenUsageIndicatorProps): void => { + flushSync(() => { + root.render(React.createElement(ChatTokenUsageIndicator, props)); + }); + }; + + describe('token formatting', () => { + it('should display "-" when no tokens are tracked', () => { + const mockTracker = createMockTokenTracker(undefined); + renderComponent({ + sessionId: 'test-session', + tokenTracker: mockTracker, + budgetAwareEnabled: true + }); + + const textContent = container.textContent; + expect(textContent).to.contain('-'); + }); + + it('should format small token counts as plain numbers', () => { + const mockTracker = createMockTokenTracker(500); + renderComponent({ + sessionId: 'test-session', + tokenTracker: mockTracker, + budgetAwareEnabled: true + }); + + const textContent = container.textContent; + expect(textContent).to.contain('500'); + }); + + it('should format large token counts with "k" suffix', () => { + const mockTracker = createMockTokenTracker(125000); + renderComponent({ + sessionId: 'test-session', + tokenTracker: mockTracker, + budgetAwareEnabled: true + }); + + const textContent = container.textContent; + expect(textContent).to.contain('125k'); + }); + + it('should format token counts with decimal "k" suffix when needed', () => { + const mockTracker = createMockTokenTracker(1500); + renderComponent({ + sessionId: 'test-session', + tokenTracker: mockTracker, + budgetAwareEnabled: true + }); + + const textContent = container.textContent; + expect(textContent).to.contain('1.5k'); + }); + }); + + describe('color coding', () => { + it('should have green class when usage is below 70%', () => { + // 70% of CHAT_TOKEN_THRESHOLD = 126000, so 100000 is below + expect(Math.round(CHAT_TOKEN_THRESHOLD * 0.7)).to.equal(126000); + const mockTracker = createMockTokenTracker(100000); + renderComponent({ + sessionId: 'test-session', + tokenTracker: mockTracker, + budgetAwareEnabled: true + }); + + const indicator = container.querySelector('.theia-ChatTokenUsageIndicator'); + expect(indicator?.classList.contains('token-usage-green')).to.be.true; + }); + + it('should have yellow class when usage is between 70% and 90%', () => { + // 70% of CHAT_TOKEN_THRESHOLD = 126000 + // 90% of CHAT_TOKEN_THRESHOLD = 162000 + const mockTracker = createMockTokenTracker(150000); + renderComponent({ + sessionId: 'test-session', + tokenTracker: mockTracker, + budgetAwareEnabled: true + }); + + const indicator = container.querySelector('.theia-ChatTokenUsageIndicator'); + expect(indicator?.classList.contains('token-usage-yellow')).to.be.true; + }); + + it('should have red class when usage is at or above 90%', () => { + // 90% of CHAT_TOKEN_THRESHOLD = 162000 + const mockTracker = createMockTokenTracker(170000); + renderComponent({ + sessionId: 'test-session', + tokenTracker: mockTracker, + budgetAwareEnabled: true + }); + + const indicator = container.querySelector('.theia-ChatTokenUsageIndicator'); + expect(indicator?.classList.contains('token-usage-red')).to.be.true; + }); + + it('should have none class when no tokens are tracked', () => { + const mockTracker = createMockTokenTracker(undefined); + renderComponent({ + sessionId: 'test-session', + tokenTracker: mockTracker, + budgetAwareEnabled: true + }); + + const indicator = container.querySelector('.theia-ChatTokenUsageIndicator'); + expect(indicator?.classList.contains('token-usage-none')).to.be.true; + }); + }); + + describe('tooltip', () => { + it('should include budget-aware status in tooltip when enabled', () => { + const mockTracker = createMockTokenTracker(100000); + renderComponent({ + sessionId: 'test-session', + tokenTracker: mockTracker, + budgetAwareEnabled: true + }); + + const indicator = container.querySelector('.theia-ChatTokenUsageIndicator'); + const title = indicator?.getAttribute('title'); + expect(title).to.include('Budget-aware: Enabled'); + }); + + it('should include budget-aware status in tooltip when disabled', () => { + const mockTracker = createMockTokenTracker(100000); + renderComponent({ + sessionId: 'test-session', + tokenTracker: mockTracker, + budgetAwareEnabled: false + }); + + const indicator = container.querySelector('.theia-ChatTokenUsageIndicator'); + const title = indicator?.getAttribute('title'); + expect(title).to.include('Budget-aware: Disabled'); + }); + + it('should include threshold and budget values in tooltip', () => { + const mockTracker = createMockTokenTracker(100000); + renderComponent({ + sessionId: 'test-session', + tokenTracker: mockTracker, + budgetAwareEnabled: true + }); + + const indicator = container.querySelector('.theia-ChatTokenUsageIndicator'); + const title = indicator?.getAttribute('title'); + expect(title).to.include('Threshold:'); + expect(title).to.include('Budget:'); + }); + + it('should show "None" in tooltip when no tokens tracked', () => { + const mockTracker = createMockTokenTracker(undefined); + renderComponent({ + sessionId: 'test-session', + tokenTracker: mockTracker, + budgetAwareEnabled: true + }); + + const indicator = container.querySelector('.theia-ChatTokenUsageIndicator'); + const title = indicator?.getAttribute('title'); + expect(title).to.include('Tokens: None'); + }); + }); + + describe('subscription to token updates', () => { + it('should update when token tracker fires update event', () => { + const updateEmitter = new Emitter(); + const thresholdEmitter = new Emitter(); + let currentTokens = 50000; + + const mockTracker: ChatSessionTokenTracker = { + onThresholdExceeded: thresholdEmitter.event, + onSessionTokensUpdated: updateEmitter.event, + getSessionInputTokens: () => currentTokens, + resetSessionTokens: () => { }, + resetThresholdTrigger: () => { } + }; + + renderComponent({ + sessionId: 'test-session', + tokenTracker: mockTracker, + budgetAwareEnabled: true + }); + + // Initial state + let textContent = container.textContent; + expect(textContent).to.contain('50k'); + + // Fire update event within flushSync to ensure synchronous React update + currentTokens = 100000; + flushSync(() => { + updateEmitter.fire({ sessionId: 'test-session', inputTokens: 100000 }); + }); + + textContent = container.textContent; + expect(textContent).to.contain('100k'); + }); + + it('should not update when event is for different session', () => { + const updateEmitter = new Emitter(); + const thresholdEmitter = new Emitter(); + + const mockTracker: ChatSessionTokenTracker = { + onThresholdExceeded: thresholdEmitter.event, + onSessionTokensUpdated: updateEmitter.event, + getSessionInputTokens: () => 50000, + resetSessionTokens: () => { }, + resetThresholdTrigger: () => { } + }; + + renderComponent({ + sessionId: 'test-session', + tokenTracker: mockTracker, + budgetAwareEnabled: true + }); + + // Initial state + let textContent = container.textContent; + expect(textContent).to.contain('50k'); + + // Fire update event for different session within flushSync + flushSync(() => { + updateEmitter.fire({ sessionId: 'other-session', inputTokens: 100000 }); + }); + + textContent = container.textContent; + // Should still show 50k since we didn't update our session + expect(textContent).to.contain('50k'); + }); + }); +}); diff --git a/packages/ai-chat-ui/src/browser/chat-token-usage-indicator.tsx b/packages/ai-chat-ui/src/browser/chat-token-usage-indicator.tsx new file mode 100644 index 0000000000000..fa565a2baa081 --- /dev/null +++ b/packages/ai-chat-ui/src/browser/chat-token-usage-indicator.tsx @@ -0,0 +1,115 @@ +// ***************************************************************************** +// Copyright (C) 2025 EclipseSource GmbH. +// +// This program and the accompanying materials are made available under the +// terms of the Eclipse Public License v. 2.0 which is available at +// http://www.eclipse.org/legal/epl-2.0. +// +// This Source Code may also be made available under the following Secondary +// Licenses when the conditions for such availability set forth in the Eclipse +// Public License v. 2.0 are satisfied: GNU General Public License, version 2 +// with the GNU Classpath Exception which is available at +// https://www.gnu.org/software/classpath/license.html. +// +// SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-only WITH Classpath-exception-2.0 +// ***************************************************************************** + +import * as React from '@theia/core/shared/react'; +import { + ChatSessionTokenTracker, + CHAT_TOKEN_BUDGET, + CHAT_TOKEN_THRESHOLD +} from '@theia/ai-chat/lib/browser'; + +export interface ChatTokenUsageIndicatorProps { + sessionId: string; + tokenTracker: ChatSessionTokenTracker; + budgetAwareEnabled: boolean; +} + +/** + * Formats a token count to a human-readable string. + * E.g., 125000 -> "125k", 1500 -> "1.5k", 500 -> "500" + */ +const formatTokenCount = (tokens: number | undefined): string => { + if (tokens === undefined) { + return '-'; + } + if (tokens >= 1000) { + const k = tokens / 1000; + // Show one decimal place if needed, otherwise whole number + return k % 1 === 0 ? `${k}k` : `${k.toFixed(1)}k`; + } + return tokens.toString(); +}; + +/** + * Determines the color class based on usage percentage. + * - Green: <70% + * - Yellow: 70-<90% + * - Red: ≥90% + */ +const getUsageColorClass = (tokens: number | undefined, threshold: number): string => { + if (tokens === undefined) { + return 'token-usage-none'; + } + const percentage = (tokens / threshold) * 100; + if (percentage >= 90) { + return 'token-usage-red'; + } + if (percentage >= 70) { + return 'token-usage-yellow'; + } + return 'token-usage-green'; +}; + +/** + * A React component that displays the current token usage for a chat session. + * Shows current input tokens vs threshold with color coding based on usage percentage. + */ +export const ChatTokenUsageIndicator: React.FC = ({ + sessionId, + tokenTracker, + budgetAwareEnabled +}) => { + const [inputTokens, setInputTokens] = React.useState( + () => tokenTracker.getSessionInputTokens(sessionId) + ); + + React.useEffect(() => { + // Get initial value + setInputTokens(tokenTracker.getSessionInputTokens(sessionId)); + + // Subscribe to token updates + const disposable = tokenTracker.onSessionTokensUpdated(event => { + if (event.sessionId === sessionId) { + setInputTokens(event.inputTokens); + } + }); + + return () => disposable.dispose(); + }, [sessionId, tokenTracker]); + + const thresholdFormatted = formatTokenCount(CHAT_TOKEN_THRESHOLD); + const budgetFormatted = formatTokenCount(CHAT_TOKEN_BUDGET); + const currentFormatted = formatTokenCount(inputTokens); + const colorClass = getUsageColorClass(inputTokens, CHAT_TOKEN_THRESHOLD); + + const tooltipText = [ + `Tokens: ${inputTokens !== undefined ? inputTokens.toLocaleString() : 'None'}`, + `Threshold: ${CHAT_TOKEN_THRESHOLD.toLocaleString()} (${thresholdFormatted})`, + `Budget: ${CHAT_TOKEN_BUDGET.toLocaleString()} (${budgetFormatted})`, + `Budget-aware: ${budgetAwareEnabled ? 'Enabled' : 'Disabled'}` + ].join('\n'); + + return ( +
+ + {currentFormatted} / {thresholdFormatted} tokens + +
+ ); +}; diff --git a/packages/ai-chat-ui/src/browser/chat-tree-view/chat-view-tree-widget.tsx b/packages/ai-chat-ui/src/browser/chat-tree-view/chat-view-tree-widget.tsx index c2edadf90a7c9..87173c728a928 100644 --- a/packages/ai-chat-ui/src/browser/chat-tree-view/chat-view-tree-widget.tsx +++ b/packages/ai-chat-ui/src/browser/chat-tree-view/chat-view-tree-widget.tsx @@ -487,7 +487,10 @@ export class ChatViewTreeWidget extends TreeWidget { chatModel.getBranches().forEach(branch => { const request = branch.get(); nodes.push(this.mapRequestToNode(branch)); - nodes.push(this.mapResponseToNode(request.response)); + // Skip ResponseNode for summary nodes - they render content in the RequestNode + if (request.request.kind !== 'summary') { + nodes.push(this.mapResponseToNode(request.response)); + } }); this.model.root.children = nodes; this.model.refresh(); @@ -504,6 +507,17 @@ export class ChatViewTreeWidget extends TreeWidget { if (!(isRequestNode(node) || isResponseNode(node))) { return super.renderNode(node, props); } + + // Summary nodes render without agent header + const isSummaryNode = isRequestNode(node) && node.request.request.kind === 'summary'; + if (isSummaryNode) { + return +
this.handleContextMenu(node, e)}> + {this.renderDetail(node)} +
+
; + } + return
this.handleContextMenu(node, e)}> {this.renderAgent(node)} @@ -754,6 +768,11 @@ const WidgetContainer: React.FC = ({ widget }) => { return
; }; +const SummaryContentRenderer: React.FC<{ content: string; openerService: OpenerService }> = ({ content, openerService }) => { + const ref = useMarkdownRendering(content, openerService); + return
; +}; + const ChatRequestRender = ( { node, hoverService, chatAgentService, variableService, openerService, @@ -767,6 +786,29 @@ const ChatRequestRender = ( provideChatInputWidget: () => ReactWidget | undefined, }) => { const parts = node.request.message.parts; + const isStale = node.request.isStale === true; + const isSummaryNode = node.request.request.kind === 'summary'; + + // Summary nodes render header and content in a single unified node + if (isSummaryNode) { + const summaryContent = node.request.response.response.asDisplayString(); + return ( +
+
+ + + {nls.localize('theia/ai/chat-ui/chat-view-tree-widget/conversationSummary', 'Conversation Summary')} + + {summaryContent && ( +
+ +
+ )} +
+
+ ); + } + if (EditableChatRequestModel.isEditing(node.request)) { const widget = provideChatInputWidget(); if (widget) { @@ -805,7 +847,7 @@ const ChatRequestRender = ( }; return ( -
+

{parts.map((part, index) => { if (part instanceof ParsedChatRequestAgentPart || part instanceof ParsedChatRequestVariablePart) { diff --git a/packages/ai-chat-ui/src/browser/chat-view-widget.tsx b/packages/ai-chat-ui/src/browser/chat-view-widget.tsx index b16555cf52f15..b206fff94b1f5 100644 --- a/packages/ai-chat-ui/src/browser/chat-view-widget.tsx +++ b/packages/ai-chat-ui/src/browser/chat-view-widget.tsx @@ -15,6 +15,8 @@ // ***************************************************************************** import { CommandService, deepClone, Emitter, Event, MessageService, PreferenceService, URI } from '@theia/core'; import { ChatRequest, ChatRequestModel, ChatService, ChatSession, isActiveSessionChangedEvent, MutableChatModel } from '@theia/ai-chat'; +import { ChatSessionTokenTracker } from '@theia/ai-chat/lib/browser'; +import { BUDGET_AWARE_TOOL_LOOP_PREF } from '@theia/ai-chat/lib/common/ai-chat-preferences'; import { BaseWidget, codicon, ExtractableWidget, Message, PanelLayout, StatefulWidget } from '@theia/core/lib/browser'; import { nls } from '@theia/core/lib/common/nls'; import { inject, injectable, optional, postConstruct } from '@theia/core/shared/inversify'; @@ -25,6 +27,9 @@ import { AIVariableResolutionRequest } from '@theia/ai-core'; import { ProgressBarFactory } from '@theia/core/lib/browser/progress-bar-factory'; import { FrontendVariableService } from '@theia/ai-core/lib/browser'; import { FrontendLanguageModelRegistry } from '@theia/ai-core/lib/common'; +import { ChatTokenUsageIndicator } from './chat-token-usage-indicator'; +import * as React from '@theia/core/shared/react'; +import { Root, createRoot } from '@theia/core/shared/react-dom/client'; export namespace ChatViewWidget { export interface State { @@ -63,11 +68,17 @@ export class ChatViewWidget extends BaseWidget implements ExtractableWidget, Sta @inject(FrontendLanguageModelRegistry) protected readonly languageModelRegistry: FrontendLanguageModelRegistry; + @inject(ChatSessionTokenTracker) + protected readonly tokenTracker: ChatSessionTokenTracker; + @inject(ChatWelcomeMessageProvider) @optional() protected readonly welcomeProvider?: ChatWelcomeMessageProvider; protected chatSession: ChatSession; + protected tokenIndicatorContainer: HTMLDivElement; + protected tokenIndicatorRoot: Root; + protected _state: ChatViewWidget.State = { locked: false, temporaryLocked: false }; protected readonly onStateChangedEmitter = new Emitter(); @@ -107,8 +118,17 @@ export class ChatViewWidget extends BaseWidget implements ExtractableWidget, Sta layout.addWidget(this.treeWidget); this.inputWidget.node.classList.add('chat-input-widget'); layout.addWidget(this.inputWidget); + + // Add token indicator container after inputWidget is added to the layout + // so insertAdjacentElement can properly place it in the DOM + this.tokenIndicatorContainer = document.createElement('div'); + this.tokenIndicatorContainer.classList.add('chat-token-usage-container'); + this.inputWidget.node.insertAdjacentElement('beforebegin', this.tokenIndicatorContainer); this.chatSession = this.chatService.createSession(); + this.tokenIndicatorRoot = createRoot(this.tokenIndicatorContainer); + this.renderTokenIndicator(); + this.inputWidget.onQuery = this.onQuery.bind(this); this.inputWidget.onUnpin = this.onUnpin.bind(this); this.inputWidget.onCancel = this.onCancel.bind(this); @@ -177,6 +197,7 @@ export class ChatViewWidget extends BaseWidget implements ExtractableWidget, Sta this.treeWidget.trackChatModel(this.chatSession.model); this.inputWidget.chatModel = this.chatSession.model; this.inputWidget.pinnedAgent = this.chatSession.pinnedAgent; + this.renderTokenIndicator(); } else { console.warn(`Session with ${event.sessionId} not found.`); } @@ -299,4 +320,20 @@ export class ChatViewWidget extends BaseWidget implements ExtractableWidget, Sta getSettings(): { [key: string]: unknown } | undefined { return this.chatSession.model.settings; } + + protected renderTokenIndicator(): void { + const budgetAwareEnabled = this.preferenceService.get(BUDGET_AWARE_TOOL_LOOP_PREF, false); + this.tokenIndicatorRoot.render( + React.createElement(ChatTokenUsageIndicator, { + sessionId: this.chatSession.id, + tokenTracker: this.tokenTracker, + budgetAwareEnabled + }) + ); + } + + override dispose(): void { + this.tokenIndicatorRoot.unmount(); + super.dispose(); + } } diff --git a/packages/ai-chat-ui/src/browser/style/index.css b/packages/ai-chat-ui/src/browser/style/index.css index 3c3946c04c0df..5bcec8d7ed502 100644 --- a/packages/ai-chat-ui/src/browser/style/index.css +++ b/packages/ai-chat-ui/src/browser/style/index.css @@ -7,8 +7,8 @@ flex: 1; } -.chat-input-widget > .ps__rail-x, -.chat-input-widget > .ps__rail-y { +.chat-input-widget>.ps__rail-x, +.chat-input-widget>.ps__rail-y { display: none !important; } @@ -23,7 +23,7 @@ overflow-wrap: break-word; } -div:last-child > .theia-ChatNode { +div:last-child>.theia-ChatNode { border: none; } @@ -59,6 +59,7 @@ div:last-child > .theia-ChatNode { } @keyframes dots { + 0%, 20% { content: ""; @@ -137,7 +138,7 @@ div:last-child > .theia-ChatNode { padding-inline-start: 1rem; } -.theia-ChatNode li > p { +.theia-ChatNode li>p { margin-top: 0; margin-bottom: 0; } @@ -151,7 +152,7 @@ div:last-child > .theia-ChatNode { font-size: var(--theia-code-font-size); } -.theia-RequestNode > p div { +.theia-RequestNode>p div { display: inline; } @@ -450,8 +451,7 @@ div:last-child > .theia-ChatNode { text-align: center; } -.theia-ChatInput-ChangeSet-List - .theia-ChatInput-ChangeSet-Icon.codicon::before { +.theia-ChatInput-ChangeSet-List .theia-ChatInput-ChangeSet-Icon.codicon::before { font-size: var(--theia-ui-font-size1); } @@ -468,8 +468,7 @@ div:last-child > .theia-ChatNode { color: var(--theia-disabledForeground); } -.theia-ChatInput-ChangeSet-List - .theia-ChatInput-ChangeSet-AdditionalInfo-SuffixIcon { +.theia-ChatInput-ChangeSet-List .theia-ChatInput-ChangeSet-AdditionalInfo-SuffixIcon { font-size: var(--theia-ui-font-size0) px; margin-left: 4px; } @@ -749,8 +748,7 @@ div:last-child > .theia-ChatNode { display: flex; flex-direction: column; gap: 8px; - border: var(--theia-border-width) solid - var(--theia-sideBarSectionHeader-border); + border: var(--theia-border-width) solid var(--theia-sideBarSectionHeader-border); padding: 8px 12px 12px; border-radius: 5px; margin: 0 0 8px 0; @@ -1106,8 +1104,7 @@ details[open].collapsible-arguments .collapsible-arguments-summary { /* Delegation response styles */ .theia-delegation-container { - border: var(--theia-border-width) solid - var(--theia-sideBarSectionHeader-border); + border: var(--theia-border-width) solid var(--theia-sideBarSectionHeader-border); border-radius: var(--theia-ui-padding); margin: var(--theia-ui-padding) 0; background-color: var(--theia-sideBar-background); @@ -1122,8 +1119,7 @@ details[open].collapsible-arguments .collapsible-arguments-summary { padding: var(--theia-ui-padding); background-color: var(--theia-editorGroupHeader-tabsBackground); border-radius: var(--theia-ui-padding) var(--theia-ui-padding) 0 0; - border-bottom: var(--theia-border-width) solid - var(--theia-sideBarSectionHeader-border); + border-bottom: var(--theia-border-width) solid var(--theia-sideBarSectionHeader-border); list-style: none; position: relative; } @@ -1226,8 +1222,7 @@ details[open].collapsible-arguments .collapsible-arguments-summary { .delegation-prompt-section { margin-bottom: var(--theia-ui-padding); padding-bottom: var(--theia-ui-padding); - border-bottom: var(--theia-border-width) solid - var(--theia-sideBarSectionHeader-border); + border-bottom: var(--theia-border-width) solid var(--theia-sideBarSectionHeader-border); } .delegation-prompt { @@ -1244,7 +1239,7 @@ details[open].collapsible-arguments .collapsible-arguments-summary { margin-top: var(--theia-ui-padding); } -.delegation-response-section > strong { +.delegation-response-section>strong { display: block; margin-bottom: var(--theia-ui-padding); color: var(--theia-foreground); @@ -1283,6 +1278,136 @@ details[open].collapsible-arguments .collapsible-arguments-summary { color: var(--theia-button-foreground, #fff); } +/* Stale request indicator styles */ +.theia-RequestNode-stale { + opacity: 0.7; +} + + +/* Summary node styles */ +.theia-ChatNode-Summary { + background-color: var(--theia-editorGroupHeader-tabsBackground); + padding: 8px 16px; +} + +.theia-RequestNode-summary { + padding: 0; + width: 100%; +} + +.theia-RequestNode-summary details { + width: 100%; + border: 1px solid var(--theia-sideBarSectionHeader-border); + border-radius: 4px; + background-color: var(--theia-sideBar-background); +} + +.theia-RequestNode-SummaryHeader { + cursor: pointer; + padding: 8px 12px; + display: flex; + align-items: center; + gap: 8px; + font-weight: 600; + color: var(--theia-foreground); + user-select: none; + list-style: none; + position: relative; +} + +.theia-RequestNode-SummaryHeader::-webkit-details-marker { + display: none; +} + +.theia-RequestNode-SummaryHeader::marker { + content: ""; +} + +.theia-RequestNode-SummaryHeader:hover { + background-color: var(--theia-toolbar-hoverBackground); +} + +.theia-RequestNode-SummaryHeader .codicon { + color: var(--theia-button-background); + font-size: 14px; +} + +.theia-RequestNode-SummaryHeader::after { + content: "\25BC"; + position: absolute; + right: 12px; + top: 50%; + transform: translateY(-50%); + transition: transform 0.2s ease; + color: var(--theia-descriptionForeground); + font-size: var(--theia-ui-font-size1); +} + +.theia-RequestNode-summary details:not([open]) .theia-RequestNode-SummaryHeader::after { + transform: translateY(-50%) rotate(-90deg); +} + +.theia-RequestNode-SummaryContent { + padding: 12px; + border-top: 1px solid var(--theia-sideBarSectionHeader-border); + background-color: var(--theia-editor-background); + border-radius: 0 0 4px 4px; +} + +.theia-RequestNode-SummaryIndicator { + display: flex; + align-items: center; + gap: 6px; + padding: 4px 8px; + background-color: var(--theia-editorGroupHeader-tabsBackground); + color: var(--theia-descriptionForeground); + border-radius: 4px; + font-size: 12px; + font-weight: 500; +} + +.theia-RequestNode-SummaryIndicator .codicon { + font-size: 14px; + color: var(--theia-button-background); +} + +/* Chat summary styles */ +.theia-chat-summary { + margin: 8px 0; + border: 1px solid var(--theia-sideBarSectionHeader-border); + border-radius: 4px; + background-color: var(--theia-editorGroupHeader-tabsBackground); +} + +.theia-chat-summary details { + width: 100%; +} + +.theia-chat-summary summary { + cursor: pointer; + padding: 8px 12px; + display: flex; + align-items: center; + gap: 8px; + font-weight: 600; + color: var(--theia-foreground); + user-select: none; +} + +.theia-chat-summary summary:hover { + background-color: var(--theia-toolbar-hoverBackground); +} + +.theia-chat-summary summary .codicon { + color: var(--theia-descriptionForeground); +} + +.theia-chat-summary-content { + padding: 12px; + border-top: 1px solid var(--theia-sideBarSectionHeader-border); + background-color: var(--theia-editor-background); +} + /* Unknown content styles */ .theia-chat-unknown-content { margin: 4px 0; @@ -1310,3 +1435,39 @@ details[open].collapsible-arguments .collapsible-arguments-summary { font-family: var(--theia-code-font-family); white-space: pre-wrap; } + +/* Token usage indicator styles */ +.chat-token-usage-container { + padding: 4px 8px; + border-top: 1px solid var(--theia-sideBarSectionHeader-border); + background-color: var(--theia-sideBar-background); +} + +.theia-ChatTokenUsageIndicator { + font-size: 11px; + display: flex; + align-items: center; + justify-content: center; + gap: 4px; + font-family: var(--theia-ui-font-family); +} + +.theia-ChatTokenUsageIndicator .token-usage-text { + white-space: nowrap; +} + +.theia-ChatTokenUsageIndicator.token-usage-none { + color: var(--theia-disabledForeground); +} + +.theia-ChatTokenUsageIndicator.token-usage-green { + color: var(--theia-successForeground, var(--theia-charts-green)); +} + +.theia-ChatTokenUsageIndicator.token-usage-yellow { + color: var(--theia-editorWarning-foreground, var(--theia-charts-yellow)); +} + +.theia-ChatTokenUsageIndicator.token-usage-red { + color: var(--theia-errorForeground, var(--theia-charts-red)); +} diff --git a/packages/ai-chat/README.md b/packages/ai-chat/README.md index 8aafbd53d7a22..ceca078422f02 100644 --- a/packages/ai-chat/README.md +++ b/packages/ai-chat/README.md @@ -15,6 +15,30 @@ The `@theia/ai-chat` extension provides the concept of a language model chat to Theia. It serves as the basis for `@theia/ai-chat-ui` to provide the Chat UI. +## Features + +### Budget-Aware Tool Loop (Experimental) + +When enabled via the `ai-features.chat.experimentalBudgetAwareToolLoop` preference, the chat system can automatically trigger conversation summarization when the token budget is exceeded during tool call loops. This prevents "context too long" API errors during complex multi-tool tasks. + +**How it works:** +1. When a chat request includes tools, the system manages the tool loop externally instead of letting the language model handle it internally +2. Between tool call iterations, the system checks if the token budget is exceeded +3. If exceeded, it triggers summarization to compress the conversation history +4. The task continues with the summarized context plus the pending tool calls + +**Requirements:** +- Currently only supports Anthropic models (other models fall back to standard behavior) +- Requires the language model to support the `singleRoundTrip` request property + +**Preference:** + +```json +{ + "ai-features.chat.experimentalBudgetAwareToolLoop": true +} +``` + ## Additional Information - [API documentation for `@theia/ai-chat`](https://eclipse-theia.github.io/theia/docs/next/modules/_theia_ai-chat.html) diff --git a/packages/ai-chat/src/browser/ai-chat-frontend-module.ts b/packages/ai-chat/src/browser/ai-chat-frontend-module.ts index 7661bab3c6d17..55939a898d0bb 100644 --- a/packages/ai-chat/src/browser/ai-chat-frontend-module.ts +++ b/packages/ai-chat/src/browser/ai-chat-frontend-module.ts @@ -14,7 +14,7 @@ // SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-only WITH Classpath-exception-2.0 // ***************************************************************************** -import { Agent, AgentService, AIVariableContribution, bindToolProvider } from '@theia/ai-core/lib/common'; +import { Agent, AgentService, AIVariableContribution, bindToolProvider, LanguageModelService } from '@theia/ai-core/lib/common'; import { bindContributionProvider, CommandContribution, PreferenceContribution } from '@theia/core'; import { FrontendApplicationContribution, LabelProviderContribution } from '@theia/core/lib/browser'; import { ContainerModule } from '@theia/core/shared/inversify'; @@ -73,8 +73,12 @@ import { ChangeSetElementDeserializerRegistryImpl } from '../common/change-set-element-deserializer'; import { ChangeSetFileElementDeserializerContribution } from './change-set-file-element-deserializer'; +import { ChatSessionTokenTracker, ChatSessionTokenTrackerImpl } from './chat-session-token-tracker'; +import { ChatSessionTokenRestoreContribution } from './chat-session-token-restore-contribution'; +import { ChatSessionSummarizationService, ChatSessionSummarizationServiceImpl } from './chat-session-summarization-service'; +import { ChatLanguageModelServiceImpl } from './chat-language-model-service'; -export default new ContainerModule(bind => { +export default new ContainerModule((bind, unbind, isBound, rebind) => { bindContributionProvider(bind, ChatAgent); bind(ChatContentDeserializerRegistryImpl).toSelf().inSingletonScope(); @@ -186,4 +190,18 @@ export default new ContainerModule(bind => { bind(CommandContribution).toService(AIChatFrontendContribution); bindToolProvider(AgentDelegationTool, bind); + + bind(ChatSessionTokenTrackerImpl).toSelf().inSingletonScope(); + bind(ChatSessionTokenTracker).toService(ChatSessionTokenTrackerImpl); + + bind(ChatSessionTokenRestoreContribution).toSelf().inSingletonScope(); + bind(FrontendApplicationContribution).toService(ChatSessionTokenRestoreContribution); + + bind(ChatSessionSummarizationServiceImpl).toSelf().inSingletonScope(); + bind(ChatSessionSummarizationService).toService(ChatSessionSummarizationServiceImpl); + bind(FrontendApplicationContribution).toService(ChatSessionSummarizationServiceImpl); + + // Rebind LanguageModelService to use the chat-aware implementation with budget-aware tool loop + bind(ChatLanguageModelServiceImpl).toSelf().inSingletonScope(); + rebind(LanguageModelService).toService(ChatLanguageModelServiceImpl); }); diff --git a/packages/ai-chat/src/browser/chat-language-model-service.spec.ts b/packages/ai-chat/src/browser/chat-language-model-service.spec.ts new file mode 100644 index 0000000000000..614c0608058f8 --- /dev/null +++ b/packages/ai-chat/src/browser/chat-language-model-service.spec.ts @@ -0,0 +1,441 @@ +// ***************************************************************************** +// Copyright (C) 2025 EclipseSource GmbH. +// +// This program and the accompanying materials are made available under the +// terms of the Eclipse Public License v. 2.0 which is available at +// http://www.eclipse.org/legal/epl-2.0. +// +// This Source Code may also be made available under the following Secondary +// Licenses when the conditions for such availability set forth in the Eclipse +// Public License v. 2.0 are satisfied: GNU General Public License, version 2 +// with the GNU Classpath Exception which is available at +// https://www.gnu.org/software/classpath/license.html. +// +// SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-only WITH Classpath-exception-2.0 +// ***************************************************************************** + +import { expect } from 'chai'; +import * as sinon from 'sinon'; +import { Container } from '@theia/core/shared/inversify'; +import { ILogger, PreferenceService } from '@theia/core'; +import { + LanguageModel, + LanguageModelRegistry, + LanguageModelStreamResponse, + LanguageModelStreamResponsePart, + UserRequest, + isLanguageModelStreamResponse +} from '@theia/ai-core'; +import { ChatLanguageModelServiceImpl } from './chat-language-model-service'; +import { ChatSessionTokenTracker, CHAT_TOKEN_THRESHOLD } from './chat-session-token-tracker'; +import { ChatSessionSummarizationService } from './chat-session-summarization-service'; +import { BUDGET_AWARE_TOOL_LOOP_PREF } from '../common/ai-chat-preferences'; +import { PREFERENCE_NAME_REQUEST_SETTINGS } from '@theia/ai-core/lib/common/ai-core-preferences'; + +describe('ChatLanguageModelServiceImpl', () => { + let container: Container; + let service: ChatLanguageModelServiceImpl; + let mockLanguageModel: sinon.SinonStubbedInstance; + let mockPreferenceService: sinon.SinonStubbedInstance; + let mockTokenTracker: sinon.SinonStubbedInstance; + let mockSummarizationService: sinon.SinonStubbedInstance; + let mockLogger: sinon.SinonStubbedInstance; + + beforeEach(() => { + container = new Container(); + + // Create mocks + mockLanguageModel = { + id: 'test/model', + name: 'Test Model', + request: sinon.stub() + } as unknown as sinon.SinonStubbedInstance; + + mockPreferenceService = { + get: sinon.stub() + } as unknown as sinon.SinonStubbedInstance; + + mockTokenTracker = { + getSessionInputTokens: sinon.stub(), + onThresholdExceeded: sinon.stub() + } as unknown as sinon.SinonStubbedInstance; + + mockSummarizationService = { + triggerSummarization: sinon.stub(), + hasSummary: sinon.stub() + } as unknown as sinon.SinonStubbedInstance; + + mockLogger = { + info: sinon.stub(), + warn: sinon.stub(), + error: sinon.stub(), + debug: sinon.stub() + } as unknown as sinon.SinonStubbedInstance; + + // Bind mocks + container.bind(LanguageModelRegistry).toConstantValue({} as LanguageModelRegistry); + container.bind(PreferenceService).toConstantValue(mockPreferenceService); + container.bind(ChatSessionTokenTracker).toConstantValue(mockTokenTracker); + container.bind(ChatSessionSummarizationService).toConstantValue(mockSummarizationService); + container.bind(ILogger).toConstantValue(mockLogger); + container.bind(ChatLanguageModelServiceImpl).toSelf().inSingletonScope(); + + service = container.get(ChatLanguageModelServiceImpl); + + // Default preference setup + mockPreferenceService.get.withArgs(PREFERENCE_NAME_REQUEST_SETTINGS, []).returns([]); + }); + + afterEach(() => { + sinon.restore(); + }); + + describe('sendRequest', () => { + it('should delegate to super when preference is disabled', async () => { + mockPreferenceService.get.withArgs(BUDGET_AWARE_TOOL_LOOP_PREF, false).returns(false); + + const request: UserRequest = { + sessionId: 'session-1', + requestId: 'request-1', + messages: [{ actor: 'user', type: 'text', text: 'Hello' }], + tools: [{ id: 'tool-1', name: 'test-tool', parameters: { type: 'object', properties: {} }, handler: async () => 'result' }] + }; + + const mockStream = createMockStream([{ content: 'Response' }]); + mockLanguageModel.request.resolves({ stream: mockStream }); + + const response = await service.sendRequest(mockLanguageModel, request); + + expect(isLanguageModelStreamResponse(response)).to.be.true; + expect(mockSummarizationService.triggerSummarization.called).to.be.false; + }); + + it('should delegate to super when request has no tools', async () => { + mockPreferenceService.get.withArgs(BUDGET_AWARE_TOOL_LOOP_PREF, false).returns(true); + + const request: UserRequest = { + sessionId: 'session-1', + requestId: 'request-1', + messages: [{ actor: 'user', type: 'text', text: 'Hello' }] + // No tools + }; + + const mockStream = createMockStream([{ content: 'Response' }]); + mockLanguageModel.request.resolves({ stream: mockStream }); + + const response = await service.sendRequest(mockLanguageModel, request); + + expect(isLanguageModelStreamResponse(response)).to.be.true; + expect(mockSummarizationService.triggerSummarization.called).to.be.false; + }); + + it('should use budget-aware handling when preference is enabled and tools are present', async () => { + mockPreferenceService.get.withArgs(BUDGET_AWARE_TOOL_LOOP_PREF, false).returns(true); + mockTokenTracker.getSessionInputTokens.returns(100); // Below threshold + + const request: UserRequest = { + sessionId: 'session-1', + requestId: 'request-1', + messages: [{ actor: 'user', type: 'text', text: 'Hello' }], + tools: [{ id: 'tool-1', name: 'test-tool', parameters: { type: 'object', properties: {} }, handler: async () => 'result' }] + }; + + const mockStream = createMockStream([{ content: 'Response' }]); + mockLanguageModel.request.resolves({ stream: mockStream }); + + const response = await service.sendRequest(mockLanguageModel, request); + + expect(isLanguageModelStreamResponse(response)).to.be.true; + + // Consume stream to trigger the actual request + // eslint-disable-next-line @typescript-eslint/no-unused-vars + for await (const _part of (response as LanguageModelStreamResponse).stream) { + // just consume + } + + // Verify singleRoundTrip was set + expect(mockLanguageModel.request.calledOnce).to.be.true; + const actualRequest = mockLanguageModel.request.firstCall.args[0] as UserRequest; + expect(actualRequest.singleRoundTrip).to.be.true; + }); + }); + + describe('budget checking', () => { + it('should trigger summarization after tool execution when budget is exceeded mid-loop', async () => { + mockPreferenceService.get.withArgs(BUDGET_AWARE_TOOL_LOOP_PREF, false).returns(true); + // Return over threshold - budget check happens before request and after tool execution + mockTokenTracker.getSessionInputTokens.returns(CHAT_TOKEN_THRESHOLD + 1000); + mockSummarizationService.triggerSummarization.resolves(undefined); + + const toolHandler = sinon.stub().resolves('tool result'); + const request: UserRequest = { + sessionId: 'session-1', + requestId: 'request-1', + messages: [{ actor: 'user', type: 'text', text: 'Hello' }], + tools: [{ + id: 'tool-1', + name: 'test-tool', + parameters: { type: 'object', properties: {} }, + handler: toolHandler + }] + }; + + // First call: model returns tool call without result + const firstStream = createMockStream([ + { content: 'Let me use a tool' }, + { tool_calls: [{ id: 'call-1', function: { name: 'test-tool', arguments: '{}' }, finished: true }] } + ]); + + // Second call: model returns final response + const secondStream = createMockStream([{ content: 'Done!' }]); + + mockLanguageModel.request + .onFirstCall().resolves({ stream: firstStream }) + .onSecondCall().resolves({ stream: secondStream }); + + const response = await service.sendRequest(mockLanguageModel, request); + + // Consume stream to trigger the tool loop + // eslint-disable-next-line @typescript-eslint/no-unused-vars + for await (const _part of (response as LanguageModelStreamResponse).stream) { + // just consume + } + + // Verify summarization was triggered both before request and after tool execution + expect(mockSummarizationService.triggerSummarization.calledTwice).to.be.true; + // First call (before request): no skipReorder + expect(mockSummarizationService.triggerSummarization.firstCall.calledWith('session-1')).to.be.true; + // Second call (mid-turn, after tool execution): skipReorder=true + expect(mockSummarizationService.triggerSummarization.secondCall.calledWith('session-1', true)).to.be.true; + }); + + it('should trigger summarization before request when budget is exceeded', async () => { + mockPreferenceService.get.withArgs(BUDGET_AWARE_TOOL_LOOP_PREF, false).returns(true); + mockTokenTracker.getSessionInputTokens.returns(CHAT_TOKEN_THRESHOLD + 1000); + mockSummarizationService.triggerSummarization.resolves(undefined); + + const request: UserRequest = { + sessionId: 'session-1', + requestId: 'request-1', + messages: [{ actor: 'user', type: 'text', text: 'Hello' }], + tools: [{ id: 'tool-1', name: 'test-tool', parameters: { type: 'object', properties: {} }, handler: async () => 'result' }] + }; + + // Model returns response without tool calls + const mockStream = createMockStream([{ content: 'Response' }]); + mockLanguageModel.request.resolves({ stream: mockStream }); + + const response = await service.sendRequest(mockLanguageModel, request); + + // Consume stream to trigger the actual request + // eslint-disable-next-line @typescript-eslint/no-unused-vars + for await (const _part of (response as LanguageModelStreamResponse).stream) { + // just consume + } + + // Summarization should be triggered before the request since budget is exceeded + expect(mockSummarizationService.triggerSummarization.calledOnce).to.be.true; + expect(mockSummarizationService.triggerSummarization.calledWith('session-1')).to.be.true; + }); + + it('should preserve original messages (no message rebuilding)', async () => { + mockPreferenceService.get.withArgs(BUDGET_AWARE_TOOL_LOOP_PREF, false).returns(true); + mockTokenTracker.getSessionInputTokens.returns(100); // Below threshold + + const request: UserRequest = { + sessionId: 'session-1', + requestId: 'request-1', + messages: [ + { actor: 'system', type: 'text', text: 'System prompt' }, + { actor: 'user', type: 'text', text: 'Hello' } + ], + tools: [{ id: 'tool-1', name: 'test-tool', parameters: { type: 'object', properties: {} }, handler: async () => 'result' }] + }; + + const mockStream = createMockStream([{ content: 'Response' }]); + mockLanguageModel.request.resolves({ stream: mockStream }); + + const response = await service.sendRequest(mockLanguageModel, request); + + // Consume stream to trigger the actual request + // eslint-disable-next-line @typescript-eslint/no-unused-vars + for await (const _part of (response as LanguageModelStreamResponse).stream) { + // just consume + } + + // Verify original messages are preserved + expect(mockLanguageModel.request.calledOnce).to.be.true; + const actualRequest = mockLanguageModel.request.firstCall.args[0] as UserRequest; + expect(actualRequest.messages).to.have.length(2); + expect((actualRequest.messages[0] as { text: string }).text).to.equal('System prompt'); + expect((actualRequest.messages[1] as { text: string }).text).to.equal('Hello'); + }); + }); + + describe('tool loop handling', () => { + it('should execute tools and continue loop when model respects singleRoundTrip', async () => { + mockPreferenceService.get.withArgs(BUDGET_AWARE_TOOL_LOOP_PREF, false).returns(true); + mockTokenTracker.getSessionInputTokens.returns(100); + + const toolHandler = sinon.stub().resolves('tool result'); + const request: UserRequest = { + sessionId: 'session-1', + requestId: 'request-1', + messages: [{ actor: 'user', type: 'text', text: 'Hello' }], + tools: [{ + id: 'tool-1', + name: 'test-tool', + parameters: { type: 'object', properties: {} }, + handler: toolHandler + }] + }; + + // First call: model returns tool call without result (respected singleRoundTrip) + const firstStream = createMockStream([ + { content: 'Let me use a tool' }, + { tool_calls: [{ id: 'call-1', function: { name: 'test-tool', arguments: '{}' }, finished: true }] } + ]); + + // Second call: model returns final response + const secondStream = createMockStream([ + { content: 'Done!' } + ]); + + mockLanguageModel.request + .onFirstCall().resolves({ stream: firstStream }) + .onSecondCall().resolves({ stream: secondStream }); + + const response = await service.sendRequest(mockLanguageModel, request); + expect(isLanguageModelStreamResponse(response)).to.be.true; + + // Consume the stream to trigger the tool loop + const parts: LanguageModelStreamResponsePart[] = []; + for await (const part of (response as LanguageModelStreamResponse).stream) { + parts.push(part); + } + + // Verify tool was executed + expect(toolHandler.calledOnce).to.be.true; + + // Verify two LLM calls were made (initial + continuation) + expect(mockLanguageModel.request.calledTwice).to.be.true; + }); + + it('should not execute tools when model ignores singleRoundTrip (has results)', async () => { + mockPreferenceService.get.withArgs(BUDGET_AWARE_TOOL_LOOP_PREF, false).returns(true); + mockTokenTracker.getSessionInputTokens.returns(100); + + const toolHandler = sinon.stub().resolves('tool result'); + const request: UserRequest = { + sessionId: 'session-1', + requestId: 'request-1', + messages: [{ actor: 'user', type: 'text', text: 'Hello' }], + tools: [{ + id: 'tool-1', + name: 'test-tool', + parameters: { type: 'object', properties: {} }, + handler: toolHandler + }] + }; + + // Model returns tool call WITH result (ignored singleRoundTrip, handled internally) + const mockStream = createMockStream([ + { content: 'Let me use a tool' }, + { tool_calls: [{ id: 'call-1', function: { name: 'test-tool', arguments: '{}' }, finished: true, result: 'internal result' }] }, + { content: 'Done!' } + ]); + + mockLanguageModel.request.resolves({ stream: mockStream }); + + const response = await service.sendRequest(mockLanguageModel, request); + expect(isLanguageModelStreamResponse(response)).to.be.true; + + // Consume the stream + const parts: LanguageModelStreamResponsePart[] = []; + for await (const part of (response as LanguageModelStreamResponse).stream) { + parts.push(part); + } + + // Tool should NOT have been executed by our service (model did it) + expect(toolHandler.called).to.be.false; + + // Only one LLM call (model handled everything) + expect(mockLanguageModel.request.calledOnce).to.be.true; + }); + }); + + describe('error handling', () => { + it('should log context-too-long errors and re-throw', async () => { + mockPreferenceService.get.withArgs(BUDGET_AWARE_TOOL_LOOP_PREF, false).returns(true); + mockTokenTracker.getSessionInputTokens.returns(100); + + const request: UserRequest = { + sessionId: 'session-1', + requestId: 'request-1', + messages: [{ actor: 'user', type: 'text', text: 'Hello' }], + tools: [{ id: 'tool-1', name: 'test-tool', parameters: { type: 'object', properties: {} }, handler: async () => 'result' }] + }; + + const contextError = new Error('Request too long: context exceeds maximum token limit'); + mockLanguageModel.request.rejects(contextError); + + const response = await service.sendRequest(mockLanguageModel, request); + expect(isLanguageModelStreamResponse(response)).to.be.true; + + // Consuming the stream should throw + try { + // eslint-disable-next-line @typescript-eslint/no-unused-vars + for await (const _part of (response as LanguageModelStreamResponse).stream) { + // Should throw before we get here + } + expect.fail('Should have thrown an error'); + } catch (error) { + expect(error).to.equal(contextError); + expect(mockLogger.error.calledOnce).to.be.true; + expect(mockLogger.error.firstCall.args[0]).to.include('Context too long'); + } + }); + + it('should propagate non-context errors without special logging', async () => { + mockPreferenceService.get.withArgs(BUDGET_AWARE_TOOL_LOOP_PREF, false).returns(true); + mockTokenTracker.getSessionInputTokens.returns(100); + + const request: UserRequest = { + sessionId: 'session-1', + requestId: 'request-1', + messages: [{ actor: 'user', type: 'text', text: 'Hello' }], + tools: [{ id: 'tool-1', name: 'test-tool', parameters: { type: 'object', properties: {} }, handler: async () => 'result' }] + }; + + const networkError = new Error('Network connection failed'); + mockLanguageModel.request.rejects(networkError); + + const response = await service.sendRequest(mockLanguageModel, request); + + try { + // eslint-disable-next-line @typescript-eslint/no-unused-vars + for await (const _part of (response as LanguageModelStreamResponse).stream) { + // Should throw + } + expect.fail('Should have thrown an error'); + } catch (error) { + expect(error).to.equal(networkError); + // Should not have logged "Context too long" message + expect(mockLogger.error.called).to.be.false; + } + }); + }); +}); + +/** + * Helper to create a mock async iterable stream from an array of parts. + */ +function createMockStream(parts: LanguageModelStreamResponsePart[]): AsyncIterable { + return { + async *[Symbol.asyncIterator](): AsyncIterator { + for (const part of parts) { + yield part; + } + } + }; +} diff --git a/packages/ai-chat/src/browser/chat-language-model-service.ts b/packages/ai-chat/src/browser/chat-language-model-service.ts new file mode 100644 index 0000000000000..3e7bceda531e6 --- /dev/null +++ b/packages/ai-chat/src/browser/chat-language-model-service.ts @@ -0,0 +1,317 @@ +// ***************************************************************************** +// Copyright (C) 2025 EclipseSource GmbH. +// +// This program and the accompanying materials are made available under the +// terms of the Eclipse Public License v. 2.0 which is available at +// http://www.eclipse.org/legal/epl-2.0. +// +// This Source Code may also be made available under the following Secondary +// Licenses when the conditions for such availability set forth in the Eclipse +// Public License v. 2.0 are satisfied: GNU General Public License, version 2 +// with the GNU Classpath Exception which is available at +// https://www.gnu.org/software/classpath/license.html. +// +// SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-only WITH Classpath-exception-2.0 +// ***************************************************************************** + +import { injectable, inject } from '@theia/core/shared/inversify'; +import { ILogger, PreferenceService } from '@theia/core'; +import { LanguageModelServiceImpl } from '@theia/ai-core/lib/common/language-model-service'; +import { + LanguageModel, + LanguageModelResponse, + LanguageModelStreamResponse, + LanguageModelStreamResponsePart, + UserRequest, + isLanguageModelStreamResponse, + isToolCallResponsePart, + ToolCall, + ToolRequest, + ToolCallResult, + LanguageModelMessage, + LanguageModelRegistry +} from '@theia/ai-core'; +import { BUDGET_AWARE_TOOL_LOOP_PREF } from '../common/ai-chat-preferences'; +import { ChatSessionTokenTracker, CHAT_TOKEN_THRESHOLD } from './chat-session-token-tracker'; +import { ChatSessionSummarizationService } from './chat-session-summarization-service'; +import { PREFERENCE_NAME_REQUEST_SETTINGS, RequestSetting } from '@theia/ai-core/lib/common/ai-core-preferences'; +import { mergeRequestSettings } from '@theia/ai-core/lib/browser/frontend-language-model-service'; + +/** + * Chat-specific language model service that adds budget-aware tool loop handling. + * Extends LanguageModelServiceImpl to intercept sendRequest() calls. + * + * When the experimental preference is enabled, this service: + * 1. Sets singleRoundTrip=true to prevent models from handling tool loops internally + * 2. Manages the tool loop externally with budget checks between iterations + * 3. Triggers summarization when token budget is exceeded mid-turn + * + * Models that don't support singleRoundTrip will ignore the flag - this is detected + * by checking if tool_calls have results attached (model handled internally) vs + * no results (model respected the flag). + */ +@injectable() +export class ChatLanguageModelServiceImpl extends LanguageModelServiceImpl { + + @inject(LanguageModelRegistry) + protected override languageModelRegistry: LanguageModelRegistry; + + @inject(ChatSessionTokenTracker) + protected readonly tokenTracker: ChatSessionTokenTracker; + + @inject(PreferenceService) + protected readonly preferenceService: PreferenceService; + + @inject(ILogger) + protected readonly logger: ILogger; + + @inject(ChatSessionSummarizationService) + protected readonly summarizationService: ChatSessionSummarizationService; + + override async sendRequest( + languageModel: LanguageModel, + request: UserRequest + ): Promise { + // Apply request settings (matching FrontendLanguageModelServiceImpl behavior) + const requestSettings = this.preferenceService.get(PREFERENCE_NAME_REQUEST_SETTINGS, []); + const ids = languageModel.id.split('/'); + const matchingSetting = mergeRequestSettings(requestSettings, ids[1], ids[0], request.agentId); + if (matchingSetting?.requestSettings) { + request.settings = { + ...matchingSetting.requestSettings, + ...request.settings + }; + } + if (matchingSetting?.clientSettings) { + request.clientSettings = { + ...matchingSetting.clientSettings, + ...request.clientSettings + }; + } + + const budgetAwareEnabled = this.preferenceService.get(BUDGET_AWARE_TOOL_LOOP_PREF, false); + + if (budgetAwareEnabled && request.tools?.length) { + return this.sendRequestWithBudgetAwareness(languageModel, request); + } + + return super.sendRequest(languageModel, request); + } + + /** + * Send request with budget-aware tool loop handling. + * Manages the tool loop externally, checking token budget between iterations + * and triggering summarization when needed. + */ + protected async sendRequestWithBudgetAwareness( + languageModel: LanguageModel, + request: UserRequest + ): Promise { + // Check if budget is exceeded BEFORE sending + if (request.sessionId && this.isBudgetExceeded(request.sessionId)) { + this.logger.info(`Budget exceeded before request for session ${request.sessionId}, triggering summarization...`); + await this.summarizationService.triggerSummarization(request.sessionId, false); + } + + const modifiedRequest: UserRequest = { + ...request, + singleRoundTrip: true + }; + return this.executeToolLoop(languageModel, modifiedRequest); + } + + /** + * Execute the tool loop, handling tool calls and budget checks between iterations. + */ + protected async executeToolLoop( + languageModel: LanguageModel, + request: UserRequest + ): Promise { + const that = this; + const sessionId = request.sessionId; + const tools = request.tools ?? []; + + // State that persists across the async iterator + let currentMessages = [...request.messages]; + let pendingToolCalls: ToolCall[] = []; + let modelHandledLoop = false; + + const asyncIterator = { + async *[Symbol.asyncIterator](): AsyncIterator { + let continueLoop = true; + + while (continueLoop) { + continueLoop = false; + pendingToolCalls = []; + modelHandledLoop = false; + + // Create request with current messages + const currentRequest: UserRequest = { + ...request, + messages: currentMessages, + singleRoundTrip: true + }; + + let response: LanguageModelResponse; + try { + // Call the parent's sendRequest to get the response + response = await LanguageModelServiceImpl.prototype.sendRequest.call( + that, languageModel, currentRequest + ); + } catch (error) { + // Check if this is a "context too long" error + const errorMessage = error instanceof Error ? error.message : String(error); + if (errorMessage.toLowerCase().includes('context') || + errorMessage.toLowerCase().includes('token') || + errorMessage.toLowerCase().includes('too long') || + errorMessage.toLowerCase().includes('max_tokens')) { + that.logger.error( + 'Context too long error for session ' + sessionId + '. ' + + 'Cannot recover - summarization also requires an LLM call.', + error + ); + } + // Re-throw to let the chat agent handle and display the error + throw error; + } + + if (!isLanguageModelStreamResponse(response)) { + // Non-streaming response - just return as-is + // This shouldn't happen with singleRoundTrip but handle gracefully + return; + } + + // Process the stream + for await (const part of response.stream) { + // Collect tool calls to check if model respected singleRoundTrip + if (isToolCallResponsePart(part)) { + for (const toolCall of part.tool_calls) { + // If any tool call has a result, the model handled the loop internally + if (toolCall.result !== undefined) { + modelHandledLoop = true; + } + // Collect finished tool calls without results (model respected singleRoundTrip) + if (toolCall.finished && toolCall.result === undefined && toolCall.id) { + pendingToolCalls.push(toolCall); + } + } + } + yield part; + } + + // If model handled the loop internally, we're done + if (modelHandledLoop) { + return; + } + + // If there are pending tool calls, execute them and continue the loop + if (pendingToolCalls.length > 0) { + // Execute tools + const toolResults = await that.executeTools(pendingToolCalls, tools); + + // Check budget after tool execution + if (that.isBudgetExceeded(sessionId)) { + that.logger.info(`Budget exceeded after tool execution for session ${sessionId}, triggering summarization...`); + // Pass skipReorder=true for mid-turn summarization to avoid disrupting the active request + await that.summarizationService.triggerSummarization(sessionId, true); + } + + // Append tool messages to current messages + currentMessages = that.appendToolMessages( + currentMessages, + pendingToolCalls, + toolResults + ); + + // Yield tool call results + const resultsToYield = pendingToolCalls.map(tc => ({ + finished: true, + id: tc.id, + result: toolResults.get(tc.id!), + function: tc.function + })); + yield { tool_calls: resultsToYield }; + + continueLoop = true; + } + } + } + }; + + return { stream: asyncIterator }; + } + + /** + * Check if the token budget is exceeded for a session. + */ + protected isBudgetExceeded(sessionId: string): boolean { + const tokens = this.tokenTracker.getSessionInputTokens(sessionId); + return tokens !== undefined && tokens >= CHAT_TOKEN_THRESHOLD; + } + + /** + * Execute tool calls and collect results. + */ + protected async executeTools( + toolCalls: ToolCall[], + toolRequests: ToolRequest[] + ): Promise> { + const results = new Map(); + + for (const toolCall of toolCalls) { + const toolRequest = toolRequests.find(t => t.name === toolCall.function?.name); + if (toolRequest && toolCall.id && toolCall.function?.arguments) { + try { + const result = await toolRequest.handler(toolCall.function.arguments); + results.set(toolCall.id, result); + } catch (error) { + this.logger.error(`Tool execution failed for ${toolCall.function?.name}:`, error); + results.set(toolCall.id, { type: 'error', data: String(error) } as ToolCallResult); + } + } + } + + return results; + } + + /** + * Append tool_use and tool_result messages to the message array. + */ + protected appendToolMessages( + messages: LanguageModelMessage[], + toolCalls: ToolCall[], + toolResults: Map + ): LanguageModelMessage[] { + const newMessages: LanguageModelMessage[] = [...messages]; + + // Add tool_use messages (AI requesting tool calls) + for (const toolCall of toolCalls) { + if (toolCall.id && toolCall.function?.name) { + newMessages.push({ + actor: 'ai', + type: 'tool_use', + id: toolCall.id, + name: toolCall.function.name, + input: toolCall.function.arguments ? JSON.parse(toolCall.function.arguments) : {} + }); + } + } + + // Add tool_result messages (user providing results) + for (const toolCall of toolCalls) { + if (toolCall.id && toolCall.function?.name) { + const result = toolResults.get(toolCall.id); + newMessages.push({ + actor: 'user', + type: 'tool_result', + tool_use_id: toolCall.id, + name: toolCall.function.name, + content: result + }); + } + } + + return newMessages; + } +} + diff --git a/packages/ai-chat/src/browser/chat-session-store-impl.ts b/packages/ai-chat/src/browser/chat-session-store-impl.ts index f4ee2a062ce17..795c264674033 100644 --- a/packages/ai-chat/src/browser/chat-session-store-impl.ts +++ b/packages/ai-chat/src/browser/chat-session-store-impl.ts @@ -27,6 +27,7 @@ import { ChatModel } from '../common/chat-model'; import { ChatSessionIndex, ChatSessionStore, ChatModelWithMetadata, ChatSessionMetadata } from '../common/chat-session-store'; import { PERSISTED_SESSION_LIMIT_PREF } from '../common/ai-chat-preferences'; import { SerializedChatData, CHAT_DATA_VERSION } from '../common/chat-model-serialization'; +import { ChatSessionTokenTracker } from './chat-session-token-tracker'; const INDEX_FILE = 'index.json'; @@ -50,6 +51,9 @@ export class ChatSessionStoreImpl implements ChatSessionStore { @inject(PreferenceService) protected readonly preferenceService: PreferenceService; + @inject(ChatSessionTokenTracker) + protected readonly tokenTracker: ChatSessionTokenTracker; + protected storageRoot?: URI; protected indexCache?: ChatSessionIndex; protected storePromise: Promise = Promise.resolve(); @@ -75,7 +79,8 @@ export class ChatSessionStoreImpl implements ChatSessionStore { title: session.title, pinnedAgentId: session.pinnedAgentId, saveDate: session.saveDate, - model: modelData + model: modelData, + lastInputTokens: this.tokenTracker.getSessionInputTokens(session.model.id) }; this.logger.debug('Writing session to file', { sessionId: session.model.id, diff --git a/packages/ai-chat/src/browser/chat-session-summarization-service.ts b/packages/ai-chat/src/browser/chat-session-summarization-service.ts new file mode 100644 index 0000000000000..e2e9aff5da2c1 --- /dev/null +++ b/packages/ai-chat/src/browser/chat-session-summarization-service.ts @@ -0,0 +1,215 @@ +// ***************************************************************************** +// Copyright (C) 2025 EclipseSource GmbH. +// +// This program and the accompanying materials are made available under the +// terms of the Eclipse Public License v. 2.0 which is available at +// http://www.eclipse.org/legal/epl-2.0. +// +// This Source Code may also be made available under the following Secondary +// Licenses when the conditions for such availability set forth in the Eclipse +// Public License v. 2.0 are satisfied: GNU General Public License, version 2 +// with the GNU Classpath Exception which is available at +// https://www.gnu.org/software/classpath/license.html. +// +// SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-only WITH Classpath-exception-2.0 +// ***************************************************************************** + +import { inject, injectable, postConstruct } from '@theia/core/shared/inversify'; +import { Disposable, ILogger } from '@theia/core'; +import { FrontendApplicationContribution } from '@theia/core/lib/browser'; +import { AgentService, TokenUsageServiceClient } from '@theia/ai-core'; +import { + ChatAgent, + ChatService, + MutableChatModel, + MutableChatRequestModel +} from '../common'; +import { ChatSessionSummaryAgent } from '../common/chat-session-summary-agent'; +import { + ChatSessionTokenTracker, + SessionTokenThresholdEvent +} from './chat-session-token-tracker'; + +export const ChatSessionSummarizationService = Symbol('ChatSessionSummarizationService'); + +/** + * Service that automatically summarizes chat sessions when token usage exceeds the threshold. + * + * When the threshold is exceeded: + * 1. Marks older messages as stale (excluding them from future prompts) + * 2. Invokes ChatSessionSummaryAgent to generate a summary + * 3. Inserts a summary node into the chat + */ +export interface ChatSessionSummarizationService { + /** + * Check if a session has been summarized (has stale messages). + */ + hasSummary(sessionId: string): boolean; + + /** + * Trigger summarization for a session. + * Called by the budget-aware tool loop when token threshold is exceeded mid-turn. + * + * @param sessionId The session to summarize + * @param skipReorder If true, skip removing/re-adding the trigger request (for mid-turn summarization) + * @returns Promise that resolves with the summary text on success, or `undefined` on failure + */ + triggerSummarization(sessionId: string, skipReorder: boolean): Promise; +} + +@injectable() +export class ChatSessionSummarizationServiceImpl implements ChatSessionSummarizationService, FrontendApplicationContribution { + @inject(ChatSessionTokenTracker) + protected readonly tokenTracker: ChatSessionTokenTracker; + + @inject(ChatService) + protected readonly chatService: ChatService; + + @inject(AgentService) + protected readonly agentService: AgentService; + + @inject(ILogger) + protected readonly logger: ILogger; + + @inject(TokenUsageServiceClient) + protected readonly tokenUsageClient: TokenUsageServiceClient; + + /** + * Set of sessionIds currently being summarized to prevent concurrent summarization. + */ + protected summarizingSession = new Set(); + + @postConstruct() + protected init(): void { + this.tokenTracker.onThresholdExceeded(event => this.handleThresholdExceeded(event)); + } + + /** + * Called when the frontend application starts. + * Required by FrontendApplicationContribution to ensure this service is instantiated. + */ + onStart(): void { + // Service initialization is handled in @postConstruct + } + + async triggerSummarization(sessionId: string, skipReorder: boolean): Promise { + const session = this.chatService.getSession(sessionId); + if (!session) { + this.logger.warn(`Session ${sessionId} not found for summarization`); + return undefined; + } + + this.logger.info(`Mid-turn summarization triggered for session ${sessionId}`); + return this.performSummarization(sessionId, session.model as MutableChatModel, skipReorder); + } + + protected async handleThresholdExceeded(event: SessionTokenThresholdEvent): Promise { + const { sessionId, inputTokens } = event; + + if (this.summarizingSession.has(sessionId)) { + return; + } + + const session = this.chatService.getSession(sessionId); + if (!session) { + this.logger.warn(`Session ${sessionId} not found for summarization`); + return; + } + + this.logger.info(`Token threshold exceeded for session ${sessionId}: ${inputTokens} tokens. Starting summarization...`); + await this.performSummarization(sessionId, session.model as MutableChatModel); + } + + /** + * Core summarization logic shared by both threshold-triggered and explicit mid-turn summarization. + * + * @param skipReorder If true, skip removing/re-adding the trigger request (for mid-turn summarization + * where the request is actively being processed with tool calls) + * @returns The summary text on success, or `undefined` on failure + */ + protected async performSummarization(sessionId: string, model: MutableChatModel, skipReorder?: boolean): Promise { + if (this.summarizingSession.has(sessionId)) { + return undefined; + } + + this.summarizingSession.add(sessionId); + + try { + const position = skipReorder ? 'end' : 'beforeLast'; + + const summaryText = await model.insertSummary( + async summaryRequest => { + // Find and invoke the summary agent + const agent = this.agentService.getAgents().find( + (candidate): candidate is ChatAgent => + 'invoke' in candidate && + typeof candidate.invoke === 'function' && + candidate.id === ChatSessionSummaryAgent.ID + ); + + if (!agent) { + this.logger.error('ChatSessionSummaryAgent not found'); + return undefined; + } + + // Set up listener to capture token usage + let capturedInputTokens: number | undefined; + const tokenUsageListener: Disposable = this.tokenUsageClient.onTokenUsageUpdated(usage => { + if (usage.requestId === summaryRequest.id) { + capturedInputTokens = usage.inputTokens; + } + }); + + try { + await agent.invoke(summaryRequest); + } finally { + tokenUsageListener.dispose(); + } + + // Store captured tokens for later use + if (capturedInputTokens !== undefined) { + summaryRequest.addData('capturedInputTokens', capturedInputTokens); + } + + return summaryRequest.response.response.asDisplayString(); + }, + position + ); + + if (!summaryText) { + this.logger.warn(`Summarization failed for session ${sessionId}`); + return undefined; + } + + this.logger.info(`Added summary node to session ${sessionId}`); + + // Reset token count using captured tokens + const lastSummaryRequest = model.getRequests().find(r => r.request.kind === 'summary'); + const capturedTokens = lastSummaryRequest?.getDataByKey('capturedInputTokens'); + if (capturedTokens !== undefined) { + this.tokenTracker.resetSessionTokens(sessionId, capturedTokens); + this.tokenTracker.resetThresholdTrigger(sessionId); + this.logger.info(`Reset token count for session ${sessionId} to ${capturedTokens} tokens`); + } + + return summaryText; + + } catch (error) { + this.logger.error(`Failed to summarize session ${sessionId}:`, error); + return undefined; + } finally { + this.summarizingSession.delete(sessionId); + } + } + + hasSummary(sessionId: string): boolean { + const session = this.chatService.getSession(sessionId); + if (!session) { + return false; + } + + const model = session.model as MutableChatModel; + return model.getRequests().some(r => (r as MutableChatRequestModel).isStale); + } + +} diff --git a/packages/ai-chat/src/browser/chat-session-token-restore-contribution.ts b/packages/ai-chat/src/browser/chat-session-token-restore-contribution.ts new file mode 100644 index 0000000000000..c366503e137e5 --- /dev/null +++ b/packages/ai-chat/src/browser/chat-session-token-restore-contribution.ts @@ -0,0 +1,42 @@ +// ***************************************************************************** +// Copyright (C) 2025 EclipseSource GmbH. +// +// This program and the accompanying materials are made available under the +// terms of the Eclipse Public License v. 2.0 which is available at +// http://www.eclipse.org/legal/epl-2.0. +// +// This Source Code may also be made available under the following Secondary +// Licenses when the conditions for such availability set forth in the Eclipse +// Public License v. 2.0 are satisfied: GNU General Public License, version 2 +// with the GNU Classpath Exception which is available at +// https://www.gnu.org/software/classpath/license.html. +// +// SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-only WITH Classpath-exception-2.0 +// ***************************************************************************** + +import { inject, injectable } from '@theia/core/shared/inversify'; +import { FrontendApplicationContribution } from '@theia/core/lib/browser'; +import { ChatService, isSessionCreatedEvent } from '../common/chat-service'; +import { ChatSessionTokenTracker } from '../common/chat-session-token-tracker'; + +/** + * Contribution that wires ChatService session events to the token tracker. + * This breaks the circular dependency between ChatService and ChatSessionTokenTracker + * by deferring the wiring until after both services are fully constructed. + */ +@injectable() +export class ChatSessionTokenRestoreContribution implements FrontendApplicationContribution { + @inject(ChatService) + protected readonly chatService: ChatService; + + @inject(ChatSessionTokenTracker) + protected readonly tokenTracker: ChatSessionTokenTracker; + + onStart(): void { + this.chatService.onSessionEvent(event => { + if (isSessionCreatedEvent(event) && event.tokenCount !== undefined) { + this.tokenTracker.resetSessionTokens(event.sessionId, event.tokenCount); + } + }); + } +} diff --git a/packages/ai-chat/src/browser/chat-session-token-tracker.spec.ts b/packages/ai-chat/src/browser/chat-session-token-tracker.spec.ts new file mode 100644 index 0000000000000..2adababb48c65 --- /dev/null +++ b/packages/ai-chat/src/browser/chat-session-token-tracker.spec.ts @@ -0,0 +1,207 @@ +// ***************************************************************************** +// Copyright (C) 2025 EclipseSource GmbH. +// +// This program and the accompanying materials are made available under the +// terms of the Eclipse Public License v. 2.0 which is available at +// http://www.eclipse.org/legal/epl-2.0. +// +// This Source Code may also be made available under the following Secondary +// Licenses when the conditions for such availability set forth in the Eclipse +// Public License v. 2.0 are satisfied: GNU General Public License, version 2 +// with the GNU Classpath Exception which is available at +// https://www.gnu.org/software/classpath/license.html. +// +// SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-only WITH Classpath-exception-2.0 +// ***************************************************************************** + +import { expect } from 'chai'; +import * as sinon from 'sinon'; +import { Container } from '@theia/core/shared/inversify'; +import { Emitter } from '@theia/core'; +import { TokenUsageServiceClient, TokenUsage } from '@theia/ai-core/lib/common'; +import { ChatSessionTokenTrackerImpl, CHAT_TOKEN_THRESHOLD } from './chat-session-token-tracker'; +import { SessionTokenThresholdEvent, SessionTokenUpdateEvent } from '../common/chat-session-token-tracker'; + +describe('ChatSessionTokenTrackerImpl', () => { + let container: Container; + let tracker: ChatSessionTokenTrackerImpl; + let mockTokenUsageEmitter: Emitter; + let mockTokenUsageClient: TokenUsageServiceClient; + + const createTokenUsage = (sessionId: string | undefined, inputTokens: number, requestId: string): TokenUsage => ({ + sessionId, + inputTokens, + outputTokens: 100, + requestId, + model: 'test-model', + timestamp: new Date() + }); + + beforeEach(() => { + container = new Container(); + + // Create a mock TokenUsageServiceClient with controllable event emitter + mockTokenUsageEmitter = new Emitter(); + mockTokenUsageClient = { + notifyTokenUsage: sinon.stub(), + onTokenUsageUpdated: mockTokenUsageEmitter.event + }; + + // Bind dependencies + container.bind(TokenUsageServiceClient).toConstantValue(mockTokenUsageClient); + container.bind(ChatSessionTokenTrackerImpl).toSelf().inSingletonScope(); + + tracker = container.get(ChatSessionTokenTrackerImpl); + }); + + afterEach(() => { + mockTokenUsageEmitter.dispose(); + sinon.restore(); + }); + + describe('getSessionInputTokens', () => { + it('should return correct token count after usage is reported', () => { + const sessionId = 'session-1'; + const inputTokens = 5000; + + mockTokenUsageEmitter.fire(createTokenUsage(sessionId, inputTokens, 'request-1')); + + expect(tracker.getSessionInputTokens(sessionId)).to.equal(inputTokens); + }); + + it('should return undefined for unknown session', () => { + expect(tracker.getSessionInputTokens('unknown-session')).to.be.undefined; + }); + }); + + describe('onThresholdExceeded', () => { + it('should fire when tokens exceed threshold', () => { + const sessionId = 'session-1'; + const inputTokens = CHAT_TOKEN_THRESHOLD + 1000; + const thresholdEvents: SessionTokenThresholdEvent[] = []; + + tracker.onThresholdExceeded(event => thresholdEvents.push(event)); + + mockTokenUsageEmitter.fire(createTokenUsage(sessionId, inputTokens, 'request-1')); + + expect(thresholdEvents).to.have.length(1); + expect(thresholdEvents[0].sessionId).to.equal(sessionId); + expect(thresholdEvents[0].inputTokens).to.equal(inputTokens); + }); + + it('should not fire when tokens are below threshold', () => { + const sessionId = 'session-1'; + const inputTokens = CHAT_TOKEN_THRESHOLD - 1000; + const thresholdEvents: SessionTokenThresholdEvent[] = []; + + tracker.onThresholdExceeded(event => thresholdEvents.push(event)); + + mockTokenUsageEmitter.fire(createTokenUsage(sessionId, inputTokens, 'request-1')); + + expect(thresholdEvents).to.have.length(0); + }); + + it('should not fire twice for the same session without reset', () => { + const sessionId = 'session-1'; + const thresholdEvents: SessionTokenThresholdEvent[] = []; + + tracker.onThresholdExceeded(event => thresholdEvents.push(event)); + + // First token usage exceeding threshold + mockTokenUsageEmitter.fire(createTokenUsage(sessionId, CHAT_TOKEN_THRESHOLD + 1000, 'request-1')); + + // Second token usage exceeding threshold (should not trigger again) + mockTokenUsageEmitter.fire(createTokenUsage(sessionId, CHAT_TOKEN_THRESHOLD + 2000, 'request-2')); + + expect(thresholdEvents).to.have.length(1); + }); + }); + + describe('resetThresholdTrigger', () => { + it('should allow re-triggering after resetThresholdTrigger is called', () => { + const sessionId = 'session-1'; + const thresholdEvents: SessionTokenThresholdEvent[] = []; + + tracker.onThresholdExceeded(event => thresholdEvents.push(event)); + + // First token usage exceeding threshold + mockTokenUsageEmitter.fire(createTokenUsage(sessionId, CHAT_TOKEN_THRESHOLD + 1000, 'request-1')); + + expect(thresholdEvents).to.have.length(1); + + // Reset the threshold trigger (simulating summarization completion) + tracker.resetThresholdTrigger(sessionId); + + // Second token usage exceeding threshold should trigger again + mockTokenUsageEmitter.fire(createTokenUsage(sessionId, CHAT_TOKEN_THRESHOLD + 3000, 'request-2')); + + expect(thresholdEvents).to.have.length(2); + expect(thresholdEvents[0].sessionId).to.equal(sessionId); + expect(thresholdEvents[1].sessionId).to.equal(sessionId); + }); + + it('should not affect other sessions', () => { + const sessionId1 = 'session-1'; + const sessionId2 = 'session-2'; + const thresholdEvents: SessionTokenThresholdEvent[] = []; + + tracker.onThresholdExceeded(event => thresholdEvents.push(event)); + + // Trigger threshold for session 1 + mockTokenUsageEmitter.fire(createTokenUsage(sessionId1, CHAT_TOKEN_THRESHOLD + 1000, 'request-1')); + + // Trigger threshold for session 2 + mockTokenUsageEmitter.fire(createTokenUsage(sessionId2, CHAT_TOKEN_THRESHOLD + 1000, 'request-2')); + + expect(thresholdEvents).to.have.length(2); + + // Reset only session 1 + tracker.resetThresholdTrigger(sessionId1); + + // Session 1 should be able to trigger again + mockTokenUsageEmitter.fire(createTokenUsage(sessionId1, CHAT_TOKEN_THRESHOLD + 2000, 'request-3')); + + // Session 2 should not trigger again (not reset) + mockTokenUsageEmitter.fire(createTokenUsage(sessionId2, CHAT_TOKEN_THRESHOLD + 2000, 'request-4')); + + expect(thresholdEvents).to.have.length(3); + expect(thresholdEvents[2].sessionId).to.equal(sessionId1); + }); + }); + + describe('resetSessionTokens', () => { + it('should update token count and fire onSessionTokensUpdated', () => { + const sessionId = 'session-1'; + const updateEvents: SessionTokenUpdateEvent[] = []; + + tracker.onSessionTokensUpdated(event => updateEvents.push(event)); + + // Set initial token count + mockTokenUsageEmitter.fire(createTokenUsage(sessionId, 50000, 'request-1')); + + expect(tracker.getSessionInputTokens(sessionId)).to.equal(50000); + expect(updateEvents).to.have.length(1); + + // Reset to new baseline (simulating post-summarization) + const newTokenCount = 10000; + tracker.resetSessionTokens(sessionId, newTokenCount); + + expect(tracker.getSessionInputTokens(sessionId)).to.equal(newTokenCount); + expect(updateEvents).to.have.length(2); + expect(updateEvents[1].sessionId).to.equal(sessionId); + expect(updateEvents[1].inputTokens).to.equal(newTokenCount); + }); + }); + + describe('token usage handling', () => { + it('should ignore token usage without sessionId', () => { + const updateEvents: SessionTokenUpdateEvent[] = []; + + tracker.onSessionTokensUpdated(event => updateEvents.push(event)); + + mockTokenUsageEmitter.fire(createTokenUsage(undefined, 5000, 'request-1')); + + expect(updateEvents).to.have.length(0); + }); + }); +}); diff --git a/packages/ai-chat/src/browser/chat-session-token-tracker.ts b/packages/ai-chat/src/browser/chat-session-token-tracker.ts new file mode 100644 index 0000000000000..1bc86679bebaf --- /dev/null +++ b/packages/ai-chat/src/browser/chat-session-token-tracker.ts @@ -0,0 +1,105 @@ +// ***************************************************************************** +// Copyright (C) 2025 EclipseSource GmbH. +// +// This program and the accompanying materials are made available under the +// terms of the Eclipse Public License v. 2.0 which is available at +// http://www.eclipse.org/legal/epl-2.0. +// +// This Source Code may also be made available under the following Secondary +// Licenses when the conditions for such availability set forth in the Eclipse +// Public License v. 2.0 are satisfied: GNU General Public License, version 2 +// with the GNU Classpath Exception which is available at +// https://www.gnu.org/software/classpath/license.html. +// +// SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-only WITH Classpath-exception-2.0 +// ***************************************************************************** + +import { inject, injectable, postConstruct } from '@theia/core/shared/inversify'; +import { TokenUsageServiceClient, TokenUsage } from '@theia/ai-core/lib/common'; +import { Emitter } from '@theia/core'; +import { ChatSessionTokenTracker, SessionTokenThresholdEvent, SessionTokenUpdateEvent } from '../common/chat-session-token-tracker'; + +// Re-export from common for backwards compatibility +export { ChatSessionTokenTracker, SessionTokenUpdateEvent, SessionTokenThresholdEvent } from '../common/chat-session-token-tracker'; + +/** + * Hardcoded token budget and threshold for chat sessions. + */ +export const CHAT_TOKEN_BUDGET = 200000; +export const CHAT_TOKEN_THRESHOLD_PERCENT = 0.9; +export const CHAT_TOKEN_THRESHOLD = CHAT_TOKEN_BUDGET * CHAT_TOKEN_THRESHOLD_PERCENT; + +@injectable() +export class ChatSessionTokenTrackerImpl implements ChatSessionTokenTracker { + @inject(TokenUsageServiceClient) + protected readonly tokenUsageClient: TokenUsageServiceClient; + + protected readonly onThresholdExceededEmitter = new Emitter(); + readonly onThresholdExceeded = this.onThresholdExceededEmitter.event; + + protected readonly onSessionTokensUpdatedEmitter = new Emitter(); + readonly onSessionTokensUpdated = this.onSessionTokensUpdatedEmitter.event; + + /** + * Map of sessionId -> latest inputTokens count. + * Updated when token usage is reported for requests in that session. + */ + protected sessionTokens = new Map(); + + /** + * Set of sessionIds that have already triggered the threshold event. + * Prevents multiple triggers for the same session. + */ + protected triggeredSessions = new Set(); + + @postConstruct() + protected init(): void { + this.tokenUsageClient.onTokenUsageUpdated(usage => this.handleTokenUsage(usage)); + } + + protected handleTokenUsage(usage: TokenUsage): void { + const { sessionId, inputTokens } = usage; + + if (!sessionId) { + return; // Can't track without sessionId + } + + // Update the session's token count + this.sessionTokens.set(sessionId, inputTokens); + + // Fire the token update event + this.onSessionTokensUpdatedEmitter.fire({ sessionId, inputTokens }); + + // Check if threshold is exceeded and we haven't already triggered + if (inputTokens >= CHAT_TOKEN_THRESHOLD && !this.triggeredSessions.has(sessionId)) { + this.triggeredSessions.add(sessionId); + this.onThresholdExceededEmitter.fire({ + sessionId, + inputTokens + }); + } + } + + getSessionInputTokens(sessionId: string): number | undefined { + return this.sessionTokens.get(sessionId); + } + + /** + * Reset the triggered state for a session. + * Called after summarization is complete to allow future triggers + * if the session continues to grow. + */ + resetThresholdTrigger(sessionId: string): void { + this.triggeredSessions.delete(sessionId); + } + + /** + * Reset the session's token count to a new baseline. + * Called after summarization to reflect the reduced token usage. + * The new count should reflect only the summary + any non-stale messages. + */ + resetSessionTokens(sessionId: string, newTokenCount: number): void { + this.sessionTokens.set(sessionId, newTokenCount); + this.onSessionTokensUpdatedEmitter.fire({ sessionId, inputTokens: newTokenCount }); + } +} diff --git a/packages/ai-chat/src/browser/index.ts b/packages/ai-chat/src/browser/index.ts new file mode 100644 index 0000000000000..67734e22c894f --- /dev/null +++ b/packages/ai-chat/src/browser/index.ts @@ -0,0 +1,17 @@ +// ***************************************************************************** +// Copyright (C) 2025 EclipseSource GmbH. +// +// This program and the accompanying materials are made available under the +// terms of the Eclipse Public License v. 2.0 which is available at +// http://www.eclipse.org/legal/epl-2.0. +// +// This Source Code may also be made available under the following Secondary +// Licenses when the conditions for such availability set forth in the Eclipse +// Public License v. 2.0 are satisfied: GNU General Public License, version 2 +// with the GNU Classpath Exception which is available at +// https://www.gnu.org/software/classpath/license.html. +// +// SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-only WITH Classpath-exception-2.0 +// ***************************************************************************** + +export * from './chat-session-token-tracker'; diff --git a/packages/ai-chat/src/common/ai-chat-preferences.ts b/packages/ai-chat/src/common/ai-chat-preferences.ts index bd40ccb1aaac6..9a3226b23aafd 100644 --- a/packages/ai-chat/src/common/ai-chat-preferences.ts +++ b/packages/ai-chat/src/common/ai-chat-preferences.ts @@ -20,6 +20,7 @@ import { nls, PreferenceSchema } from '@theia/core'; export const DEFAULT_CHAT_AGENT_PREF = 'ai-features.chat.defaultChatAgent'; export const PIN_CHAT_AGENT_PREF = 'ai-features.chat.pinChatAgent'; export const BYPASS_MODEL_REQUIREMENT_PREF = 'ai-features.chat.bypassModelRequirement'; +export const BUDGET_AWARE_TOOL_LOOP_PREF = 'ai-features.chat.experimentalBudgetAwareToolLoop'; export const PERSISTED_SESSION_LIMIT_PREF = 'ai-features.chat.persistedSessionLimit'; export const aiChatPreferences: PreferenceSchema = { @@ -46,6 +47,15 @@ export const aiChatPreferences: PreferenceSchema = { default: false, title: AI_CORE_PREFERENCES_TITLE, }, + [BUDGET_AWARE_TOOL_LOOP_PREF]: { + type: 'boolean', + description: nls.localize('theia/ai/chat/budgetAwareToolLoop/description', + 'Experimental: Enable budget-aware tool loop. When enabled, the chat agent can trigger summarization mid-turn \ +if the token budget is exceeded during tool call loops. This prevents API errors from context overflow. \ +Requires language model support (currently only Anthropic models).'), + default: false, + title: AI_CORE_PREFERENCES_TITLE, + }, [PERSISTED_SESSION_LIMIT_PREF]: { type: 'number', description: nls.localize('theia/ai/chat/persistedSessionLimit/description', diff --git a/packages/ai-chat/src/common/chat-agents.ts b/packages/ai-chat/src/common/chat-agents.ts index 752e2e5bf0191..773512f90ad6d 100644 --- a/packages/ai-chat/src/common/chat-agents.ts +++ b/packages/ai-chat/src/common/chat-agents.ts @@ -291,6 +291,11 @@ export abstract class AbstractChatAgent implements ChatAgent { model: ChatModel, includeResponseInProgress = false ): Promise { const requestMessages = model.getRequests().flatMap(request => { + // Skip stale requests entirely - their content is replaced by summary nodes + if (request.isStale === true) { + return []; + } + const messages: LanguageModelMessage[] = []; const text = request.message.parts.map(part => part.promptText).join(''); if (text.length > 0) { diff --git a/packages/ai-chat/src/common/chat-auto-save.spec.ts b/packages/ai-chat/src/common/chat-auto-save.spec.ts index db87465db5f0c..6e000bda4e99a 100644 --- a/packages/ai-chat/src/common/chat-auto-save.spec.ts +++ b/packages/ai-chat/src/common/chat-auto-save.spec.ts @@ -372,4 +372,5 @@ describe('Chat Auto-Save Mechanism', () => { expect(sessionStore.saveCount).to.be.greaterThan(0); }); }); + }); diff --git a/packages/ai-chat/src/common/chat-content-deserializer.spec.ts b/packages/ai-chat/src/common/chat-content-deserializer.spec.ts index 656847d3f9d19..02ab44832fb3a 100644 --- a/packages/ai-chat/src/common/chat-content-deserializer.spec.ts +++ b/packages/ai-chat/src/common/chat-content-deserializer.spec.ts @@ -30,6 +30,7 @@ import { MarkdownChatResponseContentImpl, ProgressChatResponseContentImpl, QuestionResponseContentImpl, + SummaryChatResponseContentImpl, TextChatResponseContentImpl, ThinkingChatResponseContentImpl, ToolCallChatResponseContentImpl @@ -347,6 +348,36 @@ describe('Chat Content Serialization', () => { }); }); + describe('SummaryChatResponseContentImpl', () => { + it('should serialize and deserialize correctly', async () => { + const original = new SummaryChatResponseContentImpl('This is a summary of the conversation.'); + const serialized = original.toSerializable?.(); + + expect(serialized).to.not.be.undefined; + expect(serialized!.kind).to.equal('summary'); + expect(serialized!.data).to.deep.equal({ content: 'This is a summary of the conversation.' }); + + // Simulate caller populating fallbackMessage + const withFallback = { + ...serialized!, + fallbackMessage: original.asString?.() || original.toString() + }; + + const deserialized = await registry.deserialize(withFallback); + expect(deserialized.kind).to.equal('summary'); + expect(deserialized.asString?.()).to.equal('This is a summary of the conversation.'); + }); + + it('should include summary prefix in language model message', () => { + const original = new SummaryChatResponseContentImpl('Summary content'); + const message = original.toLanguageModelMessage(); + + expect(message.type).to.equal('text'); + expect(message.text).to.include('[Summary of previous conversation]'); + expect(message.text).to.include('Summary content'); + }); + }); + describe('ChatContentDeserializerRegistry', () => { it('should handle unknown content types with fallback', async () => { const unknownContent = { diff --git a/packages/ai-chat/src/common/chat-content-deserializer.ts b/packages/ai-chat/src/common/chat-content-deserializer.ts index 7f862b832f222..62bd9160c0632 100644 --- a/packages/ai-chat/src/common/chat-content-deserializer.ts +++ b/packages/ai-chat/src/common/chat-content-deserializer.ts @@ -25,6 +25,7 @@ import { MarkdownChatResponseContentImpl, ProgressChatResponseContentImpl, QuestionResponseContentImpl, + SummaryChatResponseContentImpl, TextChatResponseContentImpl, ThinkingChatResponseContentImpl, ToolCallChatResponseContentImpl, @@ -39,7 +40,8 @@ import { HorizontalLayoutContentData, ProgressContentData, ErrorContentData, - QuestionContentData + QuestionContentData, + SummaryContentData } from './chat-model'; import { SerializableChatResponseContentData } from './chat-model-serialization'; import { ContributionProvider, ILogger, MaybePromise } from '@theia/core'; @@ -324,5 +326,10 @@ export class DefaultChatContentDeserializerContribution implements ChatContentDe data.selectedOption ) }); + + registry.register({ + kind: 'summary', + deserialize: (data: SummaryContentData) => new SummaryChatResponseContentImpl(data.content) + }); } } diff --git a/packages/ai-chat/src/common/chat-model-insert-summary.spec.ts b/packages/ai-chat/src/common/chat-model-insert-summary.spec.ts new file mode 100644 index 0000000000000..c4fdaea379063 --- /dev/null +++ b/packages/ai-chat/src/common/chat-model-insert-summary.spec.ts @@ -0,0 +1,319 @@ +// ***************************************************************************** +// Copyright (C) 2025 EclipseSource GmbH. +// +// This program and the accompanying materials are made available under the +// terms of the Eclipse Public License v. 2.0 which is available at +// http://www.eclipse.org/legal/epl-2.0. +// +// This Source Code may also be made available under the following Secondary +// Licenses when the conditions for such availability set forth in the Eclipse +// Public License v. 2.0 are satisfied: GNU General Public License, version 2 +// with the GNU Classpath Exception which is available at +// https://www.gnu.org/software/classpath/license.html. +// +// SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-only WITH Classpath-exception-2.0 +// ***************************************************************************** + +import { expect } from 'chai'; +import { ChatAgentLocation } from './chat-agents'; +import { ChatResponseContent, MutableChatModel, MutableChatRequestModel, SummaryChatResponseContent, TextChatResponseContentImpl } from './chat-model'; +import { ParsedChatRequest } from './parsed-chat-request'; + +describe('MutableChatModel.insertSummary()', () => { + + function createParsedRequest(text: string): ParsedChatRequest { + return { + request: { text }, + parts: [{ + kind: 'text', + text, + promptText: text, + range: { start: 0, endExclusive: text.length } + }], + toolRequests: new Map(), + variables: [] + }; + } + + function createModelWithRequests(count: number): MutableChatModel { + const model = new MutableChatModel(ChatAgentLocation.Panel); + for (let i = 1; i <= count; i++) { + const req = model.addRequest(createParsedRequest(`Request ${i}`)); + req.response.complete(); + } + return model; + } + + describe('basic functionality', () => { + it('should return undefined when model has less than 2 requests', async () => { + const model = new MutableChatModel(ChatAgentLocation.Panel); + model.addRequest(createParsedRequest('Single request')); + + const result = await model.insertSummary( + async () => 'Summary text', + 'end' + ); + + expect(result).to.be.undefined; + }); + + it('should return undefined when model is empty', async () => { + const model = new MutableChatModel(ChatAgentLocation.Panel); + + const result = await model.insertSummary( + async () => 'Summary text', + 'end' + ); + + expect(result).to.be.undefined; + }); + + it('should return summary text on success', async () => { + const model = createModelWithRequests(3); + + const result = await model.insertSummary( + async () => 'This is a summary', + 'end' + ); + + expect(result).to.equal('This is a summary'); + }); + }); + + describe('position: end', () => { + it('should append summary at the end', async () => { + const model = createModelWithRequests(3); + + await model.insertSummary( + async () => 'Summary text', + 'end' + ); + + const requests = model.getRequests(); + // Should have 4 requests: 3 original + 1 summary + expect(requests).to.have.lengthOf(4); + expect(requests[3].request.kind).to.equal('summary'); + }); + + it('should mark all requests except the last as stale', async () => { + const model = createModelWithRequests(3); + + await model.insertSummary( + async () => 'Summary text', + 'end' + ); + + const requests = model.getRequests(); + // Requests 1-2 (indices 0-1) should be stale, request 3 should not be + expect(requests[0].isStale).to.be.true; + expect(requests[1].isStale).to.be.true; + expect(requests[2].isStale).to.be.false; + // Summary request (index 3) should also not be stale + expect(requests[3].isStale).to.be.false; + }); + + it('should create SummaryChatResponseContent in response', async () => { + const model = createModelWithRequests(2); + + await model.insertSummary( + async () => 'The conversation summary', + 'end' + ); + + const summaryRequest = model.getRequests().find(r => r.request.kind === 'summary'); + expect(summaryRequest).to.not.be.undefined; + + const content = summaryRequest!.response.response.content; + expect(content).to.have.lengthOf(1); + expect(SummaryChatResponseContent.is(content[0])).to.be.true; + expect((content[0] as SummaryChatResponseContent).content).to.equal('The conversation summary'); + }); + }); + + describe('position: beforeLast', () => { + it('should insert summary before the last request', async () => { + const model = createModelWithRequests(3); + const lastRequestId = model.getRequests()[2].id; + + await model.insertSummary( + async () => 'Summary text', + 'beforeLast' + ); + + const requests = model.getRequests(); + // Should have 4 requests: 3 original + 1 summary + expect(requests).to.have.lengthOf(4); + // Summary should be at index 2, original last request at index 3 + expect(requests[2].request.kind).to.equal('summary'); + expect(requests[3].id).to.equal(lastRequestId); + }); + + it('should preserve the trigger request identity (same object)', async () => { + const model = createModelWithRequests(3); + const originalLastRequest = model.getRequests()[2]; + const originalId = originalLastRequest.id; + + await model.insertSummary( + async () => 'Summary text', + 'beforeLast' + ); + + const readdedRequest = model.getRequests()[3]; + // Should be the exact same object + expect(readdedRequest.id).to.equal(originalId); + }); + + it('should mark all requests except trigger as stale', async () => { + const model = createModelWithRequests(3); + const triggerRequestId = model.getRequests()[2].id; + + await model.insertSummary( + async () => 'Summary text', + 'beforeLast' + ); + + const requests = model.getRequests(); + // Requests 1-2 (indices 0-1) should be stale + expect(requests[0].isStale).to.be.true; + expect(requests[1].isStale).to.be.true; + // Summary request (index 2) should not be stale + expect(requests[2].isStale).to.be.false; + // Trigger request (index 3) should not be stale + expect(requests[3].isStale).to.be.false; + expect(requests[3].id).to.equal(triggerRequestId); + }); + }); + + describe('callback failure handling', () => { + it('should rollback on callback returning undefined (end position)', async () => { + const model = createModelWithRequests(3); + const originalRequestCount = model.getRequests().length; + + const result = await model.insertSummary( + async () => undefined, + 'end' + ); + + expect(result).to.be.undefined; + // Model should be unchanged + expect(model.getRequests()).to.have.lengthOf(originalRequestCount); + // Stale flags should be restored + model.getRequests().forEach(r => { + expect(r.isStale).to.be.false; + }); + }); + + it('should rollback on callback throwing error (end position)', async () => { + const model = createModelWithRequests(3); + const originalRequestCount = model.getRequests().length; + + const result = await model.insertSummary( + async () => { throw new Error('Agent failed'); }, + 'end' + ); + + expect(result).to.be.undefined; + // Model should be unchanged + expect(model.getRequests()).to.have.lengthOf(originalRequestCount); + // Stale flags should be restored + model.getRequests().forEach(r => { + expect(r.isStale).to.be.false; + }); + }); + + it('should rollback on callback failure (beforeLast position)', async () => { + const model = createModelWithRequests(3); + const originalRequestIds = model.getRequests().map(r => r.id); + + const result = await model.insertSummary( + async () => undefined, + 'beforeLast' + ); + + expect(result).to.be.undefined; + // Should have same requests in same order + const currentRequestIds = model.getRequests().map(r => r.id); + expect(currentRequestIds).to.deep.equal(originalRequestIds); + // Stale flags should be restored + model.getRequests().forEach(r => { + expect(r.isStale).to.be.false; + }); + }); + + it('should restore trigger request on failure (beforeLast position)', async () => { + const model = createModelWithRequests(3); + const originalLastRequestId = model.getRequests()[2].id; + + const result = await model.insertSummary( + async () => { throw new Error('Agent failed'); }, + 'beforeLast' + ); + + expect(result).to.be.undefined; + // Trigger request should be back in position + const requests = model.getRequests(); + expect(requests).to.have.lengthOf(3); + expect(requests[2].id).to.equal(originalLastRequestId); + }); + }); + + describe('callback receives correct summaryRequest', () => { + it('should pass a valid MutableChatRequestModel to callback', async () => { + const model = createModelWithRequests(2); + let receivedRequest: MutableChatRequestModel | undefined; + + await model.insertSummary( + async summaryRequest => { + receivedRequest = summaryRequest; + return 'Summary'; + }, + 'end' + ); + + expect(receivedRequest).to.not.be.undefined; + expect(receivedRequest!.request.kind).to.equal('summary'); + expect(receivedRequest!.response).to.not.be.undefined; + }); + + it('should allow callback to use summaryRequest for agent invocation', async () => { + const model = createModelWithRequests(2); + let responseModified = false; + + await model.insertSummary( + async summaryRequest => { + // Simulate agent adding content to response + summaryRequest.response.response.addContent( + new TextChatResponseContentImpl('Agent response') as ChatResponseContent + ); + responseModified = true; + return summaryRequest.response.response.asDisplayString(); + }, + 'end' + ); + + expect(responseModified).to.be.true; + }); + }); + + describe('already stale requests', () => { + it('should not re-mark already stale requests', async () => { + const model = createModelWithRequests(4); + // Mark first request as already stale + model.getRequests()[0].isStale = true; + + await model.insertSummary( + async () => 'Summary', + 'end' + ); + + const requests = model.getRequests(); + // First request was already stale, should remain stale + expect(requests[0].isStale).to.be.true; + // Second and third requests should now be stale + expect(requests[1].isStale).to.be.true; + expect(requests[2].isStale).to.be.true; + // Fourth (last before summary) should not be stale + expect(requests[3].isStale).to.be.false; + }); + }); +}); diff --git a/packages/ai-chat/src/common/chat-model-serialization.spec.ts b/packages/ai-chat/src/common/chat-model-serialization.spec.ts index adc3e568d8b37..bb8e41a9965bc 100644 --- a/packages/ai-chat/src/common/chat-model-serialization.spec.ts +++ b/packages/ai-chat/src/common/chat-model-serialization.spec.ts @@ -35,6 +35,71 @@ describe('ChatModel Serialization and Restoration', () => { }; } + describe('isStale property serialization', () => { + it('should not include isStale when false (default)', () => { + const model = new MutableChatModel(ChatAgentLocation.Panel); + model.addRequest(createParsedRequest('Hello')); + + const serialized = model.toSerializable(); + + expect(serialized.requests[0].isStale).to.be.undefined; + }); + + it('should include isStale when true', () => { + const model = new MutableChatModel(ChatAgentLocation.Panel); + const req = model.addRequest(createParsedRequest('Hello')); + req.isStale = true; + + const serialized = model.toSerializable(); + + expect(serialized.requests[0].isStale).to.be.true; + }); + + it('should restore isStale flag correctly', () => { + const model1 = new MutableChatModel(ChatAgentLocation.Panel); + const req = model1.addRequest(createParsedRequest('Hello')); + req.isStale = true; + + const serialized = model1.toSerializable(); + const model2 = new MutableChatModel(serialized); + + expect(model2.getRequests()[0].isStale).to.be.true; + }); + + it('should default isStale to false when missing in serialized data', () => { + const serializedData = { + sessionId: 'test-session', + location: ChatAgentLocation.Panel, + hierarchy: { + rootBranchId: 'branch-root', + branches: { + 'branch-root': { + id: 'branch-root', + items: [{ requestId: 'request-1' }], + activeBranchIndex: 0 + } + } + }, + requests: [{ + id: 'request-1', + text: 'Hello' + // isStale is intentionally omitted + }], + responses: [{ + id: 'response-1', + requestId: 'request-1', + isComplete: true, + isError: false, + content: [] + }] + }; + + const model = new MutableChatModel(serializedData); + + expect(model.getRequests()[0].isStale).to.be.false; + }); + }); + describe('Simple tree serialization', () => { it('should serialize a chat with a single request', () => { const model = new MutableChatModel(ChatAgentLocation.Panel); diff --git a/packages/ai-chat/src/common/chat-model-serialization.ts b/packages/ai-chat/src/common/chat-model-serialization.ts index 2a9de4c222fcf..86edac1d6685c 100644 --- a/packages/ai-chat/src/common/chat-model-serialization.ts +++ b/packages/ai-chat/src/common/chat-model-serialization.ts @@ -15,6 +15,7 @@ // ***************************************************************************** import { ChatAgentLocation } from './chat-agents'; +import { ChatRequestKind } from './chat-model'; export interface SerializableChangeSetElement { kind?: string; @@ -41,6 +42,10 @@ export interface SerializableChatRequestData { id: string; text: string; agentId?: string; + /** The type of request: 'user' or 'summary'. Defaults to 'user' if not specified. */ + kind?: ChatRequestKind; + /** Indicates this request has been summarized and should be excluded from prompt construction */ + isStale?: boolean; changeSet?: { title: string; elements: SerializableChangeSetElement[]; @@ -126,6 +131,7 @@ export interface SerializedChatData { title?: string; model: SerializedChatModel; saveDate: number; + lastInputTokens?: number; } export interface SerializableChatsData { diff --git a/packages/ai-chat/src/common/chat-model.ts b/packages/ai-chat/src/common/chat-model.ts index 4ffd618d46c96..a77fd5b58453d 100644 --- a/packages/ai-chat/src/common/chat-model.ts +++ b/packages/ai-chat/src/common/chat-model.ts @@ -243,6 +243,8 @@ export interface ChangeSetDecoration { readonly additionalInfoSuffixIcon?: string[]; } +export type ChatRequestKind = 'user' | 'summary'; + export interface ChatRequest { readonly text: string; readonly displayText?: string; @@ -254,6 +256,10 @@ export interface ChatRequest { readonly referencedRequestId?: string; readonly variables?: readonly AIVariableResolutionRequest[]; readonly modeId?: string; + /** + * The type of request. Defaults to 'user' if not specified. + */ + readonly kind?: ChatRequestKind; } export interface ChatContext { @@ -269,6 +275,8 @@ export interface ChatRequestModel { readonly context: ChatContext; readonly agentId?: string; readonly data?: { [key: string]: unknown }; + /** Indicates this request has been summarized and should be excluded from prompt construction */ + readonly isStale?: boolean; toSerializable(): SerializableChatRequestData; } @@ -425,6 +433,10 @@ export interface ProgressContentData { message: string; } +export interface SummaryContentData { + content: string; +} + export interface ErrorContentData { message: string; stack?: string; @@ -497,6 +509,17 @@ export interface ProgressChatResponseContent message: string; } +/** + * A summary chat response content represents a condensed summary of previous chat messages. + * It is used when chat history exceeds token limits and older messages need to be summarized. + * The summary is displayed collapsed by default and is read-only. + */ +export interface SummaryChatResponseContent extends ChatResponseContent { + kind: 'summary'; + /** The summary text content */ + content: string; +} + export interface Location { uri: URI; position: Position; @@ -644,6 +667,17 @@ export namespace ProgressChatResponseContent { } } +export namespace SummaryChatResponseContent { + export function is(obj: unknown): obj is SummaryChatResponseContent { + return ( + ChatResponseContent.is(obj) && + obj.kind === 'summary' && + 'content' in obj && + typeof (obj as { content: unknown }).content === 'string' + ); + } +} + export type QuestionResponseHandler = ( selectedOption: { text: string, value?: string }, ) => void; @@ -826,6 +860,12 @@ export class MutableChatModel implements ChatModel, Disposable { for (const reqData of data.requests) { const respData = data.responses.find(r => r.requestId === reqData.id); const requestModel = new MutableChatRequestModel(this, reqData, respData); + // Subscribe to request changes and forward to model emitter (same as addRequest) + requestModel.onDidChange(event => { + if (!ChatChangeEvent.isChangeSetEvent(event)) { + this._onDidChangeEmitter.fire(event); + } + }, this, this.toDispose); requestMap.set(requestModel.id, requestModel); } @@ -901,6 +941,101 @@ export class MutableChatModel implements ChatModel, Disposable { return requestModel; } + /** + * Insert a summary into the model. + * Handles request reordering, stale marking, and summary content creation. + * + * @param summaryCallback - Callback that invokes the summary agent. + * Receives the summary request (already added to model). + * Should invoke the agent and return the summary text, or undefined on failure. + * @param position - 'end' appends summary at end, 'beforeLast' inserts before the last request + * @returns The summary text on success, or undefined on failure + */ + async insertSummary( + summaryCallback: (summaryRequest: MutableChatRequestModel) => Promise, + position: 'end' | 'beforeLast' + ): Promise { + const allRequests = this.getRequests(); + + // Need at least 2 requests to summarize + if (allRequests.length < 2) { + return undefined; + } + + // The request to preserve (most recent exchange, not summarized) + // Captured before any modifications - same for both position modes + const requestToPreserve = allRequests[allRequests.length - 1]; + + let triggerRequest: MutableChatRequestModel | undefined; + let triggerBranch: ChatHierarchyBranch | undefined; + + if (position === 'beforeLast') { + // Remove the last request temporarily - it will be re-added after the summary + triggerRequest = requestToPreserve; + triggerBranch = this.getBranch(triggerRequest.id); + if (triggerBranch) { + triggerBranch.remove(triggerRequest); + } + } + + // Identify which requests will be marked stale after successful summarization + // (all except the preserved one) + const requestsToMarkStale = allRequests.filter(r => !r.isStale && r !== requestToPreserve); + + // Create summary request + // Use the ChatSessionSummaryAgent.ID constant value directly to avoid circular dependency + const summaryRequest = this.addRequest({ + request: { + text: '', + kind: 'summary' + }, + parts: [], + toolRequests: new Map(), + variables: [] + }, 'chat-session-summary-agent'); + + // Call the callback to invoke the agent + // NOTE: Stale marking happens AFTER the callback so the summary agent can see all messages + let summaryText: string | undefined; + try { + summaryText = await summaryCallback(summaryRequest); + } catch (error) { + summaryText = undefined; + } + + if (!summaryText) { + // Rollback: remove summary request, re-add trigger if needed + const summaryBranch = this.getBranch(summaryRequest.id); + if (summaryBranch) { + summaryBranch.remove(summaryRequest); + } + if (position === 'beforeLast' && triggerRequest) { + this._hierarchy.append(triggerRequest as MutableChatRequestModel); + this._onDidChangeEmitter.fire({ kind: 'addRequest', request: triggerRequest }); + } + return undefined; + } + + // Success: mark requests as stale AFTER successful summarization + // This ensures the summary agent could see all messages when building the prompt + for (const request of requestsToMarkStale) { + request.isStale = true; + } + + // Update summary response with SummaryChatResponseContent + summaryRequest.response.response.clearContent(); + const summaryContent = new SummaryChatResponseContentImpl(summaryText); + summaryRequest.response.response.addContent(summaryContent); + + // Re-add trigger request if beforeLast + if (position === 'beforeLast' && triggerRequest) { + this._hierarchy.append(triggerRequest as MutableChatRequestModel); + this._onDidChangeEmitter.fire({ kind: 'addRequest', request: triggerRequest }); + } + + return summaryText; + } + protected getTargetForRequestAddition(request: ParsedChatRequest): (addendum: MutableChatRequestModel) => void { const requestId = request.request.referencedRequestId; const branch = requestId !== undefined && this._hierarchy.findBranch(requestId); @@ -1471,6 +1606,7 @@ export class MutableChatRequestModel implements ChatRequestModel, EditableChatRe protected _agentId?: string; protected _data: { [key: string]: unknown }; protected _isEditing = false; + protected _isStale = false; readonly message: ParsedChatRequest; protected readonly toDispose = new DisposableCollection(); @@ -1527,9 +1663,14 @@ export class MutableChatRequestModel implements ChatRequestModel, EditableChatRe respData?: SerializableChatResponseData ): void { this._id = reqData.id; - this._request = { text: reqData.text }; + this._request = { + text: reqData.text, + kind: reqData.kind // undefined for old sessions = default 'user' behavior + }; this._agentId = reqData.agentId; this._data = {}; + // Restore isStale flag, defaulting to false for backward compatibility + this._isStale = reqData.isStale ?? false; // Create minimal context this._context = { variables: [] }; @@ -1619,6 +1760,14 @@ export class MutableChatRequestModel implements ChatRequestModel, EditableChatRe return this._agentId; } + get isStale(): boolean { + return this._isStale; + } + + set isStale(value: boolean) { + this._isStale = value; + } + cancelEdit(): void { if (this.isEditing) { this._isEditing = false; @@ -1648,15 +1797,22 @@ export class MutableChatRequestModel implements ChatRequestModel, EditableChatRe } toSerializable(): SerializableChatRequestData { - return { + const result: SerializableChatRequestData = { id: this.id, text: this.request.text, agentId: this.agentId, + // Only include kind when not default to minimize payload + kind: this.request.kind !== 'user' ? this.request.kind : undefined, changeSet: this._changeSet ? { title: this._changeSet.title, elements: this._changeSet.getElements().map(elem => elem.toSerializable?.()).filter((elem): elem is SerializableChangeSetElement => elem !== undefined) } : undefined }; + // Only include isStale when true to minimize payload + if (this._isStale) { + result.isStale = true; + } + return result; } dispose(): void { @@ -2605,6 +2761,47 @@ export class ProgressChatResponseContentImpl implements ProgressChatResponseCont } } +/** + * Implementation of SummaryChatResponseContent. + * Represents a summary of previous chat messages that exceeded token limits. + * The summary content is included in prompts to maintain conversation context. + */ +export class SummaryChatResponseContentImpl implements SummaryChatResponseContent { + readonly kind = 'summary'; + protected _content: string; + + constructor(content: string) { + this._content = content; + } + + get content(): string { + return this._content; + } + + asString(): string { + return this._content; + } + + asDisplayString(): string | undefined { + return this._content; + } + + toLanguageModelMessage(): TextMessage { + return { + actor: 'ai', + type: 'text', + text: `[Summary of previous conversation]\n${this._content}` + }; + } + + toSerializable(): SerializableChatResponseContentData { + return { + kind: 'summary', + data: { content: this._content } + }; + } +} + /** * Fallback content for unknown content types. * Used when a deserializer is not available (e.g., content from removed extension). diff --git a/packages/ai-chat/src/common/chat-service.ts b/packages/ai-chat/src/common/chat-service.ts index 143d17de20424..c33862981ef0d 100644 --- a/packages/ai-chat/src/common/chat-service.ts +++ b/packages/ai-chat/src/common/chat-service.ts @@ -41,7 +41,7 @@ import { import { ChatRequestParser } from './chat-request-parser'; import { ChatSessionNamingService } from './chat-session-naming-service'; import { ParsedChatRequest, ParsedChatRequestAgentPart } from './parsed-chat-request'; -import { ChatSessionIndex, ChatSessionStore } from './chat-session-store'; +import { ChatModelWithMetadata, ChatSessionIndex, ChatSessionStore } from './chat-session-store'; import { ChatContentDeserializerRegistry } from './chat-content-deserializer'; import { ChangeSetDeserializationContext, ChangeSetElementDeserializerRegistry } from './change-set-element-deserializer'; import { SerializableChangeSetElement, SerializedChatModel } from './chat-model-serialization'; @@ -85,6 +85,7 @@ export function isActiveSessionChangedEvent(obj: unknown): obj is ActiveSessionC export interface SessionCreatedEvent { type: 'created'; sessionId: string; + tokenCount?: number; } export function isSessionCreatedEvent(obj: unknown): obj is SessionCreatedEvent { @@ -450,9 +451,12 @@ export class ChatServiceImpl implements ChatService { } // Store session with title and pinned agent info - this.sessionStore.storeSessions( - { model: session.model, title: session.title, pinnedAgentId: session.pinnedAgent?.id } - ).catch(error => { + const sessionData: ChatModelWithMetadata = { + model: session.model, + title: session.title, + pinnedAgentId: session.pinnedAgent?.id + }; + this.sessionStore.storeSessions(sessionData).catch(error => { this.logger.error('Failed to store chat sessions', error); }); } @@ -513,7 +517,11 @@ export class ChatServiceImpl implements ChatService { }; this._sessions.push(session); this.setupAutoSaveForSession(session); - this.onSessionEventEmitter.fire({ type: 'created', sessionId: session.id }); + this.onSessionEventEmitter.fire({ + type: 'created', + sessionId: session.id, + tokenCount: serialized.lastInputTokens + }); this.logger.debug('Session successfully restored and registered', { sessionId, title: session.title }); diff --git a/packages/ai-chat/src/common/chat-session-store.ts b/packages/ai-chat/src/common/chat-session-store.ts index 719bd6502001b..434c1a7106b69 100644 --- a/packages/ai-chat/src/common/chat-session-store.ts +++ b/packages/ai-chat/src/common/chat-session-store.ts @@ -24,6 +24,7 @@ export interface ChatModelWithMetadata { model: ChatModel; title?: string; pinnedAgentId?: string; + lastInputTokens?: number; } export interface ChatSessionStore { diff --git a/packages/ai-chat/src/common/chat-session-token-tracker.ts b/packages/ai-chat/src/common/chat-session-token-tracker.ts new file mode 100644 index 0000000000000..71d7864a4fb9b --- /dev/null +++ b/packages/ai-chat/src/common/chat-session-token-tracker.ts @@ -0,0 +1,73 @@ +// ***************************************************************************** +// Copyright (C) 2025 EclipseSource GmbH. +// +// This program and the accompanying materials are made available under the +// terms of the Eclipse Public License v. 2.0 which is available at +// http://www.eclipse.org/legal/epl-2.0. +// +// This Source Code may also be made available under the following Secondary +// Licenses when the conditions for such availability set forth in the Eclipse +// Public License v. 2.0 are satisfied: GNU General Public License, version 2 +// with the GNU Classpath Exception which is available at +// https://www.gnu.org/software/classpath/license.html. +// +// SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-only WITH Classpath-exception-2.0 +// ***************************************************************************** + +import { Event } from '@theia/core'; + +/** + * Event fired when a session's token count is updated. + */ +export interface SessionTokenUpdateEvent { + sessionId: string; + inputTokens: number; +} + +/** + * Event fired when a session's token usage crosses the threshold. + */ +export interface SessionTokenThresholdEvent { + sessionId: string; + inputTokens: number; +} + +export const ChatSessionTokenTracker = Symbol('ChatSessionTokenTracker'); + +/** + * Service that tracks token usage per chat session. + * + * Listens to token usage updates from the backend and correlates them with + * chat sessions via requestId. When a session's input tokens exceed the + * threshold (90% of 200k), it emits an event for summarization. + */ +export interface ChatSessionTokenTracker { + /** + * Event fired when a session's token usage crosses the threshold. + */ + readonly onThresholdExceeded: Event; + + /** + * Event fired when a session's token count is updated. + */ + readonly onSessionTokensUpdated: Event; + + /** + * Get the latest input token count for a session. + * Returns the inputTokens from the most recent request in the session. + */ + getSessionInputTokens(sessionId: string): number | undefined; + + /** + * Reset the session's token count to a new baseline. + * Called after summarization to reflect the reduced token usage. + */ + resetSessionTokens(sessionId: string, newTokenCount: number): void; + + /** + * Reset the triggered state for a session. + * Called after summarization is complete to allow future triggers + * if the session continues to grow. + */ + resetThresholdTrigger(sessionId: string): void; +} diff --git a/packages/ai-chat/src/common/index.ts b/packages/ai-chat/src/common/index.ts index be96415422ba6..7486de135c9a5 100644 --- a/packages/ai-chat/src/common/index.ts +++ b/packages/ai-chat/src/common/index.ts @@ -23,6 +23,7 @@ export * from './chat-model-util'; export * from './chat-request-parser'; export * from './chat-service'; export * from './chat-session-store'; +export * from './chat-session-token-tracker'; export * from './custom-chat-agent'; export * from './parsed-chat-request'; export * from './context-variables'; diff --git a/packages/ai-core/src/common/language-model.ts b/packages/ai-core/src/common/language-model.ts index 1f9b87b3a52ca..e8aa9453dd829 100644 --- a/packages/ai-core/src/common/language-model.ts +++ b/packages/ai-core/src/common/language-model.ts @@ -167,7 +167,20 @@ export interface LanguageModelRequest { tools?: ToolRequest[]; response_format?: { type: 'text' } | { type: 'json_object' } | ResponseFormatJsonSchema; settings?: { [key: string]: unknown }; - clientSettings?: { keepToolCalls: boolean; keepThinking: boolean } + clientSettings?: { keepToolCalls: boolean; keepThinking: boolean }; + /** + * If true, the model should return after the first LLM response + * without executing tool calls. The caller handles tool execution + * and continuation. + * + * Models that don't support this property ignore it and handle + * the tool loop internally (current behavior). This allows gradual + * migration - once all models support it, the old tool loop code + * can be removed. + * + * Default: false (model handles tool loop internally). + */ + singleRoundTrip?: boolean; } export interface ResponseFormatJsonSchema { type: 'json_schema'; @@ -222,6 +235,10 @@ export const isLanguageModelStreamResponsePart = (part: unknown): part is Langua export interface UsageResponsePart { input_tokens: number; output_tokens: number; + /** Input tokens written to cache (Anthropic-specific) */ + cache_creation_input_tokens?: number; + /** Input tokens read from cache (Anthropic-specific) */ + cache_read_input_tokens?: number; } export const isUsageResponsePart = (part: unknown): part is UsageResponsePart => !!(part && typeof part === 'object' && diff --git a/packages/ai-core/src/common/token-usage-service.ts b/packages/ai-core/src/common/token-usage-service.ts index 5c8b7211c63b0..c67252912ec68 100644 --- a/packages/ai-core/src/common/token-usage-service.ts +++ b/packages/ai-core/src/common/token-usage-service.ts @@ -33,6 +33,8 @@ export interface TokenUsage { timestamp: Date; /** Request identifier */ requestId: string; + /** Session identifier */ + sessionId?: string; } export interface TokenUsageParams { @@ -46,6 +48,8 @@ export interface TokenUsageParams { readCachedInputTokens?: number; /** Request identifier */ requestId: string; + /** Session identifier */ + sessionId?: string; } export interface TokenUsageService { diff --git a/packages/ai-core/src/node/token-usage-service-impl.ts b/packages/ai-core/src/node/token-usage-service-impl.ts index 8dad2ac1ae156..c68a17f07549a 100644 --- a/packages/ai-core/src/node/token-usage-service-impl.ts +++ b/packages/ai-core/src/node/token-usage-service-impl.ts @@ -46,7 +46,8 @@ export class TokenUsageServiceImpl implements TokenUsageService { outputTokens: params.outputTokens, model, timestamp: new Date(), - requestId: params.requestId + requestId: params.requestId, + sessionId: params.sessionId }; this.tokenUsages.push(usage); diff --git a/packages/ai-google/src/node/google-language-model.ts b/packages/ai-google/src/node/google-language-model.ts index e8c986d871e35..bfd69f1e3bfb6 100644 --- a/packages/ai-google/src/node/google-language-model.ts +++ b/packages/ai-google/src/node/google-language-model.ts @@ -317,7 +317,8 @@ export class GoogleModel implements LanguageModel { that.tokenUsageService.recordTokenUsage(that.id, { inputTokens: promptTokens, outputTokens: completionTokens, - requestId: request.requestId + requestId: request.requestId, + sessionId: request.sessionId }).catch(error => console.error('Error recording token usage:', error)); } } @@ -452,7 +453,8 @@ export class GoogleModel implements LanguageModel { await this.tokenUsageService.recordTokenUsage(this.id, { inputTokens: promptTokens, outputTokens: completionTokens, - requestId: request.requestId + requestId: request.requestId, + sessionId: request.sessionId }); } } diff --git a/packages/ai-ollama/src/node/ollama-language-model.ts b/packages/ai-ollama/src/node/ollama-language-model.ts index ab17b810ca97f..91ebd23631f51 100644 --- a/packages/ai-ollama/src/node/ollama-language-model.ts +++ b/packages/ai-ollama/src/node/ollama-language-model.ts @@ -71,7 +71,8 @@ export class OllamaModel implements LanguageModel { stream }; const structured = request.response_format?.type === 'json_schema'; - return this.dispatchRequest(ollama, ollamaRequest, structured, cancellationToken); + const sessionId = 'sessionId' in request ? (request as { sessionId?: string }).sessionId : undefined; + return this.dispatchRequest(ollama, ollamaRequest, structured, sessionId, cancellationToken); } /** @@ -86,7 +87,13 @@ export class OllamaModel implements LanguageModel { }; } - protected async dispatchRequest(ollama: Ollama, ollamaRequest: ExtendedChatRequest, structured: boolean, cancellation?: CancellationToken): Promise { + protected async dispatchRequest( + ollama: Ollama, + ollamaRequest: ExtendedChatRequest, + structured: boolean, + sessionId?: string, + cancellation?: CancellationToken + ): Promise { // Handle structured output request if (structured) { @@ -95,14 +102,19 @@ export class OllamaModel implements LanguageModel { if (isNonStreaming(ollamaRequest)) { // handle non-streaming request - return this.handleNonStreamingRequest(ollama, ollamaRequest, cancellation); + return this.handleNonStreamingRequest(ollama, ollamaRequest, sessionId, cancellation); } // handle streaming request - return this.handleStreamingRequest(ollama, ollamaRequest, cancellation); + return this.handleStreamingRequest(ollama, ollamaRequest, sessionId, cancellation); } - protected async handleStreamingRequest(ollama: Ollama, chatRequest: ExtendedChatRequest, cancellation?: CancellationToken): Promise { + protected async handleStreamingRequest( + ollama: Ollama, + chatRequest: ExtendedChatRequest, + sessionId?: string, + cancellation?: CancellationToken + ): Promise { const responseStream = await ollama.chat({ ...chatRequest, stream: true, @@ -146,7 +158,7 @@ export class OllamaModel implements LanguageModel { } if (chunk.done) { - that.recordTokenUsage(chunk); + that.recordTokenUsage(chunk, sessionId); if (chunk.done_reason && chunk.done_reason !== 'stop') { throw new Error('Ollama stopped unexpectedly. Reason: ' + chunk.done_reason); @@ -169,6 +181,7 @@ export class OllamaModel implements LanguageModel { const continuedResponse = await that.handleStreamingRequest( ollama, chatRequest, + sessionId, cancellation ); @@ -222,7 +235,12 @@ export class OllamaModel implements LanguageModel { } } - protected async handleNonStreamingRequest(ollama: Ollama, chatRequest: ExtendedNonStreamingChatRequest, cancellation?: CancellationToken): Promise { + protected async handleNonStreamingRequest( + ollama: Ollama, + chatRequest: ExtendedNonStreamingChatRequest, + sessionId?: string, + cancellation?: CancellationToken + ): Promise { try { // even though we have a non-streaming request, we still use the streaming version for two reasons: // 1. we can abort the stream if the request is cancelled instead of having to wait for the entire response @@ -251,7 +269,7 @@ export class OllamaModel implements LanguageModel { // if the response is done, record the token usage and check the done reason if (chunk.done) { - this.recordTokenUsage(chunk); + this.recordTokenUsage(chunk, sessionId); lastUpdated = chunk.created_at; if (chunk.done_reason && chunk.done_reason !== 'stop') { throw new Error('Ollama stopped unexpectedly. Reason: ' + chunk.done_reason); @@ -273,7 +291,7 @@ export class OllamaModel implements LanguageModel { } // recurse to get the final response content (the intermediate content remains hidden, it is only part of the conversation) - return this.handleNonStreamingRequest(ollama, chatRequest); + return this.handleNonStreamingRequest(ollama, chatRequest, sessionId); } // if no tool calls are necessary, return the final response content @@ -315,12 +333,13 @@ export class OllamaModel implements LanguageModel { return toolCallsForResponse; } - private recordTokenUsage(response: ChatResponse): void { + private recordTokenUsage(response: ChatResponse, sessionId?: string): void { if (this.tokenUsageService && response.prompt_eval_count && response.eval_count) { this.tokenUsageService.recordTokenUsage(this.id, { inputTokens: response.prompt_eval_count, outputTokens: response.eval_count, - requestId: `ollama_${response.created_at}` + requestId: `ollama_${response.created_at}`, + sessionId }).catch(error => console.error('Error recording token usage:', error)); } } diff --git a/packages/ai-openai/src/node/openai-language-model.ts b/packages/ai-openai/src/node/openai-language-model.ts index 408c599446332..a59094ca175dc 100644 --- a/packages/ai-openai/src/node/openai-language-model.ts +++ b/packages/ai-openai/src/node/openai-language-model.ts @@ -161,7 +161,7 @@ export class OpenAiModel implements LanguageModel { }); } - return { stream: new StreamingAsyncIterator(runner, request.requestId, cancellationToken, this.tokenUsageService, this.id) }; + return { stream: new StreamingAsyncIterator(runner, request.requestId, request.sessionId, cancellationToken, this.tokenUsageService, this.id) }; } protected async handleNonStreamingRequest(openai: OpenAI, request: UserRequest): Promise { @@ -181,7 +181,8 @@ export class OpenAiModel implements LanguageModel { { inputTokens: response.usage.prompt_tokens, outputTokens: response.usage.completion_tokens, - requestId: request.requestId + requestId: request.requestId, + sessionId: request.sessionId } ); } @@ -216,7 +217,8 @@ export class OpenAiModel implements LanguageModel { { inputTokens: result.usage.prompt_tokens, outputTokens: result.usage.completion_tokens, - requestId: request.requestId + requestId: request.requestId, + sessionId: request.sessionId } ); } diff --git a/packages/ai-openai/src/node/openai-response-api-utils.ts b/packages/ai-openai/src/node/openai-response-api-utils.ts index 6f22f2a5a302b..e16ca7e778691 100644 --- a/packages/ai-openai/src/node/openai-response-api-utils.ts +++ b/packages/ai-openai/src/node/openai-response-api-utils.ts @@ -95,7 +95,7 @@ export class OpenAiResponseApiUtils { input, ...settings }); - return { stream: this.createSimpleResponseApiStreamIterator(stream, request.requestId, modelId, tokenUsageService, cancellationToken) }; + return { stream: this.createSimpleResponseApiStreamIterator(stream, request.requestId, request.sessionId, modelId, tokenUsageService, cancellationToken) }; } else { const response = await openai.responses.create({ model: model as ResponsesModel, @@ -168,6 +168,7 @@ export class OpenAiResponseApiUtils { protected createSimpleResponseApiStreamIterator( stream: AsyncIterable, requestId: string, + sessionId: string, modelId: string, tokenUsageService?: TokenUsageService, cancellationToken?: CancellationToken @@ -191,7 +192,8 @@ export class OpenAiResponseApiUtils { { inputTokens: event.response.usage.input_tokens, outputTokens: event.response.usage.output_tokens, - requestId + requestId, + sessionId } ); } @@ -752,7 +754,8 @@ class ResponseApiToolCallIterator implements AsyncIterableIterator { }); function createIterator(withCancellationToken = false): StreamingAsyncIterator { - return new StreamingAsyncIterator(mockStream, '', withCancellationToken ? cts.token : undefined); + return new StreamingAsyncIterator(mockStream, '', '', withCancellationToken ? cts.token : undefined); } it('should yield messages in the correct order when consumed immediately', async () => { diff --git a/packages/ai-openai/src/node/openai-streaming-iterator.ts b/packages/ai-openai/src/node/openai-streaming-iterator.ts index 15ca438469cb3..8f555711f9b91 100644 --- a/packages/ai-openai/src/node/openai-streaming-iterator.ts +++ b/packages/ai-openai/src/node/openai-streaming-iterator.ts @@ -32,6 +32,7 @@ export class StreamingAsyncIterator implements AsyncIterableIterator console.error('Error recording token usage:', error)); diff --git a/packages/ai-vercel-ai/src/node/vercel-ai-language-model.ts b/packages/ai-vercel-ai/src/node/vercel-ai-language-model.ts index 5711d008ea2e6..94f2e140ba3fa 100644 --- a/packages/ai-vercel-ai/src/node/vercel-ai-language-model.ts +++ b/packages/ai-vercel-ai/src/node/vercel-ai-language-model.ts @@ -341,7 +341,8 @@ export class VercelAiModel implements LanguageModel { { inputTokens: result.usage.promptTokens, outputTokens: result.usage.completionTokens, - requestId: request.requestId + requestId: request.requestId, + sessionId: request.sessionId } ); } @@ -371,7 +372,8 @@ export class VercelAiModel implements LanguageModel { this.tokenUsageService?.recordTokenUsage(this.id, { inputTokens: stepResult.usage.promptTokens, outputTokens: stepResult.usage.completionTokens, - requestId: request.requestId + requestId: request.requestId, + sessionId: request.sessionId }); } }, From 2450c9b5951836fbfaddfef5684dea0a9d4dcc17 Mon Sep 17 00:00:00 2001 From: Eugen Neufeld Date: Thu, 18 Dec 2025 17:11:06 +0100 Subject: [PATCH 2/5] fix issues --- .../chat-token-usage-indicator.spec.ts | 25 +- .../chat-tree-view/chat-view-tree-widget.tsx | 171 ++-- .../ai-chat-ui/src/browser/style/index.css | 81 +- .../src/browser/ai-chat-frontend-module.ts | 4 - .../chat-language-model-service.spec.ts | 3 +- .../src/browser/chat-session-store-impl.ts | 3 +- ...chat-session-summarization-service.spec.ts | 770 ++++++++++++++++++ .../chat-session-summarization-service.ts | 242 +++++- ...chat-session-token-restore-contribution.ts | 42 - .../chat-session-token-tracker.spec.ts | 203 ++--- .../src/browser/chat-session-token-tracker.ts | 108 +-- packages/ai-chat/src/common/chat-agents.ts | 1 + .../src/common/chat-model-hierarchy.spec.ts | 165 ++++ .../common/chat-model-insert-summary.spec.ts | 185 ++--- .../src/common/chat-model-serialization.ts | 6 +- packages/ai-chat/src/common/chat-model.ts | 97 +-- packages/ai-chat/src/common/chat-service.ts | 7 +- .../ai-chat/src/common/chat-session-store.ts | 1 + .../src/common/chat-session-summary-agent.ts | 1 + .../src/common/chat-session-token-tracker.ts | 64 +- 20 files changed, 1534 insertions(+), 645 deletions(-) create mode 100644 packages/ai-chat/src/browser/chat-session-summarization-service.spec.ts delete mode 100644 packages/ai-chat/src/browser/chat-session-token-restore-contribution.ts create mode 100644 packages/ai-chat/src/common/chat-model-hierarchy.spec.ts diff --git a/packages/ai-chat-ui/src/browser/chat-token-usage-indicator.spec.ts b/packages/ai-chat-ui/src/browser/chat-token-usage-indicator.spec.ts index 4a47251b145e6..fbfc7705130c0 100644 --- a/packages/ai-chat-ui/src/browser/chat-token-usage-indicator.spec.ts +++ b/packages/ai-chat-ui/src/browser/chat-token-usage-indicator.spec.ts @@ -26,7 +26,6 @@ import { flushSync } from '@theia/core/shared/react-dom'; import { Emitter } from '@theia/core'; import { ChatSessionTokenTracker, - SessionTokenThresholdEvent, SessionTokenUpdateEvent, CHAT_TOKEN_THRESHOLD } from '@theia/ai-chat/lib/browser'; @@ -39,14 +38,16 @@ describe('ChatTokenUsageIndicator', () => { let root: ReactDOMClient.Root; const createMockTokenTracker = (tokens: number | undefined): ChatSessionTokenTracker => { - const thresholdEmitter = new Emitter(); const updateEmitter = new Emitter(); return { - onThresholdExceeded: thresholdEmitter.event, onSessionTokensUpdated: updateEmitter.event, getSessionInputTokens: () => tokens, resetSessionTokens: () => { }, - resetThresholdTrigger: () => { } + setBranchTokens: () => { }, + getBranchTokens: () => undefined, + getBranchTokensForSession: () => ({}), + restoreBranchTokens: () => { }, + clearSessionBranchTokens: () => { } }; }; @@ -240,15 +241,17 @@ describe('ChatTokenUsageIndicator', () => { describe('subscription to token updates', () => { it('should update when token tracker fires update event', () => { const updateEmitter = new Emitter(); - const thresholdEmitter = new Emitter(); let currentTokens = 50000; const mockTracker: ChatSessionTokenTracker = { - onThresholdExceeded: thresholdEmitter.event, onSessionTokensUpdated: updateEmitter.event, getSessionInputTokens: () => currentTokens, resetSessionTokens: () => { }, - resetThresholdTrigger: () => { } + setBranchTokens: () => { }, + getBranchTokens: () => undefined, + getBranchTokensForSession: () => ({}), + restoreBranchTokens: () => { }, + clearSessionBranchTokens: () => { } }; renderComponent({ @@ -273,14 +276,16 @@ describe('ChatTokenUsageIndicator', () => { it('should not update when event is for different session', () => { const updateEmitter = new Emitter(); - const thresholdEmitter = new Emitter(); const mockTracker: ChatSessionTokenTracker = { - onThresholdExceeded: thresholdEmitter.event, onSessionTokensUpdated: updateEmitter.event, getSessionInputTokens: () => 50000, resetSessionTokens: () => { }, - resetThresholdTrigger: () => { } + setBranchTokens: () => { }, + getBranchTokens: () => undefined, + getBranchTokensForSession: () => ({}), + restoreBranchTokens: () => { }, + clearSessionBranchTokens: () => { } }; renderComponent({ diff --git a/packages/ai-chat-ui/src/browser/chat-tree-view/chat-view-tree-widget.tsx b/packages/ai-chat-ui/src/browser/chat-tree-view/chat-view-tree-widget.tsx index 87173c728a928..9e101a369756e 100644 --- a/packages/ai-chat-ui/src/browser/chat-tree-view/chat-view-tree-widget.tsx +++ b/packages/ai-chat-ui/src/browser/chat-tree-view/chat-view-tree-widget.tsx @@ -394,10 +394,15 @@ export class ChatViewTreeWidget extends TreeWidget { return { parent: this.model.root as CompositeTreeNode, get id(): string { - return this.request.id; + return this.request?.id ?? `empty-branch-${branch.id}`; }, get request(): ChatRequestModel { - return branch.get(); + // Guard against empty branches - can happen during insertSummary + try { + return branch.get(); + } catch { + return undefined as unknown as ChatRequestModel; + } }, branch, sessionId: this.chatModelId @@ -485,9 +490,13 @@ export class ChatViewTreeWidget extends TreeWidget { const nodes: TreeNode[] = []; this.chatModelId = chatModel.id; chatModel.getBranches().forEach(branch => { + // Skip empty branches (can occur during insertSummary operations) + if (branch.items.length === 0) { + return; + } const request = branch.get(); nodes.push(this.mapRequestToNode(branch)); - // Skip ResponseNode for summary nodes - they render content in the RequestNode + // Skip separate response node for summary requests - response is rendered within request node if (request.request.kind !== 'summary') { nodes.push(this.mapResponseToNode(request.response)); } @@ -504,20 +513,16 @@ export class ChatViewTreeWidget extends TreeWidget { if (!TreeNode.isVisible(node)) { return undefined; } + if (isRequestNode(node)) { + // Skip rendering if the branch is empty (request will be undefined) + if (!node.request) { + return undefined; + } + } if (!(isRequestNode(node) || isResponseNode(node))) { return super.renderNode(node, props); } - // Summary nodes render without agent header - const isSummaryNode = isRequestNode(node) && node.request.request.kind === 'summary'; - if (isSummaryNode) { - return -

this.handleContextMenu(node, e)}> - {this.renderDetail(node)} -
- ; - } - return
this.handleContextMenu(node, e)}> {this.renderAgent(node)} @@ -636,6 +641,7 @@ export class ChatViewTreeWidget extends TreeWidget { chatAgentService={this.chatAgentService} variableService={this.variableService} openerService={this.openerService} + renderResponseContent={(content: ChatResponseContent) => this.renderResponseContent(content)} provideChatInputWidget={() => { const editableNode = node; if (isEditableRequestNode(editableNode)) { @@ -666,6 +672,21 @@ export class ChatViewTreeWidget extends TreeWidget { />; } + protected renderResponseContent(content: ChatResponseContent): React.ReactNode { + const renderer = this.chatResponsePartRenderers.getContributions().reduce<[number, ChatResponsePartRenderer | undefined]>( + (prev, current) => { + const prio = current.canHandle(content); + if (prio > prev[0]) { + return [prio, current]; + } return prev; + }, + [-1, undefined])[1]; + if (!renderer) { + return undefined; + } + return renderer.render(content, undefined as unknown as ResponseNode); + } + protected renderChatResponse(node: ResponseNode): React.ReactNode { return (
@@ -768,15 +789,10 @@ const WidgetContainer: React.FC = ({ widget }) => { return
; }; -const SummaryContentRenderer: React.FC<{ content: string; openerService: OpenerService }> = ({ content, openerService }) => { - const ref = useMarkdownRendering(content, openerService); - return
; -}; - const ChatRequestRender = ( { node, hoverService, chatAgentService, variableService, openerService, - provideChatInputWidget + provideChatInputWidget, renderResponseContent }: { node: RequestNode, hoverService: HoverService, @@ -784,32 +800,15 @@ const ChatRequestRender = ( variableService: AIVariableService, openerService: OpenerService, provideChatInputWidget: () => ReactWidget | undefined, + renderResponseContent?: (content: ChatResponseContent) => React.ReactNode, }) => { - const parts = node.request.message.parts; - const isStale = node.request.isStale === true; - const isSummaryNode = node.request.request.kind === 'summary'; + // Capture the request object once to avoid getter issues + const request = node.request; + const parts = request.message.parts; + const isStale = request.isStale === true; + const isSummary = request.request.kind === 'summary'; - // Summary nodes render header and content in a single unified node - if (isSummaryNode) { - const summaryContent = node.request.response.response.asDisplayString(); - return ( -
-
- - - {nls.localize('theia/ai/chat-ui/chat-view-tree-widget/conversationSummary', 'Conversation Summary')} - - {summaryContent && ( -
- -
- )} -
-
- ); - } - - if (EditableChatRequestModel.isEditing(node.request)) { + if (EditableChatRequestModel.isEditing(request)) { const widget = provideChatInputWidget(); if (widget) { return
@@ -847,43 +846,57 @@ const ChatRequestRender = ( }; return ( -
-

- {parts.map((part, index) => { - if (part instanceof ParsedChatRequestAgentPart || part instanceof ParsedChatRequestVariablePart) { - let description = undefined; - let className = ''; - if (part instanceof ParsedChatRequestAgentPart) { - description = chatAgentService.getAgent(part.agentId)?.description; - className = 'theia-RequestNode-AgentLabel'; - } else if (part instanceof ParsedChatRequestVariablePart) { - description = variableService.getVariable(part.variableName)?.description; - className = 'theia-RequestNode-VariableLabel'; +

+ {isSummary && ( +
+ + {nls.localize('theia/ai-chat/summary', 'Conversation Summary')} +
+ )} + {isSummary && renderResponseContent ? ( +
+ {request.response.response.content.map((c, i) => ( +
{renderResponseContent(c)}
+ ))} +
+ ) : ( +

+ {parts.map((part, index) => { + if (part instanceof ParsedChatRequestAgentPart || part instanceof ParsedChatRequestVariablePart) { + let description = undefined; + let className = ''; + if (part instanceof ParsedChatRequestAgentPart) { + description = chatAgentService.getAgent(part.agentId)?.description; + className = 'theia-RequestNode-AgentLabel'; + } else if (part instanceof ParsedChatRequestVariablePart) { + description = variableService.getVariable(part.variableName)?.description; + className = 'theia-RequestNode-VariableLabel'; + } + return ( + + ); + } else { + const ref = useMarkdownRendering( + part.text + .replace(/^[\r\n]+|[\r\n]+$/g, '') // remove excessive new lines + .replace(/(^ )/g, ' '), // enforce keeping space before + openerService, + true + ); + return ( + + ); } - return ( - - ); - } else { - const ref = useMarkdownRendering( - part.text - .replace(/^[\r\n]+|[\r\n]+$/g, '') // remove excessive new lines - .replace(/(^ )/g, ' '), // enforce keeping space before - openerService, - true - ); - return ( - - ); - } - })} -

- {renderFooter()} + })} +

+ )} + {!isSummary && renderFooter()}
); }; diff --git a/packages/ai-chat-ui/src/browser/style/index.css b/packages/ai-chat-ui/src/browser/style/index.css index 5bcec8d7ed502..d1d5ec7605fc9 100644 --- a/packages/ai-chat-ui/src/browser/style/index.css +++ b/packages/ai-chat-ui/src/browser/style/index.css @@ -1283,93 +1283,32 @@ details[open].collapsible-arguments .collapsible-arguments-summary { opacity: 0.7; } - -/* Summary node styles */ -.theia-ChatNode-Summary { - background-color: var(--theia-editorGroupHeader-tabsBackground); - padding: 8px 16px; -} - +/* Summary request styles */ .theia-RequestNode-summary { - padding: 0; - width: 100%; -} - -.theia-RequestNode-summary details { - width: 100%; - border: 1px solid var(--theia-sideBarSectionHeader-border); + background-color: var(--theia-editor-inactiveSelectionBackground); + border-left: 3px solid var(--theia-focusBorder); + padding-left: 8px; + margin: 8px 0; border-radius: 4px; - background-color: var(--theia-sideBar-background); } .theia-RequestNode-SummaryHeader { - cursor: pointer; - padding: 8px 12px; display: flex; align-items: center; - gap: 8px; - font-weight: 600; - color: var(--theia-foreground); - user-select: none; - list-style: none; - position: relative; -} - -.theia-RequestNode-SummaryHeader::-webkit-details-marker { - display: none; -} - -.theia-RequestNode-SummaryHeader::marker { - content: ""; -} - -.theia-RequestNode-SummaryHeader:hover { - background-color: var(--theia-toolbar-hoverBackground); + gap: 6px; + font-weight: 500; + margin-bottom: 8px; + color: var(--theia-descriptionForeground); } .theia-RequestNode-SummaryHeader .codicon { - color: var(--theia-button-background); font-size: 14px; } -.theia-RequestNode-SummaryHeader::after { - content: "\25BC"; - position: absolute; - right: 12px; - top: 50%; - transform: translateY(-50%); - transition: transform 0.2s ease; - color: var(--theia-descriptionForeground); - font-size: var(--theia-ui-font-size1); -} - -.theia-RequestNode-summary details:not([open]) .theia-RequestNode-SummaryHeader::after { - transform: translateY(-50%) rotate(-90deg); -} - .theia-RequestNode-SummaryContent { - padding: 12px; - border-top: 1px solid var(--theia-sideBarSectionHeader-border); - background-color: var(--theia-editor-background); - border-radius: 0 0 4px 4px; + margin-top: 8px; } -.theia-RequestNode-SummaryIndicator { - display: flex; - align-items: center; - gap: 6px; - padding: 4px 8px; - background-color: var(--theia-editorGroupHeader-tabsBackground); - color: var(--theia-descriptionForeground); - border-radius: 4px; - font-size: 12px; - font-weight: 500; -} - -.theia-RequestNode-SummaryIndicator .codicon { - font-size: 14px; - color: var(--theia-button-background); -} /* Chat summary styles */ .theia-chat-summary { diff --git a/packages/ai-chat/src/browser/ai-chat-frontend-module.ts b/packages/ai-chat/src/browser/ai-chat-frontend-module.ts index 55939a898d0bb..96fcdbb1aa3f5 100644 --- a/packages/ai-chat/src/browser/ai-chat-frontend-module.ts +++ b/packages/ai-chat/src/browser/ai-chat-frontend-module.ts @@ -74,7 +74,6 @@ import { } from '../common/change-set-element-deserializer'; import { ChangeSetFileElementDeserializerContribution } from './change-set-file-element-deserializer'; import { ChatSessionTokenTracker, ChatSessionTokenTrackerImpl } from './chat-session-token-tracker'; -import { ChatSessionTokenRestoreContribution } from './chat-session-token-restore-contribution'; import { ChatSessionSummarizationService, ChatSessionSummarizationServiceImpl } from './chat-session-summarization-service'; import { ChatLanguageModelServiceImpl } from './chat-language-model-service'; @@ -194,9 +193,6 @@ export default new ContainerModule((bind, unbind, isBound, rebind) => { bind(ChatSessionTokenTrackerImpl).toSelf().inSingletonScope(); bind(ChatSessionTokenTracker).toService(ChatSessionTokenTrackerImpl); - bind(ChatSessionTokenRestoreContribution).toSelf().inSingletonScope(); - bind(FrontendApplicationContribution).toService(ChatSessionTokenRestoreContribution); - bind(ChatSessionSummarizationServiceImpl).toSelf().inSingletonScope(); bind(ChatSessionSummarizationService).toService(ChatSessionSummarizationServiceImpl); bind(FrontendApplicationContribution).toService(ChatSessionSummarizationServiceImpl); diff --git a/packages/ai-chat/src/browser/chat-language-model-service.spec.ts b/packages/ai-chat/src/browser/chat-language-model-service.spec.ts index 614c0608058f8..5c9a138510552 100644 --- a/packages/ai-chat/src/browser/chat-language-model-service.spec.ts +++ b/packages/ai-chat/src/browser/chat-language-model-service.spec.ts @@ -56,8 +56,7 @@ describe('ChatLanguageModelServiceImpl', () => { } as unknown as sinon.SinonStubbedInstance; mockTokenTracker = { - getSessionInputTokens: sinon.stub(), - onThresholdExceeded: sinon.stub() + getSessionInputTokens: sinon.stub() } as unknown as sinon.SinonStubbedInstance; mockSummarizationService = { diff --git a/packages/ai-chat/src/browser/chat-session-store-impl.ts b/packages/ai-chat/src/browser/chat-session-store-impl.ts index 795c264674033..8ea2ae984eb20 100644 --- a/packages/ai-chat/src/browser/chat-session-store-impl.ts +++ b/packages/ai-chat/src/browser/chat-session-store-impl.ts @@ -80,7 +80,8 @@ export class ChatSessionStoreImpl implements ChatSessionStore { pinnedAgentId: session.pinnedAgentId, saveDate: session.saveDate, model: modelData, - lastInputTokens: this.tokenTracker.getSessionInputTokens(session.model.id) + lastInputTokens: this.tokenTracker.getSessionInputTokens(session.model.id), + branchTokens: this.tokenTracker.getBranchTokensForSession(session.model.id) }; this.logger.debug('Writing session to file', { sessionId: session.model.id, diff --git a/packages/ai-chat/src/browser/chat-session-summarization-service.spec.ts b/packages/ai-chat/src/browser/chat-session-summarization-service.spec.ts new file mode 100644 index 0000000000000..334173d6b2939 --- /dev/null +++ b/packages/ai-chat/src/browser/chat-session-summarization-service.spec.ts @@ -0,0 +1,770 @@ +// ***************************************************************************** +// Copyright (C) 2025 EclipseSource GmbH. +// +// This program and the accompanying materials are made available under the +// terms of the Eclipse Public License v. 2.0 which is available at +// http://www.eclipse.org/legal/epl-2.0. +// +// This Source Code may also be made available under the following Secondary +// Licenses when the conditions for such availability set forth in the Eclipse +// Public License v. 2.0 are satisfied: GNU General Public License, version 2 +// with the GNU Classpath Exception which is available at +// https://www.gnu.org/software/classpath/license.html. +// +// SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-only WITH Classpath-exception-2.0 +// ***************************************************************************** + +import { expect } from 'chai'; +import * as sinon from 'sinon'; +import { Container } from '@theia/core/shared/inversify'; +import { Emitter, ILogger } from '@theia/core'; +import { TokenUsage, TokenUsageServiceClient } from '@theia/ai-core'; +import { ChatSessionSummarizationServiceImpl } from './chat-session-summarization-service'; +import { ChatSessionTokenTracker, CHAT_TOKEN_THRESHOLD } from './chat-session-token-tracker'; +import { ChatRequestInvocation, ChatService, SessionCreatedEvent, SessionDeletedEvent } from '../common/chat-service'; +import { ChatRequestModel, ChatResponseModel, ChatSession } from '../common'; +import { ChatSessionStore } from '../common/chat-session-store'; + +describe('ChatSessionSummarizationServiceImpl', () => { + let container: Container; + let service: ChatSessionSummarizationServiceImpl; + let tokenTracker: sinon.SinonStubbedInstance; + let chatService: sinon.SinonStubbedInstance; + let tokenUsageClient: sinon.SinonStubbedInstance; + let logger: sinon.SinonStubbedInstance; + + let tokenUsageEmitter: Emitter; + let sessionEventEmitter: Emitter; + let sessionRegistry: Map; + let sessionStore: sinon.SinonStubbedInstance; + + // Helper to create a mock TokenUsage event + function createTokenUsage(params: { + sessionId: string; + requestId: string; + inputTokens: number; + outputTokens: number; + cachedInputTokens?: number; + readCachedInputTokens?: number; + }): TokenUsage { + return { + ...params, + model: 'test-model', + timestamp: new Date() + }; + } + + // Helper to create a mock session + function createMockSession(sessionId: string, activeBranchId: string, branches: { id: string }[] = []): ChatSession { + const modelChangeEmitter = new Emitter(); + const allBranches = branches.length > 0 ? branches : [{ id: activeBranchId }]; + return { + id: sessionId, + isActive: true, + model: { + getBranch: sinon.stub().callsFake((requestId: string) => { + // Return branch based on requestId pattern: 'request-for-branchX' => { id: 'branchX' } + const match = requestId.match(/request-for-(.+)/); + if (match) { + return { id: match[1] }; + } + return undefined; + }), + getBranches: sinon.stub().returns(allBranches), + getRequest: sinon.stub().callsFake((requestId: string) => { + if (requestId.includes('summary')) { + return { request: { kind: 'summary' } }; + } + return { request: { kind: 'user' } }; + }), + onDidChange: modelChangeEmitter.event + } + } as unknown as ChatSession; + } + + beforeEach(() => { + container = new Container(); + + // Create emitters for event simulation + tokenUsageEmitter = new Emitter(); + sessionEventEmitter = new Emitter(); + + // Create session registry for dynamic lookup + sessionRegistry = new Map(); + + // Create stubs + const branchTokensMap = new Map(); + tokenTracker = { + resetSessionTokens: sinon.stub(), + getSessionInputTokens: sinon.stub(), + onSessionTokensUpdated: sinon.stub(), + setBranchTokens: sinon.stub().callsFake((sessionId: string, branchId: string, tokens: number) => { + branchTokensMap.set(`${sessionId}:${branchId}`, tokens); + }), + getBranchTokens: sinon.stub().callsFake((sessionId: string, branchId: string) => branchTokensMap.get(`${sessionId}:${branchId}`)), + getBranchTokensForSession: sinon.stub().callsFake((sessionId: string) => { + const result: { [branchId: string]: number } = {}; + const prefix = `${sessionId}:`; + for (const [key, value] of branchTokensMap.entries()) { + if (key.startsWith(prefix)) { + const branchId = key.substring(prefix.length); + result[branchId] = value; + } + } + return result; + }), + restoreBranchTokens: sinon.stub().callsFake((sessionId: string, branchTokens: { [branchId: string]: number }) => { + for (const [branchId, tokens] of Object.entries(branchTokens)) { + branchTokensMap.set(`${sessionId}:${branchId}`, tokens); + } + }), + clearSessionBranchTokens: sinon.stub().callsFake((sessionId: string) => { + const prefix = `${sessionId}:`; + for (const key of branchTokensMap.keys()) { + if (key.startsWith(prefix)) { + branchTokensMap.delete(key); + } + } + }) + } as unknown as sinon.SinonStubbedInstance; + + chatService = { + getSession: sinon.stub().callsFake((sessionId: string) => sessionRegistry.get(sessionId)), + getSessions: sinon.stub().returns([]), + onSessionEvent: sessionEventEmitter.event, + sendRequest: sinon.stub() + } as unknown as sinon.SinonStubbedInstance; + + tokenUsageClient = { + onTokenUsageUpdated: tokenUsageEmitter.event + } as unknown as sinon.SinonStubbedInstance; + + logger = { + info: sinon.stub(), + warn: sinon.stub(), + error: sinon.stub(), + debug: sinon.stub() + } as unknown as sinon.SinonStubbedInstance; + + sessionStore = { + storeSessions: sinon.stub().resolves(), + readSession: sinon.stub().resolves(undefined), + deleteSession: sinon.stub().resolves(), + clearAllSessions: sinon.stub().resolves(), + getSessionIndex: sinon.stub().resolves({}) + } as unknown as sinon.SinonStubbedInstance; + + // Bind to container + container.bind(ChatSessionTokenTracker).toConstantValue(tokenTracker); + container.bind(ChatService).toConstantValue(chatService); + container.bind(TokenUsageServiceClient).toConstantValue(tokenUsageClient); + container.bind(ILogger).toConstantValue(logger); + container.bind(ChatSessionStore).toConstantValue(sessionStore); + container.bind(ChatSessionSummarizationServiceImpl).toSelf().inSingletonScope(); + + service = container.get(ChatSessionSummarizationServiceImpl); + // Manually call init since @postConstruct won't run in tests + (service as unknown as { init: () => void }).init(); + }); + + afterEach(() => { + sinon.restore(); + tokenUsageEmitter.dispose(); + sessionEventEmitter.dispose(); + sessionRegistry.clear(); + }); + + // Helper to create a mock ChatRequestInvocation + function createMockInvocation(params: { + requestId: string; + isError: boolean; + displayString: string; + errorObject?: Error; + }): ChatRequestInvocation { + const mockRequest = { + id: params.requestId, + request: { kind: 'summary' as const }, + addData: sinon.stub() + } as unknown as ChatRequestModel; + + const mockResponse = { + isError: params.isError, + errorObject: params.errorObject, + response: { + asDisplayString: () => params.displayString + } + } as unknown as ChatResponseModel; + + return { + requestCompleted: Promise.resolve(mockRequest), + responseCreated: Promise.resolve(mockResponse), + responseCompleted: Promise.resolve(mockResponse) + }; + } + + describe('performSummarization error handling', () => { + it('should return undefined and log warning when sendRequest response has error', async () => { + const sessionId = 'session-with-error'; + const branchId = 'branch-A'; + + // Create a mock model with insertSummary that calls the callback + const modelChangeEmitter = new Emitter(); + const mockModel = { + getBranch: sinon.stub(), + getBranches: sinon.stub().returns([{ id: branchId }]), + getRequest: sinon.stub(), + getRequests: sinon.stub().returns([]), + onDidChange: modelChangeEmitter.event, + insertSummary: sinon.stub().callsFake( + async (callback: () => Promise<{ requestId: string; summaryText: string } | undefined>) => { + const callbackResult = await callback(); + return callbackResult?.summaryText; + } + ) + }; + + const session = { + id: sessionId, + isActive: true, + model: mockModel + } as unknown as ChatSession; + + sessionRegistry.set(sessionId, session); + + // Mock sendRequest to return an invocation with error response + (chatService.sendRequest as sinon.SinonStub).resolves( + createMockInvocation({ + requestId: 'summary-request-id', + isError: true, + displayString: '', + errorObject: new Error('No language model configured') + }) + ); + + // Call triggerSummarization + const result = await service.triggerSummarization(sessionId, false); + + // Verify result is undefined (error response returns undefined from callback) + expect(result).to.be.undefined; + + // Verify warning was logged for failed summarization + expect((logger.warn as sinon.SinonStub).called).to.be.true; + }); + + it('should return undefined and log warning when sendRequest returns empty response', async () => { + const sessionId = 'session-with-empty'; + const branchId = 'branch-A'; + + // Create a mock model with insertSummary that calls the callback + const modelChangeEmitter = new Emitter(); + const mockModel = { + getBranch: sinon.stub(), + getBranches: sinon.stub().returns([{ id: branchId }]), + getRequest: sinon.stub(), + getRequests: sinon.stub().returns([]), + onDidChange: modelChangeEmitter.event, + insertSummary: sinon.stub().callsFake( + async (callback: () => Promise<{ requestId: string; summaryText: string } | undefined>) => { + const callbackResult = await callback(); + return callbackResult?.summaryText; + } + ) + }; + + const session = { + id: sessionId, + isActive: true, + model: mockModel + } as unknown as ChatSession; + + sessionRegistry.set(sessionId, session); + + // Mock sendRequest to return an invocation with empty response + (chatService.sendRequest as sinon.SinonStub).resolves( + createMockInvocation({ + requestId: 'summary-request-id', + isError: false, + displayString: ' ' + }) + ); + + // Call triggerSummarization + const result = await service.triggerSummarization(sessionId, false); + + // Verify result is undefined (empty response returns undefined from callback) + expect(result).to.be.undefined; + + // Verify warning was logged + expect((logger.warn as sinon.SinonStub).called).to.be.true; + }); + + it('should return summary text when response is successful', async () => { + const sessionId = 'session-success'; + const branchId = 'branch-A'; + const summaryText = 'This is a valid summary of the conversation.'; + + // Create a mock model with insertSummary that calls the callback + const modelChangeEmitter = new Emitter(); + const mockModel = { + getBranch: sinon.stub(), + getBranches: sinon.stub().returns([{ id: branchId }]), + getRequest: sinon.stub(), + getRequests: sinon.stub().returns([{ request: { kind: 'summary' }, getDataByKey: sinon.stub() }]), + onDidChange: modelChangeEmitter.event, + insertSummary: sinon.stub().callsFake( + async (callback: () => Promise<{ requestId: string; summaryText: string } | undefined>) => { + const callbackResult = await callback(); + return callbackResult?.summaryText; + } + ) + }; + + const session = { + id: sessionId, + isActive: true, + model: mockModel + } as unknown as ChatSession; + + sessionRegistry.set(sessionId, session); + + // Mock sendRequest to return a successful invocation + (chatService.sendRequest as sinon.SinonStub).resolves( + createMockInvocation({ + requestId: 'summary-request-id', + isError: false, + displayString: summaryText + }) + ); + + // Call triggerSummarization + const result = await service.triggerSummarization(sessionId, false); + + // Verify result is the summary text + expect(result).to.equal(summaryText); + }); + + it('should reset token count to output tokens after successful summarization', async () => { + const sessionId = 'session-with-output-tokens'; + const branchId = 'branch-A'; + const summaryText = 'This is a valid summary.'; + const outputTokens = 1500; + + // Create a mock model with insertSummary that calls the callback + const modelChangeEmitter = new Emitter(); + const summaryRequestMock = { + request: { kind: 'summary' }, + getDataByKey: sinon.stub().withArgs('capturedOutputTokens').returns(outputTokens) + }; + const mockModel = { + getBranch: sinon.stub(), + getBranches: sinon.stub().returns([{ id: branchId }]), + getRequest: sinon.stub(), + getRequests: sinon.stub().returns([summaryRequestMock]), + onDidChange: modelChangeEmitter.event, + insertSummary: sinon.stub().callsFake( + async (callback: () => Promise<{ requestId: string; summaryText: string } | undefined>) => { + const callbackResult = await callback(); + return callbackResult?.summaryText; + } + ) + }; + + const session = { + id: sessionId, + isActive: true, + model: mockModel + } as unknown as ChatSession; + + sessionRegistry.set(sessionId, session); + + // Mock sendRequest to return a successful invocation + (chatService.sendRequest as sinon.SinonStub).resolves( + createMockInvocation({ + requestId: 'summary-request-id', + isError: false, + displayString: summaryText + }) + ); + + // Call triggerSummarization + await service.triggerSummarization(sessionId, false); + + // Verify token tracker was reset to output tokens (not 0) + expect((tokenTracker.resetSessionTokens as sinon.SinonStub).calledWith(sessionId, outputTokens)).to.be.true; + expect((tokenTracker.setBranchTokens as sinon.SinonStub).calledWith(sessionId, branchId, outputTokens)).to.be.true; + }); + }); + + describe('per-branch token tracking', () => { + it('should attribute tokens to the correct branch via model.getBranch(requestId)', () => { + const sessionId = 'session-1'; + const branchId = 'branch-A'; + const requestId = `request-for-${branchId}`; + const session = createMockSession(sessionId, branchId); + + sessionRegistry.set(sessionId, session); + + // Fire token usage event + tokenUsageEmitter.fire(createTokenUsage({ + sessionId, + requestId, + inputTokens: 1000, + outputTokens: 100 + })); + + // Verify branchTokens map is updated + const branchTokens = tokenTracker.getBranchTokensForSession(sessionId); + expect(branchTokens[branchId]).to.equal(1000); + }); + + it('should update branchTokens when token usage event is for active branch', () => { + const sessionId = 'session-2'; + const activeBranchId = 'branch-active'; + const requestId = `request-for-${activeBranchId}`; + const session = createMockSession(sessionId, activeBranchId); + + sessionRegistry.set(sessionId, session); + + // Fire token usage event for active branch + tokenUsageEmitter.fire(createTokenUsage({ + sessionId, + requestId, + inputTokens: 5000, + outputTokens: 200 + })); + + // Verify branchTokens was updated (which confirms the handler ran and processed active branch) + const branchTokens = tokenTracker.getBranchTokensForSession(sessionId); + expect(branchTokens[activeBranchId]).to.equal(5000); + }); + + it('should NOT trigger tracker reset for non-active branch but should store tokens', () => { + const sessionId = 'session-3'; + const activeBranchId = 'branch-B'; + const nonActiveBranchId = 'branch-A'; + const requestId = `request-for-${nonActiveBranchId}`; + // Active branch is B, but we fire event for branch A + const session = createMockSession(sessionId, activeBranchId, [ + { id: nonActiveBranchId }, + { id: activeBranchId } // Last element is active + ]); + + sessionRegistry.set(sessionId, session); + + const callCountBefore = tokenTracker.resetSessionTokens.callCount; + + // Fire token usage event for non-active branch + tokenUsageEmitter.fire(createTokenUsage({ + sessionId, + requestId, + inputTokens: 3000, + outputTokens: 150 + })); + + // Verify tokenTracker.resetSessionTokens was NOT called additionally + expect(tokenTracker.resetSessionTokens.callCount).to.equal(callCountBefore); + + // But branchTokens should be updated + const branchTokens = tokenTracker.getBranchTokensForSession(sessionId); + expect(branchTokens[nonActiveBranchId]).to.equal(3000); + }); + + it('should restore stored tokens when branch changes', () => { + const sessionId = 'session-4'; + const branchA = 'branch-A'; + const branchB = 'branch-B'; + + // Pre-populate branch tokens via tracker + tokenTracker.setBranchTokens(sessionId, branchA, 2000); + tokenTracker.setBranchTokens(sessionId, branchB, 4000); + + // Create session and fire created event to set up branch change listener + const modelChangeEmitter = new Emitter(); + const session = { + id: sessionId, + isActive: true, + model: { + getBranch: sinon.stub(), + getBranches: sinon.stub().returns([{ id: branchB }]), + getRequest: sinon.stub(), + onDidChange: modelChangeEmitter.event + } + } as unknown as ChatSession; + + sessionRegistry.set(sessionId, session); + + // Fire session created event to set up listener + sessionEventEmitter.fire({ type: 'created', sessionId }); + + // Simulate branch change to branch A + modelChangeEmitter.fire({ + kind: 'changeHierarchyBranch', + branch: { id: branchA } + }); + + // Verify tokenTracker.resetSessionTokens was called with branch A's tokens + expect(tokenTracker.resetSessionTokens.calledWith(sessionId, 2000)).to.be.true; + }); + + it('should emit undefined when switching to branch with no stored tokens', () => { + const sessionId = 'session-5'; + const unknownBranchId = 'branch-unknown'; + + // Create session without pre-populating tokens for the unknown branch + const modelChangeEmitter = new Emitter(); + const session = { + id: sessionId, + isActive: true, + model: { + getBranch: sinon.stub(), + getBranches: sinon.stub().returns([{ id: 'branch-other' }]), + getRequest: sinon.stub(), + onDidChange: modelChangeEmitter.event + } + } as unknown as ChatSession; + + sessionRegistry.set(sessionId, session); + + // Fire session created event to set up listener + sessionEventEmitter.fire({ type: 'created', sessionId }); + + // Simulate branch change to unknown branch + modelChangeEmitter.fire({ + kind: 'changeHierarchyBranch', + branch: { id: unknownBranchId } + }); + + // Verify tokenTracker.resetSessionTokens was called with undefined + expect(tokenTracker.resetSessionTokens.calledWith(sessionId, undefined)).to.be.true; + }); + + it('should only trigger threshold for active branch', async () => { + const sessionId = 'session-6'; + const activeBranchId = 'branch-active'; + const nonActiveBranchId = 'branch-other'; + + // Create session with two branches, active is the last one + const session = createMockSession(sessionId, activeBranchId, [ + { id: nonActiveBranchId }, + { id: activeBranchId } + ]); + + sessionRegistry.set(sessionId, session); + + // Spy on handleThresholdExceeded + const handleThresholdSpy = sinon.spy( + service as unknown as { handleThresholdExceeded: (event: { sessionId: string; inputTokens: number }) => Promise }, + 'handleThresholdExceeded' + ); + + // Fire token usage event exceeding threshold for NON-active branch + tokenUsageEmitter.fire(createTokenUsage({ + sessionId, + requestId: `request-for-${nonActiveBranchId}`, + inputTokens: CHAT_TOKEN_THRESHOLD + 10000, + outputTokens: 100 + })); + + // handleThresholdExceeded should NOT be called for non-active branch + expect(handleThresholdSpy.called).to.be.false; + + // Now fire for active branch + tokenUsageEmitter.fire(createTokenUsage({ + sessionId, + requestId: `request-for-${activeBranchId}`, + inputTokens: CHAT_TOKEN_THRESHOLD + 10000, + outputTokens: 100 + })); + + // handleThresholdExceeded SHOULD be called for active branch + expect(handleThresholdSpy.calledOnce).to.be.true; + expect(handleThresholdSpy.calledWith({ + sessionId, + inputTokens: CHAT_TOKEN_THRESHOLD + 10000 + })).to.be.true; + }); + + it('should remove all branch entries when session is deleted', () => { + const sessionId = 'session-to-delete'; + + // Pre-populate branch tokens via tracker and triggeredBranches + tokenTracker.setBranchTokens(sessionId, 'branch-A', 1000); + tokenTracker.setBranchTokens(sessionId, 'branch-B', 2000); + tokenTracker.setBranchTokens('other-session', 'branch-X', 5000); + + const triggeredBranchesSet = (service as unknown as { triggeredBranches: Set }).triggeredBranches; + triggeredBranchesSet.add(`${sessionId}:branch-A`); + triggeredBranchesSet.add(`${sessionId}:branch-B`); + triggeredBranchesSet.add('other-session:branch-X'); + + // Fire session deleted event + sessionEventEmitter.fire({ type: 'deleted', sessionId }); + + // Verify clearSessionBranchTokens was called + expect((tokenTracker.clearSessionBranchTokens as sinon.SinonStub).calledWith(sessionId)).to.be.true; + + // Verify triggeredBranches entries for deleted session are removed + expect(triggeredBranchesSet.has(`${sessionId}:branch-A`)).to.be.false; + expect(triggeredBranchesSet.has(`${sessionId}:branch-B`)).to.be.false; + + // Verify other session's triggeredBranches entries are preserved + expect(triggeredBranchesSet.has('other-session:branch-X')).to.be.true; + }); + + it('should populate branchTokens on persistence restore', () => { + const sessionId = 'restored-session'; + const activeBranchId = 'branch-restored'; + + // Create session + const modelChangeEmitter = new Emitter(); + const session = { + id: sessionId, + isActive: true, + model: { + getBranch: sinon.stub(), + getBranches: sinon.stub().returns([{ id: activeBranchId }]), + getRequest: sinon.stub(), + onDidChange: modelChangeEmitter.event + } + } as unknown as ChatSession; + + sessionRegistry.set(sessionId, session); + + // Fire session created event with branchTokens data + const branchTokensData = { + 'branch-restored': 8000, + 'branch-other': 3000 + }; + sessionEventEmitter.fire({ + type: 'created', + sessionId, + branchTokens: branchTokensData + }); + + // Verify restoreBranchTokens was called with correct data + expect((tokenTracker.restoreBranchTokens as sinon.SinonStub).calledWith(sessionId, branchTokensData)).to.be.true; + + // Verify getBranchTokensForSession returns restored data + const branchTokens = tokenTracker.getBranchTokensForSession(sessionId); + expect(branchTokens).to.deep.equal(branchTokensData); + }); + + it('should skip summary requests in token handler', () => { + const sessionId = 'session-7'; + const branchId = 'branch-A'; + const summaryRequestId = 'summary-request-for-branch-A'; + + // Create session where getRequest returns summary kind for specific request + const modelChangeEmitter = new Emitter(); + const session = { + id: sessionId, + isActive: true, + model: { + getBranch: sinon.stub().callsFake((requestId: string) => { + if (requestId === summaryRequestId) { + return { id: branchId }; + } + return undefined; + }), + getBranches: sinon.stub().returns([{ id: branchId }]), + getRequest: sinon.stub().callsFake((requestId: string) => { + if (requestId === summaryRequestId) { + return { request: { kind: 'summary' } }; + } + return { request: { kind: 'user' } }; + }), + onDidChange: modelChangeEmitter.event + } + } as unknown as ChatSession; + + sessionRegistry.set(sessionId, session); + + const callCountBefore = tokenTracker.resetSessionTokens.callCount; + + // Fire token usage event for summary request + tokenUsageEmitter.fire(createTokenUsage({ + sessionId, + requestId: summaryRequestId, + inputTokens: 5000, + outputTokens: 200 + })); + + // Verify branchTokens was NOT updated + const branchTokens = tokenTracker.getBranchTokensForSession(sessionId); + expect(branchTokens[branchId]).to.be.undefined; + + // Verify tokenTracker.resetSessionTokens was NOT called additionally + expect(tokenTracker.resetSessionTokens.callCount).to.equal(callCountBefore); + }); + + it('should handle cached input tokens correctly', () => { + const sessionId = 'session-8'; + const branchId = 'branch-A'; + const requestId = `request-for-${branchId}`; + const session = createMockSession(sessionId, branchId); + + sessionRegistry.set(sessionId, session); + + // Fire token usage event with cached tokens + tokenUsageEmitter.fire(createTokenUsage({ + sessionId, + requestId, + inputTokens: 1000, + cachedInputTokens: 500, + readCachedInputTokens: 200, + outputTokens: 100 + })); + + // Verify branchTokens includes all input token types + const branchTokens = tokenTracker.getBranchTokensForSession(sessionId); + expect(branchTokens[branchId]).to.equal(1700); // 1000 + 500 + 200 + }); + + it('should not update branchTokens when session is not found', () => { + const sessionId = 'non-existent-session'; + + // Don't add to sessionRegistry - session not found + + // Fire token usage event + tokenUsageEmitter.fire(createTokenUsage({ + sessionId, + requestId: 'some-request', + inputTokens: 1000, + outputTokens: 100 + })); + + // Verify branchTokens was NOT updated + const branchTokens = tokenTracker.getBranchTokensForSession(sessionId); + expect(branchTokens).to.deep.equal({}); + }); + + it('should not update branchTokens when branch is not found for request', () => { + const sessionId = 'session-9'; + const branchId = 'branch-A'; + const unknownRequestId = 'unknown-request'; + + // Create session where getBranch returns undefined for unknown request + const session = createMockSession(sessionId, branchId); + ((session.model as unknown as { getBranch: sinon.SinonStub }).getBranch).withArgs(unknownRequestId).returns(undefined); + + sessionRegistry.set(sessionId, session); + + const callCountBefore = tokenTracker.resetSessionTokens.callCount; + + // Fire token usage event for unknown request + tokenUsageEmitter.fire(createTokenUsage({ + sessionId, + requestId: unknownRequestId, + inputTokens: 1000, + outputTokens: 100 + })); + + // Verify branchTokens was NOT updated + const branchTokens = tokenTracker.getBranchTokensForSession(sessionId); + expect(branchTokens).to.deep.equal({}); + + // Verify tokenTracker.resetSessionTokens was NOT called additionally + expect(tokenTracker.resetSessionTokens.callCount).to.equal(callCountBefore); + }); + }); +}); diff --git a/packages/ai-chat/src/browser/chat-session-summarization-service.ts b/packages/ai-chat/src/browser/chat-session-summarization-service.ts index e2e9aff5da2c1..8f1a808317468 100644 --- a/packages/ai-chat/src/browser/chat-session-summarization-service.ts +++ b/packages/ai-chat/src/browser/chat-session-summarization-service.ts @@ -15,17 +15,20 @@ // ***************************************************************************** import { inject, injectable, postConstruct } from '@theia/core/shared/inversify'; -import { Disposable, ILogger } from '@theia/core'; +import { ILogger, nls } from '@theia/core'; import { FrontendApplicationContribution } from '@theia/core/lib/browser'; -import { AgentService, TokenUsageServiceClient } from '@theia/ai-core'; +import { TokenUsage, TokenUsageServiceClient } from '@theia/ai-core'; import { - ChatAgent, ChatService, + ChatSession, + ErrorChatResponseContent, + ErrorChatResponseContentImpl, MutableChatModel, MutableChatRequestModel } from '../common'; -import { ChatSessionSummaryAgent } from '../common/chat-session-summary-agent'; +import { isSessionCreatedEvent, isSessionDeletedEvent } from '../common/chat-service'; import { + CHAT_TOKEN_THRESHOLD, ChatSessionTokenTracker, SessionTokenThresholdEvent } from './chat-session-token-tracker'; @@ -55,6 +58,7 @@ export interface ChatSessionSummarizationService { * @returns Promise that resolves with the summary text on success, or `undefined` on failure */ triggerSummarization(sessionId: string, skipReorder: boolean): Promise; + } @injectable() @@ -65,9 +69,6 @@ export class ChatSessionSummarizationServiceImpl implements ChatSessionSummariza @inject(ChatService) protected readonly chatService: ChatService; - @inject(AgentService) - protected readonly agentService: AgentService; - @inject(ILogger) protected readonly logger: ILogger; @@ -79,9 +80,41 @@ export class ChatSessionSummarizationServiceImpl implements ChatSessionSummariza */ protected summarizingSession = new Set(); + /** + * Tracks which branches have triggered summarization. + * Key format: `${sessionId}:${branchId}` + * Used for deduplication (prevents multiple triggers for the same branch). + */ + protected triggeredBranches: Set = new Set(); + @postConstruct() protected init(): void { - this.tokenTracker.onThresholdExceeded(event => this.handleThresholdExceeded(event)); + // Listen to token usage events and attribute to correct branch + this.tokenUsageClient.onTokenUsageUpdated(usage => this.handleTokenUsage(usage)); + + // Listen for new sessions and set up branch change listeners + this.chatService.onSessionEvent(event => { + if (isSessionCreatedEvent(event)) { + const session = this.chatService.getSession(event.sessionId); + if (session) { + this.setupBranchChangeListener(session); + } + // Restore branch tokens from persisted data + if (event.branchTokens) { + this.tokenTracker.restoreBranchTokens(event.sessionId, event.branchTokens); + } + // Emit initial token count for active branch + if (session) { + const activeBranchId = this.getActiveBranchId(session); + if (activeBranchId) { + const tokens = this.tokenTracker.getBranchTokens(event.sessionId, activeBranchId); + this.tokenTracker.resetSessionTokens(event.sessionId, tokens); + } + } + } else if (isSessionDeletedEvent(event)) { + this.cleanupSession(event.sessionId); + } + }); } /** @@ -89,7 +122,72 @@ export class ChatSessionSummarizationServiceImpl implements ChatSessionSummariza * Required by FrontendApplicationContribution to ensure this service is instantiated. */ onStart(): void { - // Service initialization is handled in @postConstruct + // Set up branch change listeners for existing sessions + for (const session of this.chatService.getSessions()) { + this.setupBranchChangeListener(session); + } + } + + /** + * Get the active branch ID for a session. + */ + protected getActiveBranchId(session: ChatSession): string | undefined { + return session.model.getBranches().at(-1)?.id; + } + + /** + * Handle token usage events and attribute to correct branch. + */ + protected handleTokenUsage(usage: TokenUsage): void { + if (!usage.sessionId) { + return; + } + + const session = this.chatService.getSession(usage.sessionId); + if (!session) { + return; + } + + const model = session.model as MutableChatModel; + const branch = model.getBranch(usage.requestId); + if (!branch) { + this.logger.debug('Token event for unknown request', { sessionId: usage.sessionId, requestId: usage.requestId }); + return; + } + + // Skip summary requests - the per-summarization listener handles these + if (model.getRequest(usage.requestId)?.request.kind === 'summary') { + return; + } + + const totalInputTokens = usage.inputTokens + (usage.cachedInputTokens ?? 0) + (usage.readCachedInputTokens ?? 0); + this.tokenTracker.setBranchTokens(usage.sessionId, branch.id, totalInputTokens); + + const activeBranchId = this.getActiveBranchId(session); + + if (branch.id === activeBranchId) { + this.tokenTracker.resetSessionTokens(usage.sessionId, totalInputTokens); + // Check threshold for active branch only + const branchKey = `${usage.sessionId}:${branch.id}`; + if (totalInputTokens >= CHAT_TOKEN_THRESHOLD && !this.triggeredBranches.has(branchKey)) { + this.triggeredBranches.add(branchKey); + this.handleThresholdExceeded({ sessionId: usage.sessionId, inputTokens: totalInputTokens }); + } + } + } + + /** + * Set up a listener for branch changes in a chat session. + * When a branch change occurs (e.g., user edits an older message), reset token tracking. + */ + protected setupBranchChangeListener(session: ChatSession): void { + session.model.onDidChange(event => { + if (event.kind === 'changeHierarchyBranch') { + this.logger.info(`Branch changed in session ${session.id}, switching to branch ${event.branch.id}`); + const storedTokens = this.tokenTracker.getBranchTokens(session.id, event.branch.id); + this.tokenTracker.resetSessionTokens(session.id, storedTokens); + } + }); } async triggerSummarization(sessionId: string, skipReorder: boolean): Promise { @@ -135,73 +233,120 @@ export class ChatSessionSummarizationServiceImpl implements ChatSessionSummariza this.summarizingSession.add(sessionId); try { - const position = skipReorder ? 'end' : 'beforeLast'; + // Always use 'end' position - reordering breaks the hierarchy structure + // because the summary is added as continuation of the trigger request + const position = 'end'; + // eslint-disable-next-line max-len + const summaryPrompt = 'Please provide a concise summary of our conversation so far, capturing all key requirements, decisions, context, and pending tasks so we can seamlessly continue.'; const summaryText = await model.insertSummary( - async summaryRequest => { - // Find and invoke the summary agent - const agent = this.agentService.getAgents().find( - (candidate): candidate is ChatAgent => - 'invoke' in candidate && - typeof candidate.invoke === 'function' && - candidate.id === ChatSessionSummaryAgent.ID - ); - - if (!agent) { - this.logger.error('ChatSessionSummaryAgent not found'); + async () => { + const invocation = await this.chatService.sendRequest(sessionId, { + text: summaryPrompt, + kind: 'summary' + }); + if (!invocation) { return undefined; } - // Set up listener to capture token usage - let capturedInputTokens: number | undefined; - const tokenUsageListener: Disposable = this.tokenUsageClient.onTokenUsageUpdated(usage => { - if (usage.requestId === summaryRequest.id) { - capturedInputTokens = usage.inputTokens; + const request = await invocation.requestCompleted; + + // Set up token listener to capture output tokens + let capturedOutputTokens: number | undefined; + const tokenListener = this.tokenUsageClient.onTokenUsageUpdated(usage => { + if (usage.sessionId === sessionId && usage.requestId === request.id) { + capturedOutputTokens = usage.outputTokens; } }); try { - await agent.invoke(summaryRequest); - } finally { - tokenUsageListener.dispose(); - } + const response = await invocation.responseCompleted; - // Store captured tokens for later use - if (capturedInputTokens !== undefined) { - summaryRequest.addData('capturedInputTokens', capturedInputTokens); - } + // Validate response + const summaryResponseText = response.response.asDisplayString()?.trim(); + if (response.isError || !summaryResponseText) { + return undefined; + } - return summaryRequest.response.response.asDisplayString(); + // Store captured output tokens on request for later retrieval + if (capturedOutputTokens !== undefined) { + (request as MutableChatRequestModel).addData('capturedOutputTokens', capturedOutputTokens); + } + + return { + requestId: request.id, + summaryText: summaryResponseText + }; + } finally { + tokenListener.dispose(); + } }, position ); if (!summaryText) { this.logger.warn(`Summarization failed for session ${sessionId}`); + this.notifyUserOfFailure(model); return undefined; } this.logger.info(`Added summary node to session ${sessionId}`); - // Reset token count using captured tokens - const lastSummaryRequest = model.getRequests().find(r => r.request.kind === 'summary'); - const capturedTokens = lastSummaryRequest?.getDataByKey('capturedInputTokens'); - if (capturedTokens !== undefined) { - this.tokenTracker.resetSessionTokens(sessionId, capturedTokens); - this.tokenTracker.resetThresholdTrigger(sessionId); - this.logger.info(`Reset token count for session ${sessionId} to ${capturedTokens} tokens`); + // Find the summary request to get captured output tokens + const summaryRequest = model.getRequests().find(r => r.request.kind === 'summary'); + const outputTokens = summaryRequest?.getDataByKey('capturedOutputTokens') ?? 0; + + // Reset token count to the summary's output tokens (the new context size) + this.tokenTracker.resetSessionTokens(sessionId, outputTokens); + this.logger.info(`Reset token count for session ${sessionId} to ${outputTokens} after summarization`); + + // Update branch tokens and allow future re-triggering + const session = this.chatService.getSession(sessionId); + if (session) { + const activeBranchId = this.getActiveBranchId(session); + if (activeBranchId) { + const branchKey = `${sessionId}:${activeBranchId}`; + this.tokenTracker.setBranchTokens(sessionId, activeBranchId, outputTokens); + this.triggeredBranches.delete(branchKey); + } } return summaryText; } catch (error) { this.logger.error(`Failed to summarize session ${sessionId}:`, error); + this.notifyUserOfFailure(model); return undefined; } finally { this.summarizingSession.delete(sessionId); } } + /** + * Notify the user that summarization failed by adding an error message to the chat. + */ + protected notifyUserOfFailure(model: MutableChatModel): void { + const requests = model.getRequests(); + const currentRequest = requests.at(-1); + if (!currentRequest) { + return; + } + + // Avoid duplicate warnings + const lastContent = currentRequest.response.response.content.at(-1); + const alreadyWarned = ErrorChatResponseContent.is(lastContent) && + lastContent.error.message.includes('summarization'); + if (!alreadyWarned) { + const errorMessage = nls.localize( + 'theia/ai-chat/summarizationFailed', + 'Chat summarization failed. The conversation is approaching token limits and may fail soon. Consider starting a new chat session or reducing context to continue.' + ); + currentRequest.response.response.addContent( + new ErrorChatResponseContentImpl(new Error(errorMessage)) + ); + } + } + hasSummary(sessionId: string): boolean { const session = this.chatService.getSession(sessionId); if (!session) { @@ -212,4 +357,17 @@ export class ChatSessionSummarizationServiceImpl implements ChatSessionSummariza return model.getRequests().some(r => (r as MutableChatRequestModel).isStale); } + /** + * Clean up token tracking data when a session is deleted. + */ + protected cleanupSession(sessionId: string): void { + this.tokenTracker.clearSessionBranchTokens(sessionId); + const prefix = `${sessionId}:`; + for (const key of this.triggeredBranches.keys()) { + if (key.startsWith(prefix)) { + this.triggeredBranches.delete(key); + } + } + } + } diff --git a/packages/ai-chat/src/browser/chat-session-token-restore-contribution.ts b/packages/ai-chat/src/browser/chat-session-token-restore-contribution.ts deleted file mode 100644 index c366503e137e5..0000000000000 --- a/packages/ai-chat/src/browser/chat-session-token-restore-contribution.ts +++ /dev/null @@ -1,42 +0,0 @@ -// ***************************************************************************** -// Copyright (C) 2025 EclipseSource GmbH. -// -// This program and the accompanying materials are made available under the -// terms of the Eclipse Public License v. 2.0 which is available at -// http://www.eclipse.org/legal/epl-2.0. -// -// This Source Code may also be made available under the following Secondary -// Licenses when the conditions for such availability set forth in the Eclipse -// Public License v. 2.0 are satisfied: GNU General Public License, version 2 -// with the GNU Classpath Exception which is available at -// https://www.gnu.org/software/classpath/license.html. -// -// SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-only WITH Classpath-exception-2.0 -// ***************************************************************************** - -import { inject, injectable } from '@theia/core/shared/inversify'; -import { FrontendApplicationContribution } from '@theia/core/lib/browser'; -import { ChatService, isSessionCreatedEvent } from '../common/chat-service'; -import { ChatSessionTokenTracker } from '../common/chat-session-token-tracker'; - -/** - * Contribution that wires ChatService session events to the token tracker. - * This breaks the circular dependency between ChatService and ChatSessionTokenTracker - * by deferring the wiring until after both services are fully constructed. - */ -@injectable() -export class ChatSessionTokenRestoreContribution implements FrontendApplicationContribution { - @inject(ChatService) - protected readonly chatService: ChatService; - - @inject(ChatSessionTokenTracker) - protected readonly tokenTracker: ChatSessionTokenTracker; - - onStart(): void { - this.chatService.onSessionEvent(event => { - if (isSessionCreatedEvent(event) && event.tokenCount !== undefined) { - this.tokenTracker.resetSessionTokens(event.sessionId, event.tokenCount); - } - }); - } -} diff --git a/packages/ai-chat/src/browser/chat-session-token-tracker.spec.ts b/packages/ai-chat/src/browser/chat-session-token-tracker.spec.ts index 2adababb48c65..a08d4c7624251 100644 --- a/packages/ai-chat/src/browser/chat-session-token-tracker.spec.ts +++ b/packages/ai-chat/src/browser/chat-session-token-tracker.spec.ts @@ -15,193 +15,130 @@ // ***************************************************************************** import { expect } from 'chai'; -import * as sinon from 'sinon'; import { Container } from '@theia/core/shared/inversify'; -import { Emitter } from '@theia/core'; -import { TokenUsageServiceClient, TokenUsage } from '@theia/ai-core/lib/common'; -import { ChatSessionTokenTrackerImpl, CHAT_TOKEN_THRESHOLD } from './chat-session-token-tracker'; -import { SessionTokenThresholdEvent, SessionTokenUpdateEvent } from '../common/chat-session-token-tracker'; +import { ChatSessionTokenTrackerImpl } from './chat-session-token-tracker'; +import { SessionTokenUpdateEvent } from '../common/chat-session-token-tracker'; describe('ChatSessionTokenTrackerImpl', () => { let container: Container; let tracker: ChatSessionTokenTrackerImpl; - let mockTokenUsageEmitter: Emitter; - let mockTokenUsageClient: TokenUsageServiceClient; - - const createTokenUsage = (sessionId: string | undefined, inputTokens: number, requestId: string): TokenUsage => ({ - sessionId, - inputTokens, - outputTokens: 100, - requestId, - model: 'test-model', - timestamp: new Date() - }); beforeEach(() => { container = new Container(); - - // Create a mock TokenUsageServiceClient with controllable event emitter - mockTokenUsageEmitter = new Emitter(); - mockTokenUsageClient = { - notifyTokenUsage: sinon.stub(), - onTokenUsageUpdated: mockTokenUsageEmitter.event - }; - - // Bind dependencies - container.bind(TokenUsageServiceClient).toConstantValue(mockTokenUsageClient); container.bind(ChatSessionTokenTrackerImpl).toSelf().inSingletonScope(); - tracker = container.get(ChatSessionTokenTrackerImpl); }); - afterEach(() => { - mockTokenUsageEmitter.dispose(); - sinon.restore(); - }); - describe('getSessionInputTokens', () => { - it('should return correct token count after usage is reported', () => { - const sessionId = 'session-1'; - const inputTokens = 5000; - - mockTokenUsageEmitter.fire(createTokenUsage(sessionId, inputTokens, 'request-1')); - - expect(tracker.getSessionInputTokens(sessionId)).to.equal(inputTokens); - }); - it('should return undefined for unknown session', () => { expect(tracker.getSessionInputTokens('unknown-session')).to.be.undefined; }); }); - describe('onThresholdExceeded', () => { - it('should fire when tokens exceed threshold', () => { + describe('resetSessionTokens', () => { + it('should update token count and fire onSessionTokensUpdated', () => { const sessionId = 'session-1'; - const inputTokens = CHAT_TOKEN_THRESHOLD + 1000; - const thresholdEvents: SessionTokenThresholdEvent[] = []; - - tracker.onThresholdExceeded(event => thresholdEvents.push(event)); - - mockTokenUsageEmitter.fire(createTokenUsage(sessionId, inputTokens, 'request-1')); + const updateEvents: SessionTokenUpdateEvent[] = []; - expect(thresholdEvents).to.have.length(1); - expect(thresholdEvents[0].sessionId).to.equal(sessionId); - expect(thresholdEvents[0].inputTokens).to.equal(inputTokens); - }); + tracker.onSessionTokensUpdated(event => updateEvents.push(event)); - it('should not fire when tokens are below threshold', () => { - const sessionId = 'session-1'; - const inputTokens = CHAT_TOKEN_THRESHOLD - 1000; - const thresholdEvents: SessionTokenThresholdEvent[] = []; + // Set initial token count via resetSessionTokens + tracker.resetSessionTokens(sessionId, 50000); - tracker.onThresholdExceeded(event => thresholdEvents.push(event)); + expect(tracker.getSessionInputTokens(sessionId)).to.equal(50000); + expect(updateEvents).to.have.length(1); + expect(updateEvents[0].sessionId).to.equal(sessionId); + expect(updateEvents[0].inputTokens).to.equal(50000); - mockTokenUsageEmitter.fire(createTokenUsage(sessionId, inputTokens, 'request-1')); + // Reset to new baseline (simulating post-summarization) + const newTokenCount = 10000; + tracker.resetSessionTokens(sessionId, newTokenCount); - expect(thresholdEvents).to.have.length(0); + expect(tracker.getSessionInputTokens(sessionId)).to.equal(newTokenCount); + expect(updateEvents).to.have.length(2); + expect(updateEvents[1].sessionId).to.equal(sessionId); + expect(updateEvents[1].inputTokens).to.equal(newTokenCount); }); - it('should not fire twice for the same session without reset', () => { + it('should delete token count and emit undefined when called with undefined', () => { const sessionId = 'session-1'; - const thresholdEvents: SessionTokenThresholdEvent[] = []; + const updateEvents: SessionTokenUpdateEvent[] = []; - tracker.onThresholdExceeded(event => thresholdEvents.push(event)); + tracker.onSessionTokensUpdated(event => updateEvents.push(event)); + + // Set initial token count via resetSessionTokens + tracker.resetSessionTokens(sessionId, 50000); - // First token usage exceeding threshold - mockTokenUsageEmitter.fire(createTokenUsage(sessionId, CHAT_TOKEN_THRESHOLD + 1000, 'request-1')); + expect(tracker.getSessionInputTokens(sessionId)).to.equal(50000); + expect(updateEvents).to.have.length(1); - // Second token usage exceeding threshold (should not trigger again) - mockTokenUsageEmitter.fire(createTokenUsage(sessionId, CHAT_TOKEN_THRESHOLD + 2000, 'request-2')); + // Reset to undefined (simulating switch to branch with no prior LLM requests) + tracker.resetSessionTokens(sessionId, undefined); - expect(thresholdEvents).to.have.length(1); + expect(tracker.getSessionInputTokens(sessionId)).to.be.undefined; + expect(updateEvents).to.have.length(2); + expect(updateEvents[1].sessionId).to.equal(sessionId); + expect(updateEvents[1].inputTokens).to.be.undefined; }); }); - describe('resetThresholdTrigger', () => { - it('should allow re-triggering after resetThresholdTrigger is called', () => { + describe('branch token methods', () => { + it('should set and get branch tokens', () => { const sessionId = 'session-1'; - const thresholdEvents: SessionTokenThresholdEvent[] = []; - - tracker.onThresholdExceeded(event => thresholdEvents.push(event)); - - // First token usage exceeding threshold - mockTokenUsageEmitter.fire(createTokenUsage(sessionId, CHAT_TOKEN_THRESHOLD + 1000, 'request-1')); + const branchId = 'branch-1'; - expect(thresholdEvents).to.have.length(1); + expect(tracker.getBranchTokens(sessionId, branchId)).to.be.undefined; - // Reset the threshold trigger (simulating summarization completion) - tracker.resetThresholdTrigger(sessionId); + tracker.setBranchTokens(sessionId, branchId, 5000); - // Second token usage exceeding threshold should trigger again - mockTokenUsageEmitter.fire(createTokenUsage(sessionId, CHAT_TOKEN_THRESHOLD + 3000, 'request-2')); - - expect(thresholdEvents).to.have.length(2); - expect(thresholdEvents[0].sessionId).to.equal(sessionId); - expect(thresholdEvents[1].sessionId).to.equal(sessionId); + expect(tracker.getBranchTokens(sessionId, branchId)).to.equal(5000); }); - it('should not affect other sessions', () => { - const sessionId1 = 'session-1'; - const sessionId2 = 'session-2'; - const thresholdEvents: SessionTokenThresholdEvent[] = []; - - tracker.onThresholdExceeded(event => thresholdEvents.push(event)); - - // Trigger threshold for session 1 - mockTokenUsageEmitter.fire(createTokenUsage(sessionId1, CHAT_TOKEN_THRESHOLD + 1000, 'request-1')); - - // Trigger threshold for session 2 - mockTokenUsageEmitter.fire(createTokenUsage(sessionId2, CHAT_TOKEN_THRESHOLD + 1000, 'request-2')); - - expect(thresholdEvents).to.have.length(2); + it('should get all branch tokens for a session', () => { + const sessionId = 'session-1'; - // Reset only session 1 - tracker.resetThresholdTrigger(sessionId1); + tracker.setBranchTokens(sessionId, 'branch-1', 1000); + tracker.setBranchTokens(sessionId, 'branch-2', 2000); + tracker.setBranchTokens('other-session', 'branch-3', 3000); - // Session 1 should be able to trigger again - mockTokenUsageEmitter.fire(createTokenUsage(sessionId1, CHAT_TOKEN_THRESHOLD + 2000, 'request-3')); + const result = tracker.getBranchTokensForSession(sessionId); - // Session 2 should not trigger again (not reset) - mockTokenUsageEmitter.fire(createTokenUsage(sessionId2, CHAT_TOKEN_THRESHOLD + 2000, 'request-4')); + expect(result).to.deep.equal({ + 'branch-1': 1000, + 'branch-2': 2000 + }); + }); - expect(thresholdEvents).to.have.length(3); - expect(thresholdEvents[2].sessionId).to.equal(sessionId1); + it('should return empty object when no branch tokens exist for session', () => { + const result = tracker.getBranchTokensForSession('unknown-session'); + expect(result).to.deep.equal({}); }); - }); - describe('resetSessionTokens', () => { - it('should update token count and fire onSessionTokensUpdated', () => { + it('should restore branch tokens from persisted data', () => { const sessionId = 'session-1'; - const updateEvents: SessionTokenUpdateEvent[] = []; - - tracker.onSessionTokensUpdated(event => updateEvents.push(event)); - - // Set initial token count - mockTokenUsageEmitter.fire(createTokenUsage(sessionId, 50000, 'request-1')); - - expect(tracker.getSessionInputTokens(sessionId)).to.equal(50000); - expect(updateEvents).to.have.length(1); + const branchTokens = { + 'branch-1': 1000, + 'branch-2': 2000 + }; - // Reset to new baseline (simulating post-summarization) - const newTokenCount = 10000; - tracker.resetSessionTokens(sessionId, newTokenCount); + tracker.restoreBranchTokens(sessionId, branchTokens); - expect(tracker.getSessionInputTokens(sessionId)).to.equal(newTokenCount); - expect(updateEvents).to.have.length(2); - expect(updateEvents[1].sessionId).to.equal(sessionId); - expect(updateEvents[1].inputTokens).to.equal(newTokenCount); + expect(tracker.getBranchTokens(sessionId, 'branch-1')).to.equal(1000); + expect(tracker.getBranchTokens(sessionId, 'branch-2')).to.equal(2000); }); - }); - describe('token usage handling', () => { - it('should ignore token usage without sessionId', () => { - const updateEvents: SessionTokenUpdateEvent[] = []; + it('should clear all branch tokens for a session', () => { + const sessionId = 'session-1'; - tracker.onSessionTokensUpdated(event => updateEvents.push(event)); + tracker.setBranchTokens(sessionId, 'branch-1', 1000); + tracker.setBranchTokens(sessionId, 'branch-2', 2000); + tracker.setBranchTokens('other-session', 'branch-3', 3000); - mockTokenUsageEmitter.fire(createTokenUsage(undefined, 5000, 'request-1')); + tracker.clearSessionBranchTokens(sessionId); - expect(updateEvents).to.have.length(0); + expect(tracker.getBranchTokens(sessionId, 'branch-1')).to.be.undefined; + expect(tracker.getBranchTokens(sessionId, 'branch-2')).to.be.undefined; + expect(tracker.getBranchTokens('other-session', 'branch-3')).to.equal(3000); }); }); }); diff --git a/packages/ai-chat/src/browser/chat-session-token-tracker.ts b/packages/ai-chat/src/browser/chat-session-token-tracker.ts index 1bc86679bebaf..018b2d987cedd 100644 --- a/packages/ai-chat/src/browser/chat-session-token-tracker.ts +++ b/packages/ai-chat/src/browser/chat-session-token-tracker.ts @@ -14,13 +14,20 @@ // SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-only WITH Classpath-exception-2.0 // ***************************************************************************** -import { inject, injectable, postConstruct } from '@theia/core/shared/inversify'; -import { TokenUsageServiceClient, TokenUsage } from '@theia/ai-core/lib/common'; +import { injectable } from '@theia/core/shared/inversify'; import { Emitter } from '@theia/core'; -import { ChatSessionTokenTracker, SessionTokenThresholdEvent, SessionTokenUpdateEvent } from '../common/chat-session-token-tracker'; +import { ChatSessionTokenTracker, SessionTokenUpdateEvent } from '../common/chat-session-token-tracker'; // Re-export from common for backwards compatibility -export { ChatSessionTokenTracker, SessionTokenUpdateEvent, SessionTokenThresholdEvent } from '../common/chat-session-token-tracker'; +export { ChatSessionTokenTracker, SessionTokenUpdateEvent } from '../common/chat-session-token-tracker'; + +/** + * Event fired when a session's token usage crosses the threshold. + */ +export interface SessionTokenThresholdEvent { + sessionId: string; + inputTokens: number; +} /** * Hardcoded token budget and threshold for chat sessions. @@ -31,12 +38,6 @@ export const CHAT_TOKEN_THRESHOLD = CHAT_TOKEN_BUDGET * CHAT_TOKEN_THRESHOLD_PER @injectable() export class ChatSessionTokenTrackerImpl implements ChatSessionTokenTracker { - @inject(TokenUsageServiceClient) - protected readonly tokenUsageClient: TokenUsageServiceClient; - - protected readonly onThresholdExceededEmitter = new Emitter(); - readonly onThresholdExceeded = this.onThresholdExceededEmitter.event; - protected readonly onSessionTokensUpdatedEmitter = new Emitter(); readonly onSessionTokensUpdated = this.onSessionTokensUpdatedEmitter.event; @@ -47,59 +48,64 @@ export class ChatSessionTokenTrackerImpl implements ChatSessionTokenTracker { protected sessionTokens = new Map(); /** - * Set of sessionIds that have already triggered the threshold event. - * Prevents multiple triggers for the same session. + * Map of branch tokens. Key format: `${sessionId}:${branchId}` */ - protected triggeredSessions = new Set(); + protected branchTokens = new Map(); - @postConstruct() - protected init(): void { - this.tokenUsageClient.onTokenUsageUpdated(usage => this.handleTokenUsage(usage)); + getSessionInputTokens(sessionId: string): number | undefined { + return this.sessionTokens.get(sessionId); } - protected handleTokenUsage(usage: TokenUsage): void { - const { sessionId, inputTokens } = usage; - - if (!sessionId) { - return; // Can't track without sessionId + /** + * Reset the session's token count to a new baseline. + * Called after summarization to reflect the reduced token usage. + * The new count should reflect only the summary + any non-stale messages. + * + * @param sessionId - The session ID to reset + * @param newTokenCount - The new token count, or `undefined` to indicate unknown state. + * When `undefined`, deletes the stored count and emits `{ inputTokens: undefined }`. + */ + resetSessionTokens(sessionId: string, newTokenCount: number | undefined): void { + if (newTokenCount === undefined) { + this.sessionTokens.delete(sessionId); + } else { + this.sessionTokens.set(sessionId, newTokenCount); } + this.onSessionTokensUpdatedEmitter.fire({ sessionId, inputTokens: newTokenCount }); + } - // Update the session's token count - this.sessionTokens.set(sessionId, inputTokens); - - // Fire the token update event - this.onSessionTokensUpdatedEmitter.fire({ sessionId, inputTokens }); + setBranchTokens(sessionId: string, branchId: string, tokens: number): void { + this.branchTokens.set(`${sessionId}:${branchId}`, tokens); + } - // Check if threshold is exceeded and we haven't already triggered - if (inputTokens >= CHAT_TOKEN_THRESHOLD && !this.triggeredSessions.has(sessionId)) { - this.triggeredSessions.add(sessionId); - this.onThresholdExceededEmitter.fire({ - sessionId, - inputTokens - }); - } + getBranchTokens(sessionId: string, branchId: string): number | undefined { + return this.branchTokens.get(`${sessionId}:${branchId}`); } - getSessionInputTokens(sessionId: string): number | undefined { - return this.sessionTokens.get(sessionId); + getBranchTokensForSession(sessionId: string): { [branchId: string]: number } { + const result: { [branchId: string]: number } = {}; + const prefix = `${sessionId}:`; + for (const [key, value] of this.branchTokens.entries()) { + if (key.startsWith(prefix)) { + const branchId = key.substring(prefix.length); + result[branchId] = value; + } + } + return result; } - /** - * Reset the triggered state for a session. - * Called after summarization is complete to allow future triggers - * if the session continues to grow. - */ - resetThresholdTrigger(sessionId: string): void { - this.triggeredSessions.delete(sessionId); + restoreBranchTokens(sessionId: string, branchTokens: { [branchId: string]: number }): void { + for (const [branchId, tokens] of Object.entries(branchTokens)) { + this.branchTokens.set(`${sessionId}:${branchId}`, tokens); + } } - /** - * Reset the session's token count to a new baseline. - * Called after summarization to reflect the reduced token usage. - * The new count should reflect only the summary + any non-stale messages. - */ - resetSessionTokens(sessionId: string, newTokenCount: number): void { - this.sessionTokens.set(sessionId, newTokenCount); - this.onSessionTokensUpdatedEmitter.fire({ sessionId, inputTokens: newTokenCount }); + clearSessionBranchTokens(sessionId: string): void { + const prefix = `${sessionId}:`; + for (const key of this.branchTokens.keys()) { + if (key.startsWith(prefix)) { + this.branchTokens.delete(key); + } + } } } diff --git a/packages/ai-chat/src/common/chat-agents.ts b/packages/ai-chat/src/common/chat-agents.ts index 773512f90ad6d..ed1afaab9fec2 100644 --- a/packages/ai-chat/src/common/chat-agents.ts +++ b/packages/ai-chat/src/common/chat-agents.ts @@ -297,6 +297,7 @@ export abstract class AbstractChatAgent implements ChatAgent { } const messages: LanguageModelMessage[] = []; + const text = request.message.parts.map(part => part.promptText).join(''); if (text.length > 0) { messages.push({ diff --git a/packages/ai-chat/src/common/chat-model-hierarchy.spec.ts b/packages/ai-chat/src/common/chat-model-hierarchy.spec.ts new file mode 100644 index 0000000000000..59d150d9fb23a --- /dev/null +++ b/packages/ai-chat/src/common/chat-model-hierarchy.spec.ts @@ -0,0 +1,165 @@ +// ***************************************************************************** +// Copyright (C) 2025 EclipseSource GmbH. +// +// This program and the accompanying materials are made available under the +// terms of the Eclipse Public License v. 2.0 which is available at +// http://www.eclipse.org/legal/epl-2.0. +// +// This Source Code may also be made available under the following Secondary +// Licenses when the conditions for such availability set forth in the Eclipse +// Public License v. 2.0 are satisfied: GNU General Public License, version 2 +// with the GNU Classpath Exception which is available at +// https://www.gnu.org/software/classpath/license.html. +// +// SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-only WITH Classpath-exception-2.0 +// ***************************************************************************** + +import { expect } from 'chai'; +import { ChatAgentLocation } from './chat-agents'; +import { MutableChatModel } from './chat-model'; +import { ParsedChatRequest } from './parsed-chat-request'; + +describe('ChatRequestHierarchyBranchImpl', () => { + + function createParsedRequest(text: string): ParsedChatRequest { + return { + request: { text }, + parts: [{ + kind: 'text', + text, + promptText: text, + range: { start: 0, endExclusive: text.length } + }], + toolRequests: new Map(), + variables: [] + }; + } + + describe('remove()', () => { + it('should not fire onDidChange when removing the last item from a branch', () => { + const model = new MutableChatModel(ChatAgentLocation.Panel); + const request = model.addRequest(createParsedRequest('Single request')); + request.response.complete(); + + const branch = model.getBranch(request.id); + expect(branch).to.not.be.undefined; + + let changeEventFired = false; + model.onDidChange(event => { + if (event.kind === 'changeHierarchyBranch') { + changeEventFired = true; + } + }); + + branch!.remove(request); + + expect(changeEventFired).to.be.false; + expect(branch!.items.length).to.equal(0); + }); + + it('should fire onDidChange when removing a non-last item from a branch', () => { + const model = new MutableChatModel(ChatAgentLocation.Panel); + const request1 = model.addRequest(createParsedRequest('First request')); + request1.response.complete(); + + // Add second request as an alternative in the same branch + const branch = model.getBranch(request1.id); + expect(branch).to.not.be.undefined; + const request2 = model.addRequest({ + ...createParsedRequest('Second request'), + request: { + text: 'Second request', + referencedRequestId: request1.id + } + }); + request2.response.complete(); + + let changeEventFired = false; + model.onDidChange(event => { + if (event.kind === 'changeHierarchyBranch') { + changeEventFired = true; + } + }); + + branch!.remove(request1); + + expect(changeEventFired).to.be.true; + expect(branch!.items.length).to.equal(1); + }); + + it('should set activeBranchIndex to -1 when branch becomes empty', () => { + const model = new MutableChatModel(ChatAgentLocation.Panel); + const request = model.addRequest(createParsedRequest('Single request')); + request.response.complete(); + + const branch = model.getBranch(request.id); + expect(branch).to.not.be.undefined; + expect(branch!.activeBranchIndex).to.equal(0); + + branch!.remove(request); + + expect(branch!.activeBranchIndex).to.equal(-1); + expect(branch!.items.length).to.equal(0); + }); + + it('should correctly adjust activeBranchIndex when removing item before active', () => { + const model = new MutableChatModel(ChatAgentLocation.Panel); + const request1 = model.addRequest(createParsedRequest('First request')); + request1.response.complete(); + + // Add second request as an alternative in the same branch + const branch = model.getBranch(request1.id); + expect(branch).to.not.be.undefined; + + const request2 = model.addRequest({ + ...createParsedRequest('Second request'), + request: { + text: 'Second request', + referencedRequestId: request1.id + } + }); + request2.response.complete(); + + // After adding request2, it should be active (index 1) + expect(branch!.activeBranchIndex).to.equal(1); + + branch!.remove(request1); + + // After removing request1 (index 0), active index should be adjusted to 0 + expect(branch!.activeBranchIndex).to.equal(0); + expect(branch!.items.length).to.equal(1); + expect(branch!.get().id).to.equal(request2.id); + }); + }); + + describe('get()', () => { + it('should throw meaningful error when called on empty branch', () => { + const model = new MutableChatModel(ChatAgentLocation.Panel); + const request = model.addRequest(createParsedRequest('Single request')); + request.response.complete(); + + const branch = model.getBranch(request.id); + expect(branch).to.not.be.undefined; + + // Remove the request to make branch empty + branch!.remove(request); + expect(branch!.items.length).to.equal(0); + + // get() should throw meaningful error instead of crashing + expect(() => branch!.get()).to.throw('Cannot get request from empty branch'); + }); + + it('should return request when branch has items', () => { + const model = new MutableChatModel(ChatAgentLocation.Panel); + const request = model.addRequest(createParsedRequest('Test request')); + request.response.complete(); + + const branch = model.getBranch(request.id); + expect(branch).to.not.be.undefined; + expect(branch!.items.length).to.equal(1); + + // get() should return the request + expect(branch!.get().id).to.equal(request.id); + }); + }); +}); diff --git a/packages/ai-chat/src/common/chat-model-insert-summary.spec.ts b/packages/ai-chat/src/common/chat-model-insert-summary.spec.ts index c4fdaea379063..d0fb28bedb9ff 100644 --- a/packages/ai-chat/src/common/chat-model-insert-summary.spec.ts +++ b/packages/ai-chat/src/common/chat-model-insert-summary.spec.ts @@ -16,14 +16,14 @@ import { expect } from 'chai'; import { ChatAgentLocation } from './chat-agents'; -import { ChatResponseContent, MutableChatModel, MutableChatRequestModel, SummaryChatResponseContent, TextChatResponseContentImpl } from './chat-model'; +import { MutableChatModel, SummaryChatResponseContent, SummaryChatResponseContentImpl } from './chat-model'; import { ParsedChatRequest } from './parsed-chat-request'; describe('MutableChatModel.insertSummary()', () => { - function createParsedRequest(text: string): ParsedChatRequest { + function createParsedRequest(text: string, kind?: 'user' | 'summary'): ParsedChatRequest { return { - request: { text }, + request: { text, kind }, parts: [{ kind: 'text', text, @@ -44,13 +44,32 @@ describe('MutableChatModel.insertSummary()', () => { return model; } + /** + * Helper to create a summary callback that simulates ChatService.sendRequest(). + * It creates the summary request directly on the model (as sendRequest would do internally) + * and returns the expected result structure. + */ + function createSummaryCallback(model: MutableChatModel, summaryText: string): () => Promise<{ requestId: string; summaryText: string } | undefined> { + return async () => { + // Simulate what ChatService.sendRequest() would do: create a request on the model + const summaryRequest = model.addRequest(createParsedRequest(summaryText, 'summary')); + // Add the summary content to the response + summaryRequest.response.response.addContent(new SummaryChatResponseContentImpl(summaryText)); + summaryRequest.response.complete(); + return { + requestId: summaryRequest.id, + summaryText + }; + }; + } + describe('basic functionality', () => { it('should return undefined when model has less than 2 requests', async () => { const model = new MutableChatModel(ChatAgentLocation.Panel); model.addRequest(createParsedRequest('Single request')); const result = await model.insertSummary( - async () => 'Summary text', + async () => ({ requestId: 'test-id', summaryText: 'Summary text' }), 'end' ); @@ -61,7 +80,7 @@ describe('MutableChatModel.insertSummary()', () => { const model = new MutableChatModel(ChatAgentLocation.Panel); const result = await model.insertSummary( - async () => 'Summary text', + async () => ({ requestId: 'test-id', summaryText: 'Summary text' }), 'end' ); @@ -72,7 +91,7 @@ describe('MutableChatModel.insertSummary()', () => { const model = createModelWithRequests(3); const result = await model.insertSummary( - async () => 'This is a summary', + createSummaryCallback(model, 'This is a summary'), 'end' ); @@ -85,7 +104,7 @@ describe('MutableChatModel.insertSummary()', () => { const model = createModelWithRequests(3); await model.insertSummary( - async () => 'Summary text', + createSummaryCallback(model, 'Summary text'), 'end' ); @@ -99,7 +118,7 @@ describe('MutableChatModel.insertSummary()', () => { const model = createModelWithRequests(3); await model.insertSummary( - async () => 'Summary text', + createSummaryCallback(model, 'Summary text'), 'end' ); @@ -116,7 +135,7 @@ describe('MutableChatModel.insertSummary()', () => { const model = createModelWithRequests(2); await model.insertSummary( - async () => 'The conversation summary', + createSummaryCallback(model, 'The conversation summary'), 'end' ); @@ -130,62 +149,8 @@ describe('MutableChatModel.insertSummary()', () => { }); }); - describe('position: beforeLast', () => { - it('should insert summary before the last request', async () => { - const model = createModelWithRequests(3); - const lastRequestId = model.getRequests()[2].id; - - await model.insertSummary( - async () => 'Summary text', - 'beforeLast' - ); - - const requests = model.getRequests(); - // Should have 4 requests: 3 original + 1 summary - expect(requests).to.have.lengthOf(4); - // Summary should be at index 2, original last request at index 3 - expect(requests[2].request.kind).to.equal('summary'); - expect(requests[3].id).to.equal(lastRequestId); - }); - - it('should preserve the trigger request identity (same object)', async () => { - const model = createModelWithRequests(3); - const originalLastRequest = model.getRequests()[2]; - const originalId = originalLastRequest.id; - - await model.insertSummary( - async () => 'Summary text', - 'beforeLast' - ); - - const readdedRequest = model.getRequests()[3]; - // Should be the exact same object - expect(readdedRequest.id).to.equal(originalId); - }); - - it('should mark all requests except trigger as stale', async () => { - const model = createModelWithRequests(3); - const triggerRequestId = model.getRequests()[2].id; - - await model.insertSummary( - async () => 'Summary text', - 'beforeLast' - ); - - const requests = model.getRequests(); - // Requests 1-2 (indices 0-1) should be stale - expect(requests[0].isStale).to.be.true; - expect(requests[1].isStale).to.be.true; - // Summary request (index 2) should not be stale - expect(requests[2].isStale).to.be.false; - // Trigger request (index 3) should not be stale - expect(requests[3].isStale).to.be.false; - expect(requests[3].id).to.equal(triggerRequestId); - }); - }); - describe('callback failure handling', () => { - it('should rollback on callback returning undefined (end position)', async () => { + it('should return undefined on callback returning undefined (end position)', async () => { const model = createModelWithRequests(3); const originalRequestCount = model.getRequests().length; @@ -195,15 +160,15 @@ describe('MutableChatModel.insertSummary()', () => { ); expect(result).to.be.undefined; - // Model should be unchanged + // Model should be unchanged - callback didn't create any request expect(model.getRequests()).to.have.lengthOf(originalRequestCount); - // Stale flags should be restored + // Stale flags should remain unchanged model.getRequests().forEach(r => { expect(r.isStale).to.be.false; }); }); - it('should rollback on callback throwing error (end position)', async () => { + it('should return undefined on callback throwing error (end position)', async () => { const model = createModelWithRequests(3); const originalRequestCount = model.getRequests().length; @@ -213,85 +178,55 @@ describe('MutableChatModel.insertSummary()', () => { ); expect(result).to.be.undefined; - // Model should be unchanged + // Model should be unchanged - callback didn't create any request before throwing expect(model.getRequests()).to.have.lengthOf(originalRequestCount); - // Stale flags should be restored - model.getRequests().forEach(r => { - expect(r.isStale).to.be.false; - }); - }); - - it('should rollback on callback failure (beforeLast position)', async () => { - const model = createModelWithRequests(3); - const originalRequestIds = model.getRequests().map(r => r.id); - - const result = await model.insertSummary( - async () => undefined, - 'beforeLast' - ); - - expect(result).to.be.undefined; - // Should have same requests in same order - const currentRequestIds = model.getRequests().map(r => r.id); - expect(currentRequestIds).to.deep.equal(originalRequestIds); - // Stale flags should be restored + // Stale flags should remain unchanged model.getRequests().forEach(r => { expect(r.isStale).to.be.false; }); }); - it('should restore trigger request on failure (beforeLast position)', async () => { - const model = createModelWithRequests(3); - const originalLastRequestId = model.getRequests()[2].id; - - const result = await model.insertSummary( - async () => { throw new Error('Agent failed'); }, - 'beforeLast' - ); - - expect(result).to.be.undefined; - // Trigger request should be back in position - const requests = model.getRequests(); - expect(requests).to.have.lengthOf(3); - expect(requests[2].id).to.equal(originalLastRequestId); - }); }); - describe('callback receives correct summaryRequest', () => { - it('should pass a valid MutableChatRequestModel to callback', async () => { + describe('callback creates request via model', () => { + it('should find created request by requestId after callback returns', async () => { const model = createModelWithRequests(2); - let receivedRequest: MutableChatRequestModel | undefined; + let createdRequestId: string | undefined; await model.insertSummary( - async summaryRequest => { - receivedRequest = summaryRequest; - return 'Summary'; + async () => { + // Simulate ChatService.sendRequest() creating a request + const createdSummaryRequest = model.addRequest(createParsedRequest('Summary', 'summary')); + createdSummaryRequest.response.response.addContent(new SummaryChatResponseContentImpl('Summary')); + createdSummaryRequest.response.complete(); + createdRequestId = createdSummaryRequest.id; + return { + requestId: createdSummaryRequest.id, + summaryText: 'Summary' + }; }, 'end' ); - expect(receivedRequest).to.not.be.undefined; - expect(receivedRequest!.request.kind).to.equal('summary'); - expect(receivedRequest!.response).to.not.be.undefined; + // The summary request should be findable in the model + const summaryRequest = model.getRequests().find(r => r.id === createdRequestId); + expect(summaryRequest).to.not.be.undefined; + expect(summaryRequest!.request.kind).to.equal('summary'); }); - it('should allow callback to use summaryRequest for agent invocation', async () => { + it('should return undefined if requestId references non-existent request', async () => { const model = createModelWithRequests(2); - let responseModified = false; - await model.insertSummary( - async summaryRequest => { - // Simulate agent adding content to response - summaryRequest.response.response.addContent( - new TextChatResponseContentImpl('Agent response') as ChatResponseContent - ); - responseModified = true; - return summaryRequest.response.response.asDisplayString(); - }, + const result = await model.insertSummary( + async () => ({ + requestId: 'non-existent-id', + summaryText: 'Summary' + }), 'end' ); - expect(responseModified).to.be.true; + // Should return undefined because request wasn't found + expect(result).to.be.undefined; }); }); @@ -302,7 +237,7 @@ describe('MutableChatModel.insertSummary()', () => { model.getRequests()[0].isStale = true; await model.insertSummary( - async () => 'Summary', + createSummaryCallback(model, 'Summary'), 'end' ); diff --git a/packages/ai-chat/src/common/chat-model-serialization.ts b/packages/ai-chat/src/common/chat-model-serialization.ts index 86edac1d6685c..bbb899b374840 100644 --- a/packages/ai-chat/src/common/chat-model-serialization.ts +++ b/packages/ai-chat/src/common/chat-model-serialization.ts @@ -15,7 +15,6 @@ // ***************************************************************************** import { ChatAgentLocation } from './chat-agents'; -import { ChatRequestKind } from './chat-model'; export interface SerializableChangeSetElement { kind?: string; @@ -42,8 +41,8 @@ export interface SerializableChatRequestData { id: string; text: string; agentId?: string; - /** The type of request: 'user' or 'summary'. Defaults to 'user' if not specified. */ - kind?: ChatRequestKind; + /** The type of request. Defaults to 'user' if not specified (for backward compatibility). */ + kind?: 'user' | 'summary'; /** Indicates this request has been summarized and should be excluded from prompt construction */ isStale?: boolean; changeSet?: { @@ -132,6 +131,7 @@ export interface SerializedChatData { model: SerializedChatModel; saveDate: number; lastInputTokens?: number; + branchTokens?: { [branchId: string]: number }; } export interface SerializableChatsData { diff --git a/packages/ai-chat/src/common/chat-model.ts b/packages/ai-chat/src/common/chat-model.ts index a77fd5b58453d..7d7fc055cc4ad 100644 --- a/packages/ai-chat/src/common/chat-model.ts +++ b/packages/ai-chat/src/common/chat-model.ts @@ -943,17 +943,20 @@ export class MutableChatModel implements ChatModel, Disposable { /** * Insert a summary into the model. - * Handles request reordering, stale marking, and summary content creation. + * Handles stale marking for older requests. * - * @param summaryCallback - Callback that invokes the summary agent. - * Receives the summary request (already added to model). - * Should invoke the agent and return the summary text, or undefined on failure. - * @param position - 'end' appends summary at end, 'beforeLast' inserts before the last request + * Note: Only 'end' position is currently supported. The 'beforeLast' mode was removed + * because reordering breaks the hierarchy structure - the summary gets added as a + * continuation of the trigger request, and removing the trigger loses the branch connection. + * + * @param summaryCallback - Callback that creates the summary request via ChatService.sendRequest() + * and returns the requestId and summaryText, or undefined on failure. + * @param position - Currently only 'end' is supported (appends summary at end) * @returns The summary text on success, or undefined on failure */ async insertSummary( - summaryCallback: (summaryRequest: MutableChatRequestModel) => Promise, - position: 'end' | 'beforeLast' + summaryCallback: () => Promise<{ requestId: string; summaryText: string } | undefined>, + position: 'end' ): Promise { const allRequests = this.getRequests(); @@ -963,74 +966,37 @@ export class MutableChatModel implements ChatModel, Disposable { } // The request to preserve (most recent exchange, not summarized) - // Captured before any modifications - same for both position modes const requestToPreserve = allRequests[allRequests.length - 1]; - let triggerRequest: MutableChatRequestModel | undefined; - let triggerBranch: ChatHierarchyBranch | undefined; - - if (position === 'beforeLast') { - // Remove the last request temporarily - it will be re-added after the summary - triggerRequest = requestToPreserve; - triggerBranch = this.getBranch(triggerRequest.id); - if (triggerBranch) { - triggerBranch.remove(triggerRequest); - } - } - // Identify which requests will be marked stale after successful summarization - // (all except the preserved one) + // (all non-stale requests except the preserved one) const requestsToMarkStale = allRequests.filter(r => !r.isStale && r !== requestToPreserve); - // Create summary request - // Use the ChatSessionSummaryAgent.ID constant value directly to avoid circular dependency - const summaryRequest = this.addRequest({ - request: { - text: '', - kind: 'summary' - }, - parts: [], - toolRequests: new Map(), - variables: [] - }, 'chat-session-summary-agent'); - - // Call the callback to invoke the agent + // Call the callback to create the summary request and invoke the agent // NOTE: Stale marking happens AFTER the callback so the summary agent can see all messages - let summaryText: string | undefined; + let result: { requestId: string; summaryText: string } | undefined; try { - summaryText = await summaryCallback(summaryRequest); + result = await summaryCallback(); } catch (error) { - summaryText = undefined; + result = undefined; } - if (!summaryText) { - // Rollback: remove summary request, re-add trigger if needed - const summaryBranch = this.getBranch(summaryRequest.id); - if (summaryBranch) { - summaryBranch.remove(summaryRequest); - } - if (position === 'beforeLast' && triggerRequest) { - this._hierarchy.append(triggerRequest as MutableChatRequestModel); - this._onDidChangeEmitter.fire({ kind: 'addRequest', request: triggerRequest }); - } + if (!result) { return undefined; } - // Success: mark requests as stale AFTER successful summarization - // This ensures the summary agent could see all messages when building the prompt - for (const request of requestsToMarkStale) { - request.isStale = true; - } + const { requestId, summaryText } = result; - // Update summary response with SummaryChatResponseContent - summaryRequest.response.response.clearContent(); - const summaryContent = new SummaryChatResponseContentImpl(summaryText); - summaryRequest.response.response.addContent(summaryContent); + // Find the created summary request using findRequest (handles hierarchy search) + const summaryRequest = this._hierarchy.findRequest(requestId); + if (!summaryRequest) { + // Request not found - treat as failure + return undefined; + } - // Re-add trigger request if beforeLast - if (position === 'beforeLast' && triggerRequest) { - this._hierarchy.append(triggerRequest as MutableChatRequestModel); - this._onDidChangeEmitter.fire({ kind: 'addRequest', request: triggerRequest }); + // Mark older requests as stale + for (const request of requestsToMarkStale) { + request.isStale = true; } return summaryText; @@ -1448,6 +1414,9 @@ export class ChatRequestHierarchyBranchImpl i } get(): TRequest { + if (this.items.length === 0 || this._activeIndex < 0) { + throw new Error('Cannot get request from empty branch'); + } return this.items[this.activeBranchIndex].element; } @@ -1464,8 +1433,12 @@ export class ChatRequestHierarchyBranchImpl i const index = this.items.findIndex(version => version.element.id === requestId); if (index !== -1) { this.items.splice(index, 1); - if (this.activeBranchIndex >= index) { - this.activeBranchIndex--; + if (this.items.length === 0) { + // Branch is now empty - set directly without firing event + this._activeIndex = -1; + } else if (this._activeIndex >= index) { + // Branch still has items - use setter to fire change event + this.activeBranchIndex = Math.max(0, this._activeIndex - 1); } } } diff --git a/packages/ai-chat/src/common/chat-service.ts b/packages/ai-chat/src/common/chat-service.ts index c33862981ef0d..8d100c35e8c88 100644 --- a/packages/ai-chat/src/common/chat-service.ts +++ b/packages/ai-chat/src/common/chat-service.ts @@ -86,6 +86,7 @@ export interface SessionCreatedEvent { type: 'created'; sessionId: string; tokenCount?: number; + branchTokens?: { [branchId: string]: number }; } export function isSessionCreatedEvent(obj: unknown): obj is SessionCreatedEvent { @@ -330,6 +331,9 @@ export class ChatServiceImpl implements ChatService { } protected updateSessionMetadata(session: ChatSessionInternal, request: MutableChatRequestModel): void { + if (request.request.kind === 'summary') { + return; + } session.lastInteraction = new Date(); if (session.title) { return; @@ -520,7 +524,8 @@ export class ChatServiceImpl implements ChatService { this.onSessionEventEmitter.fire({ type: 'created', sessionId: session.id, - tokenCount: serialized.lastInputTokens + tokenCount: serialized.lastInputTokens, + branchTokens: serialized.branchTokens }); this.logger.debug('Session successfully restored and registered', { sessionId, title: session.title }); diff --git a/packages/ai-chat/src/common/chat-session-store.ts b/packages/ai-chat/src/common/chat-session-store.ts index 434c1a7106b69..ea84e5287f870 100644 --- a/packages/ai-chat/src/common/chat-session-store.ts +++ b/packages/ai-chat/src/common/chat-session-store.ts @@ -25,6 +25,7 @@ export interface ChatModelWithMetadata { title?: string; pinnedAgentId?: string; lastInputTokens?: number; + branchTokens?: { [branchId: string]: number }; } export interface ChatSessionStore { diff --git a/packages/ai-chat/src/common/chat-session-summary-agent.ts b/packages/ai-chat/src/common/chat-session-summary-agent.ts index 298a27aea0ccc..67d001ab75544 100644 --- a/packages/ai-chat/src/common/chat-session-summary-agent.ts +++ b/packages/ai-chat/src/common/chat-session-summary-agent.ts @@ -31,6 +31,7 @@ export class ChatSessionSummaryAgent extends AbstractStreamParsingChatAgent impl override description = nls.localize('theia/ai/chat/chatSessionSummaryAgent/description', 'Agent for generating chat session summaries.'); override prompts: PromptVariantSet[] = [CHAT_SESSION_SUMMARY_PROMPT]; protected readonly defaultLanguageModelPurpose = 'chat-session-summary'; + protected override systemPromptId: string = CHAT_SESSION_SUMMARY_PROMPT.id; languageModelRequirements: LanguageModelRequirement[] = [{ purpose: 'chat-session-summary', identifier: 'default/summarize', diff --git a/packages/ai-chat/src/common/chat-session-token-tracker.ts b/packages/ai-chat/src/common/chat-session-token-tracker.ts index 71d7864a4fb9b..94b09b27f6027 100644 --- a/packages/ai-chat/src/common/chat-session-token-tracker.ts +++ b/packages/ai-chat/src/common/chat-session-token-tracker.ts @@ -21,15 +21,12 @@ import { Event } from '@theia/core'; */ export interface SessionTokenUpdateEvent { sessionId: string; - inputTokens: number; -} - -/** - * Event fired when a session's token usage crosses the threshold. - */ -export interface SessionTokenThresholdEvent { - sessionId: string; - inputTokens: number; + /** + * The input token count for the active branch. + * - `number`: Known token count from the most recent LLM response + * - `undefined`: Unknown/not yet measured (branch has never had an LLM request; do NOT coerce to 0) + */ + inputTokens: number | undefined; } export const ChatSessionTokenTracker = Symbol('ChatSessionTokenTracker'); @@ -42,11 +39,6 @@ export const ChatSessionTokenTracker = Symbol('ChatSessionTokenTracker'); * threshold (90% of 200k), it emits an event for summarization. */ export interface ChatSessionTokenTracker { - /** - * Event fired when a session's token usage crosses the threshold. - */ - readonly onThresholdExceeded: Event; - /** * Event fired when a session's token count is updated. */ @@ -61,13 +53,47 @@ export interface ChatSessionTokenTracker { /** * Reset the session's token count to a new baseline. * Called after summarization to reflect the reduced token usage. + * + * @param sessionId - The session ID to reset + * @param newTokenCount - The new token count, or `undefined` to indicate unknown state. + * When `undefined`, deletes the stored count and emits `{ inputTokens: undefined }`. + */ + resetSessionTokens(sessionId: string, newTokenCount: number | undefined): void; + + /** + * Store token count for a specific branch. + * @param sessionId - The session ID + * @param branchId - The branch ID + * @param tokens - The token count + */ + setBranchTokens(sessionId: string, branchId: string, tokens: number): void; + + /** + * Get token count for a specific branch. + * @param sessionId - The session ID + * @param branchId - The branch ID + * @returns The token count, or undefined if not tracked + */ + getBranchTokens(sessionId: string, branchId: string): number | undefined; + + /** + * Get all branch token counts for a session. + * @param sessionId - The session ID + * @returns Object with branchId keys and token count values, or empty object if no data + */ + getBranchTokensForSession(sessionId: string): { [branchId: string]: number }; + + /** + * Restore branch tokens from persisted data. + * @param sessionId - The session ID + * @param branchTokens - Object with branchId keys and token count values */ - resetSessionTokens(sessionId: string, newTokenCount: number): void; + restoreBranchTokens(sessionId: string, branchTokens: { [branchId: string]: number }): void; /** - * Reset the triggered state for a session. - * Called after summarization is complete to allow future triggers - * if the session continues to grow. + * Clear all branch token data for a session. + * Called when a session is deleted. + * @param sessionId - The session ID */ - resetThresholdTrigger(sessionId: string): void; + clearSessionBranchTokens(sessionId: string): void; } From 48e23e5a3b48bcc7af3e5d178ac1a39a23dcca6a Mon Sep 17 00:00:00 2001 From: Eugen Neufeld Date: Mon, 5 Jan 2026 03:19:07 +0100 Subject: [PATCH 3/5] final summary updates --- .prompts/project-info.prompttemplate | 14 + .../src/node/anthropic-language-model.ts | 28 +- .../chat-token-usage-indicator.spec.ts | 24 +- .../browser/chat-token-usage-indicator.tsx | 11 +- .../chat-tree-view/chat-view-tree-widget.tsx | 50 +-- .../ai-chat-ui/src/browser/style/index.css | 21 + .../src/browser/ai-chat-frontend-module.ts | 4 +- .../chat-language-model-service.spec.ts | 86 ++-- .../browser/chat-language-model-service.ts | 212 +++++---- .../browser/chat-session-store-impl.spec.ts | 7 + ...chat-session-summarization-service.spec.ts | 332 +++++--------- .../chat-session-summarization-service.ts | 422 +++++++++++++----- .../chat-session-token-tracker.spec.ts | 109 +++++ .../src/browser/chat-session-token-tracker.ts | 60 ++- packages/ai-chat/src/common/chat-agents.ts | 49 +- .../src/common/chat-model-hierarchy.spec.ts | 32 ++ .../src/common/chat-model-serialization.ts | 2 +- packages/ai-chat/src/common/chat-model.ts | 5 +- .../src/common/chat-request-parser.spec.ts | 13 +- packages/ai-chat/src/common/chat-service.ts | 12 +- .../src/common/chat-session-token-tracker.ts | 31 +- .../frontend-language-model-service.ts | 50 ++- .../src/node/ollama-language-model.ts | 10 +- .../src/node/openai-response-api-utils.ts | 47 +- 24 files changed, 1055 insertions(+), 576 deletions(-) diff --git a/.prompts/project-info.prompttemplate b/.prompts/project-info.prompttemplate index 1a31686d4c108..5af299f7ddbe5 100644 --- a/.prompts/project-info.prompttemplate +++ b/.prompts/project-info.prompttemplate @@ -59,6 +59,20 @@ Tests are located in the same directory as the components under test. If you want to compile something, run the linter or tests, prefer to execute them for changed packages first, as they will run faster. Only build the full project once you are done for a final validation. +### Building and Running the Demo App + +The main example applications are in `/examples/`: + +| Command (from root) | Purpose | +|---------------------|---------| +| `npm install` | Install dependencies (required first) | +| `npm run build:browser` | Build all packages + browser app | +| `npm run start:browser` | Start browser example at localhost:3000 | +| `npm run start:electron` | Start Electron desktop app | +| `npm run watch` | Watch mode for development | + +**Requirements:** Node.js ≥18.17.0, <21 + ### Styling Theia permits extensive color theming and makes extensive use of CSS variables. Styles are typically located either in an `index.css` file for an entire package or in a diff --git a/packages/ai-anthropic/src/node/anthropic-language-model.ts b/packages/ai-anthropic/src/node/anthropic-language-model.ts index 4687bf98edd14..2ad01beba420b 100644 --- a/packages/ai-anthropic/src/node/anthropic-language-model.ts +++ b/packages/ai-anthropic/src/node/anthropic-language-model.ts @@ -263,7 +263,6 @@ export class AnthropicModel implements LanguageModel { const asyncIterator = { async *[Symbol.asyncIterator](): AsyncIterator { - const toolCalls: ToolCallback[] = []; let toolCall: ToolCallback | undefined; const currentMessages: Message[] = []; @@ -313,6 +312,17 @@ export class AnthropicModel implements LanguageModel { } else if (event.type === 'message_start') { currentMessages.push(event.message); currentMessage = event.message; + // Report input tokens immediately + if (that.tokenUsageService && event.message.usage) { + that.tokenUsageService.recordTokenUsage(that.id, { + inputTokens: event.message.usage.input_tokens, + outputTokens: event.message.usage.output_tokens, + cachedInputTokens: event.message.usage.cache_creation_input_tokens ?? undefined, + readCachedInputTokens: event.message.usage.cache_read_input_tokens ?? undefined, + requestId: request.requestId, + sessionId: request.sessionId + }); + } } else if (event.type === 'message_stop') { if (currentMessage) { yield { @@ -321,25 +331,25 @@ export class AnthropicModel implements LanguageModel { cache_creation_input_tokens: currentMessage.usage.cache_creation_input_tokens ?? undefined, cache_read_input_tokens: currentMessage.usage.cache_read_input_tokens ?? undefined }; - // Record token usage if token usage service is available - if (that.tokenUsageService && currentMessage.usage) { - const tokenUsageParams: TokenUsageParams = { + // Report final token usage + if (that.tokenUsageService) { + that.tokenUsageService.recordTokenUsage(that.id, { inputTokens: currentMessage.usage.input_tokens, outputTokens: currentMessage.usage.output_tokens, cachedInputTokens: currentMessage.usage.cache_creation_input_tokens ?? undefined, readCachedInputTokens: currentMessage.usage.cache_read_input_tokens ?? undefined, requestId: request.requestId, sessionId: request.sessionId - }; - await that.tokenUsageService.recordTokenUsage(that.id, tokenUsageParams); + }); } } - } } if (toolCalls.length > 0) { - // If singleRoundTrip is true, yield tool calls without executing them - // The caller is responsible for tool execution and continuation + // singleRoundTrip mode: Return tool calls to caller without executing them. + // This allows external tool loop management (e.g., for budget-aware summarization). + // When enabled, we yield the tool_calls and return immediately; the caller + // handles tool execution and decides whether to continue the conversation. if (request.singleRoundTrip) { const pendingCalls = toolCalls.map(tc => ({ finished: true, diff --git a/packages/ai-chat-ui/src/browser/chat-token-usage-indicator.spec.ts b/packages/ai-chat-ui/src/browser/chat-token-usage-indicator.spec.ts index fbfc7705130c0..9d01dc4b709e0 100644 --- a/packages/ai-chat-ui/src/browser/chat-token-usage-indicator.spec.ts +++ b/packages/ai-chat-ui/src/browser/chat-token-usage-indicator.spec.ts @@ -26,8 +26,7 @@ import { flushSync } from '@theia/core/shared/react-dom'; import { Emitter } from '@theia/core'; import { ChatSessionTokenTracker, - SessionTokenUpdateEvent, - CHAT_TOKEN_THRESHOLD + SessionTokenUpdateEvent } from '@theia/ai-chat/lib/browser'; import { ChatTokenUsageIndicator, ChatTokenUsageIndicatorProps } from './chat-token-usage-indicator'; @@ -42,7 +41,10 @@ describe('ChatTokenUsageIndicator', () => { return { onSessionTokensUpdated: updateEmitter.event, getSessionInputTokens: () => tokens, + getSessionOutputTokens: () => undefined, + getSessionTotalTokens: () => tokens, resetSessionTokens: () => { }, + updateSessionTokens: () => { }, setBranchTokens: () => { }, getBranchTokens: () => undefined, getBranchTokensForSession: () => ({}), @@ -130,8 +132,7 @@ describe('ChatTokenUsageIndicator', () => { describe('color coding', () => { it('should have green class when usage is below 70%', () => { - // 70% of CHAT_TOKEN_THRESHOLD = 126000, so 100000 is below - expect(Math.round(CHAT_TOKEN_THRESHOLD * 0.7)).to.equal(126000); + // Below 70% of CHAT_TOKEN_THRESHOLD const mockTracker = createMockTokenTracker(100000); renderComponent({ sessionId: 'test-session', @@ -144,8 +145,7 @@ describe('ChatTokenUsageIndicator', () => { }); it('should have yellow class when usage is between 70% and 90%', () => { - // 70% of CHAT_TOKEN_THRESHOLD = 126000 - // 90% of CHAT_TOKEN_THRESHOLD = 162000 + // Between 70% and 90% of CHAT_TOKEN_THRESHOLD (180000 * 0.7 = 126000, 180000 * 0.9 = 162000) const mockTracker = createMockTokenTracker(150000); renderComponent({ sessionId: 'test-session', @@ -158,7 +158,7 @@ describe('ChatTokenUsageIndicator', () => { }); it('should have red class when usage is at or above 90%', () => { - // 90% of CHAT_TOKEN_THRESHOLD = 162000 + // At or above 90% of CHAT_TOKEN_THRESHOLD const mockTracker = createMockTokenTracker(170000); renderComponent({ sessionId: 'test-session', @@ -246,7 +246,10 @@ describe('ChatTokenUsageIndicator', () => { const mockTracker: ChatSessionTokenTracker = { onSessionTokensUpdated: updateEmitter.event, getSessionInputTokens: () => currentTokens, + getSessionOutputTokens: () => undefined, + getSessionTotalTokens: () => currentTokens, resetSessionTokens: () => { }, + updateSessionTokens: () => { }, setBranchTokens: () => { }, getBranchTokens: () => undefined, getBranchTokensForSession: () => ({}), @@ -267,7 +270,7 @@ describe('ChatTokenUsageIndicator', () => { // Fire update event within flushSync to ensure synchronous React update currentTokens = 100000; flushSync(() => { - updateEmitter.fire({ sessionId: 'test-session', inputTokens: 100000 }); + updateEmitter.fire({ sessionId: 'test-session', inputTokens: 100000, outputTokens: undefined }); }); textContent = container.textContent; @@ -280,7 +283,10 @@ describe('ChatTokenUsageIndicator', () => { const mockTracker: ChatSessionTokenTracker = { onSessionTokensUpdated: updateEmitter.event, getSessionInputTokens: () => 50000, + getSessionOutputTokens: () => undefined, + getSessionTotalTokens: () => 50000, resetSessionTokens: () => { }, + updateSessionTokens: () => { }, setBranchTokens: () => { }, getBranchTokens: () => undefined, getBranchTokensForSession: () => ({}), @@ -300,7 +306,7 @@ describe('ChatTokenUsageIndicator', () => { // Fire update event for different session within flushSync flushSync(() => { - updateEmitter.fire({ sessionId: 'other-session', inputTokens: 100000 }); + updateEmitter.fire({ sessionId: 'other-session', inputTokens: 100000, outputTokens: undefined }); }); textContent = container.textContent; diff --git a/packages/ai-chat-ui/src/browser/chat-token-usage-indicator.tsx b/packages/ai-chat-ui/src/browser/chat-token-usage-indicator.tsx index fa565a2baa081..a8a17e826607b 100644 --- a/packages/ai-chat-ui/src/browser/chat-token-usage-indicator.tsx +++ b/packages/ai-chat-ui/src/browser/chat-token-usage-indicator.tsx @@ -21,6 +21,11 @@ import { CHAT_TOKEN_THRESHOLD } from '@theia/ai-chat/lib/browser'; +/** Percentage of threshold at which to show warning color (yellow) */ +const TOKEN_USAGE_WARNING_PERCENT = 70; +/** Percentage of threshold at which to show critical color (red) */ +const TOKEN_USAGE_CRITICAL_PERCENT = 90; + export interface ChatTokenUsageIndicatorProps { sessionId: string; tokenTracker: ChatSessionTokenTracker; @@ -54,10 +59,10 @@ const getUsageColorClass = (tokens: number | undefined, threshold: number): stri return 'token-usage-none'; } const percentage = (tokens / threshold) * 100; - if (percentage >= 90) { + if (percentage >= TOKEN_USAGE_CRITICAL_PERCENT) { return 'token-usage-red'; } - if (percentage >= 70) { + if (percentage >= TOKEN_USAGE_WARNING_PERCENT) { return 'token-usage-yellow'; } return 'token-usage-green'; @@ -108,7 +113,7 @@ export const ChatTokenUsageIndicator: React.FC = ( title={tooltipText} > - {currentFormatted} / {thresholdFormatted} tokens + {currentFormatted} / {budgetFormatted} tokens
); diff --git a/packages/ai-chat-ui/src/browser/chat-tree-view/chat-view-tree-widget.tsx b/packages/ai-chat-ui/src/browser/chat-tree-view/chat-view-tree-widget.tsx index 9e101a369756e..eff865af63d61 100644 --- a/packages/ai-chat-ui/src/browser/chat-tree-view/chat-view-tree-widget.tsx +++ b/packages/ai-chat-ui/src/browser/chat-tree-view/chat-view-tree-widget.tsx @@ -397,12 +397,7 @@ export class ChatViewTreeWidget extends TreeWidget { return this.request?.id ?? `empty-branch-${branch.id}`; }, get request(): ChatRequestModel { - // Guard against empty branches - can happen during insertSummary - try { - return branch.get(); - } catch { - return undefined as unknown as ChatRequestModel; - } + return branch.get(); }, branch, sessionId: this.chatModelId @@ -496,8 +491,8 @@ export class ChatViewTreeWidget extends TreeWidget { } const request = branch.get(); nodes.push(this.mapRequestToNode(branch)); - // Skip separate response node for summary requests - response is rendered within request node - if (request.request.kind !== 'summary') { + // Skip separate response node for summary/continuation requests - response is rendered within request node + if (request.request.kind !== 'summary' && request.request.kind !== 'continuation') { nodes.push(this.mapResponseToNode(request.response)); } }); @@ -523,9 +518,13 @@ export class ChatViewTreeWidget extends TreeWidget { return super.renderNode(node, props); } + // Check if this is a summary or continuation request - skip header for both + const isSummaryOrContinuation = isRequestNode(node) && + (node.request.request.kind === 'summary' || node.request.request.kind === 'continuation'); + return
this.handleContextMenu(node, e)}> - {this.renderAgent(node)} + {!isSummaryOrContinuation && this.renderAgent(node)} {this.renderDetail(node)}
; @@ -641,7 +640,7 @@ export class ChatViewTreeWidget extends TreeWidget { chatAgentService={this.chatAgentService} variableService={this.variableService} openerService={this.openerService} - renderResponseContent={(content: ChatResponseContent) => this.renderResponseContent(content)} + renderResponseContent={(content: ChatResponseContent, responseNode?: ResponseNode) => this.renderResponseContent(content, responseNode)} provideChatInputWidget={() => { const editableNode = node; if (isEditableRequestNode(editableNode)) { @@ -672,7 +671,7 @@ export class ChatViewTreeWidget extends TreeWidget { />; } - protected renderResponseContent(content: ChatResponseContent): React.ReactNode { + protected renderResponseContent(content: ChatResponseContent, node?: ResponseNode): React.ReactNode { const renderer = this.chatResponsePartRenderers.getContributions().reduce<[number, ChatResponsePartRenderer | undefined]>( (prev, current) => { const prio = current.canHandle(content); @@ -684,7 +683,7 @@ export class ChatViewTreeWidget extends TreeWidget { if (!renderer) { return undefined; } - return renderer.render(content, undefined as unknown as ResponseNode); + return renderer.render(content, node as ResponseNode); } protected renderChatResponse(node: ResponseNode): React.ReactNode { @@ -800,13 +799,14 @@ const ChatRequestRender = ( variableService: AIVariableService, openerService: OpenerService, provideChatInputWidget: () => ReactWidget | undefined, - renderResponseContent?: (content: ChatResponseContent) => React.ReactNode, + renderResponseContent?: (content: ChatResponseContent, node?: ResponseNode) => React.ReactNode, }) => { // Capture the request object once to avoid getter issues const request = node.request; const parts = request.message.parts; const isStale = request.isStale === true; const isSummary = request.request.kind === 'summary'; + const isContinuation = request.request.kind === 'continuation'; if (EditableChatRequestModel.isEditing(request)) { const widget = provideChatInputWidget(); @@ -847,17 +847,17 @@ const ChatRequestRender = ( return (
- {isSummary && ( -
- - {nls.localize('theia/ai-chat/summary', 'Conversation Summary')} -
- )} - {isSummary && renderResponseContent ? ( -
- {request.response.response.content.map((c, i) => ( -
{renderResponseContent(c)}
- ))} + {(isSummary || isContinuation) && renderResponseContent ? ( +
+ {request.response.response.content.map((c, i) => { + const syntheticResponseNode: ResponseNode = { + id: request.response.id, + parent: node.parent, + response: request.response, + sessionId: node.sessionId + }; + return
{renderResponseContent(c, syntheticResponseNode)}
; + })}
) : (

@@ -896,7 +896,7 @@ const ChatRequestRender = ( })}

)} - {!isSummary && renderFooter()} + {!isSummary && !isContinuation && renderFooter()}
); }; diff --git a/packages/ai-chat-ui/src/browser/style/index.css b/packages/ai-chat-ui/src/browser/style/index.css index d1d5ec7605fc9..37a365b4dfa9c 100644 --- a/packages/ai-chat-ui/src/browser/style/index.css +++ b/packages/ai-chat-ui/src/browser/style/index.css @@ -1331,6 +1331,27 @@ details[open].collapsible-arguments .collapsible-arguments-summary { font-weight: 600; color: var(--theia-foreground); user-select: none; + list-style: none; + position: relative; +} + +.theia-chat-summary summary::-webkit-details-marker { + display: none; +} + +.theia-chat-summary summary::before { + content: "\25BC"; + position: absolute; + right: 12px; + top: 50%; + transform: translateY(-50%); + transition: transform 0.2s ease; + color: var(--theia-descriptionForeground); + font-size: var(--theia-ui-font-size1); +} + +.theia-chat-summary details:not([open]) summary::before { + transform: translateY(-50%) rotate(-90deg); } .theia-chat-summary summary:hover { diff --git a/packages/ai-chat/src/browser/ai-chat-frontend-module.ts b/packages/ai-chat/src/browser/ai-chat-frontend-module.ts index 96fcdbb1aa3f5..9005993b2fcf2 100644 --- a/packages/ai-chat/src/browser/ai-chat-frontend-module.ts +++ b/packages/ai-chat/src/browser/ai-chat-frontend-module.ts @@ -28,7 +28,8 @@ import { ToolCallChatResponseContentFactory, PinChatAgent, ChatServiceFactory, - ChatAgentServiceFactory + ChatAgentServiceFactory, + ChatSessionSummarizationServiceSymbol } from '../common'; import { ChatAgentsVariableContribution } from '../common/chat-agents-variable-contribution'; import { CustomChatAgent } from '../common/custom-chat-agent'; @@ -195,6 +196,7 @@ export default new ContainerModule((bind, unbind, isBound, rebind) => { bind(ChatSessionSummarizationServiceImpl).toSelf().inSingletonScope(); bind(ChatSessionSummarizationService).toService(ChatSessionSummarizationServiceImpl); + bind(ChatSessionSummarizationServiceSymbol).toService(ChatSessionSummarizationService); bind(FrontendApplicationContribution).toService(ChatSessionSummarizationServiceImpl); // Rebind LanguageModelService to use the chat-aware implementation with budget-aware tool loop diff --git a/packages/ai-chat/src/browser/chat-language-model-service.spec.ts b/packages/ai-chat/src/browser/chat-language-model-service.spec.ts index 5c9a138510552..4e6a864a24bfa 100644 --- a/packages/ai-chat/src/browser/chat-language-model-service.spec.ts +++ b/packages/ai-chat/src/browser/chat-language-model-service.spec.ts @@ -56,12 +56,16 @@ describe('ChatLanguageModelServiceImpl', () => { } as unknown as sinon.SinonStubbedInstance; mockTokenTracker = { - getSessionInputTokens: sinon.stub() + getSessionInputTokens: sinon.stub(), + getSessionOutputTokens: sinon.stub(), + getSessionTotalTokens: sinon.stub(), + updateSessionTokens: sinon.stub() } as unknown as sinon.SinonStubbedInstance; mockSummarizationService = { - triggerSummarization: sinon.stub(), - hasSummary: sinon.stub() + hasSummary: sinon.stub(), + markPendingSplit: sinon.stub(), + checkAndHandleSummarization: sinon.stub().resolves(false) } as unknown as sinon.SinonStubbedInstance; mockLogger = { @@ -106,7 +110,7 @@ describe('ChatLanguageModelServiceImpl', () => { const response = await service.sendRequest(mockLanguageModel, request); expect(isLanguageModelStreamResponse(response)).to.be.true; - expect(mockSummarizationService.triggerSummarization.called).to.be.false; + expect(mockSummarizationService.markPendingSplit.called).to.be.false; }); it('should delegate to super when request has no tools', async () => { @@ -125,7 +129,7 @@ describe('ChatLanguageModelServiceImpl', () => { const response = await service.sendRequest(mockLanguageModel, request); expect(isLanguageModelStreamResponse(response)).to.be.true; - expect(mockSummarizationService.triggerSummarization.called).to.be.false; + expect(mockSummarizationService.markPendingSplit.called).to.be.false; }); it('should use budget-aware handling when preference is enabled and tools are present', async () => { @@ -160,11 +164,10 @@ describe('ChatLanguageModelServiceImpl', () => { }); describe('budget checking', () => { - it('should trigger summarization after tool execution when budget is exceeded mid-loop', async () => { + it('should call markPendingSplit after tool execution when budget is exceeded mid-loop', async () => { mockPreferenceService.get.withArgs(BUDGET_AWARE_TOOL_LOOP_PREF, false).returns(true); - // Return over threshold - budget check happens before request and after tool execution + // Return over threshold - budget check happens after tool execution mockTokenTracker.getSessionInputTokens.returns(CHAT_TOKEN_THRESHOLD + 1000); - mockSummarizationService.triggerSummarization.resolves(undefined); const toolHandler = sinon.stub().resolves('tool result'); const request: UserRequest = { @@ -179,18 +182,13 @@ describe('ChatLanguageModelServiceImpl', () => { }] }; - // First call: model returns tool call without result - const firstStream = createMockStream([ + // Model returns tool call without result + const mockStream = createMockStream([ { content: 'Let me use a tool' }, { tool_calls: [{ id: 'call-1', function: { name: 'test-tool', arguments: '{}' }, finished: true }] } ]); - // Second call: model returns final response - const secondStream = createMockStream([{ content: 'Done!' }]); - - mockLanguageModel.request - .onFirstCall().resolves({ stream: firstStream }) - .onSecondCall().resolves({ stream: secondStream }); + mockLanguageModel.request.resolves({ stream: mockStream }); const response = await service.sendRequest(mockLanguageModel, request); @@ -200,18 +198,23 @@ describe('ChatLanguageModelServiceImpl', () => { // just consume } - // Verify summarization was triggered both before request and after tool execution - expect(mockSummarizationService.triggerSummarization.calledTwice).to.be.true; - // First call (before request): no skipReorder - expect(mockSummarizationService.triggerSummarization.firstCall.calledWith('session-1')).to.be.true; - // Second call (mid-turn, after tool execution): skipReorder=true - expect(mockSummarizationService.triggerSummarization.secondCall.calledWith('session-1', true)).to.be.true; + // Verify markPendingSplit was called after tool execution (mid-turn budget exceeded) + expect(mockSummarizationService.markPendingSplit.calledOnce).to.be.true; + const markPendingCall = mockSummarizationService.markPendingSplit.firstCall; + expect(markPendingCall.args[0]).to.equal('session-1'); + expect(markPendingCall.args[1]).to.equal('request-1'); + // Third arg should be pending tool calls array + expect(markPendingCall.args[2]).to.be.an('array'); + // Fourth arg should be tool results map + expect(markPendingCall.args[3]).to.be.instanceOf(Map); + + // Loop should exit after split + expect(mockLanguageModel.request.calledOnce).to.be.true; }); - it('should trigger summarization before request when budget is exceeded', async () => { + it('should not trigger markPendingSplit when no tool calls are made', async () => { mockPreferenceService.get.withArgs(BUDGET_AWARE_TOOL_LOOP_PREF, false).returns(true); mockTokenTracker.getSessionInputTokens.returns(CHAT_TOKEN_THRESHOLD + 1000); - mockSummarizationService.triggerSummarization.resolves(undefined); const request: UserRequest = { sessionId: 'session-1', @@ -232,9 +235,8 @@ describe('ChatLanguageModelServiceImpl', () => { // just consume } - // Summarization should be triggered before the request since budget is exceeded - expect(mockSummarizationService.triggerSummarization.calledOnce).to.be.true; - expect(mockSummarizationService.triggerSummarization.calledWith('session-1')).to.be.true; + // markPendingSplit should NOT be called when there are no pending tool calls + expect(mockSummarizationService.markPendingSplit.called).to.be.false; }); it('should preserve original messages (no message rebuilding)', async () => { @@ -364,6 +366,36 @@ describe('ChatLanguageModelServiceImpl', () => { }); describe('error handling', () => { + it('should throw error when model returns non-streaming response', async () => { + mockPreferenceService.get.withArgs(BUDGET_AWARE_TOOL_LOOP_PREF, false).returns(true); + mockTokenTracker.getSessionInputTokens.returns(100); + + const request: UserRequest = { + sessionId: 'session-1', + requestId: 'request-1', + messages: [{ actor: 'user', type: 'text', text: 'Hello' }], + tools: [{ id: 'tool-1', name: 'test-tool', parameters: { type: 'object', properties: {} }, handler: async () => 'result' }] + }; + + // Model returns a non-streaming response (just text, not a stream) + mockLanguageModel.request.resolves({ text: 'Non-streaming response' }); + + const response = await service.sendRequest(mockLanguageModel, request); + expect(isLanguageModelStreamResponse(response)).to.be.true; + + // Consuming the stream should throw + try { + // eslint-disable-next-line @typescript-eslint/no-unused-vars + for await (const _part of (response as LanguageModelStreamResponse).stream) { + // Should throw before we get here + } + expect.fail('Should have thrown an error'); + } catch (error) { + expect(error).to.be.instanceOf(Error); + expect((error as Error).message).to.equal('Budget-aware tool loop requires streaming response. Model returned non-streaming response.'); + } + }); + it('should log context-too-long errors and re-throw', async () => { mockPreferenceService.get.withArgs(BUDGET_AWARE_TOOL_LOOP_PREF, false).returns(true); mockTokenTracker.getSessionInputTokens.returns(100); diff --git a/packages/ai-chat/src/browser/chat-language-model-service.ts b/packages/ai-chat/src/browser/chat-language-model-service.ts index 3e7bceda531e6..50e631f9ba3ca 100644 --- a/packages/ai-chat/src/browser/chat-language-model-service.ts +++ b/packages/ai-chat/src/browser/chat-language-model-service.ts @@ -34,8 +34,7 @@ import { import { BUDGET_AWARE_TOOL_LOOP_PREF } from '../common/ai-chat-preferences'; import { ChatSessionTokenTracker, CHAT_TOKEN_THRESHOLD } from './chat-session-token-tracker'; import { ChatSessionSummarizationService } from './chat-session-summarization-service'; -import { PREFERENCE_NAME_REQUEST_SETTINGS, RequestSetting } from '@theia/ai-core/lib/common/ai-core-preferences'; -import { mergeRequestSettings } from '@theia/ai-core/lib/browser/frontend-language-model-service'; +import { applyRequestSettings } from '@theia/ai-core/lib/browser/frontend-language-model-service'; /** * Chat-specific language model service that adds budget-aware tool loop handling. @@ -72,22 +71,7 @@ export class ChatLanguageModelServiceImpl extends LanguageModelServiceImpl { languageModel: LanguageModel, request: UserRequest ): Promise { - // Apply request settings (matching FrontendLanguageModelServiceImpl behavior) - const requestSettings = this.preferenceService.get(PREFERENCE_NAME_REQUEST_SETTINGS, []); - const ids = languageModel.id.split('/'); - const matchingSetting = mergeRequestSettings(requestSettings, ids[1], ids[0], request.agentId); - if (matchingSetting?.requestSettings) { - request.settings = { - ...matchingSetting.requestSettings, - ...request.settings - }; - } - if (matchingSetting?.clientSettings) { - request.clientSettings = { - ...matchingSetting.clientSettings, - ...request.clientSettings - }; - } + applyRequestSettings(request, languageModel.id, request.agentId, this.preferenceService); const budgetAwareEnabled = this.preferenceService.get(BUDGET_AWARE_TOOL_LOOP_PREF, false); @@ -107,12 +91,6 @@ export class ChatLanguageModelServiceImpl extends LanguageModelServiceImpl { languageModel: LanguageModel, request: UserRequest ): Promise { - // Check if budget is exceeded BEFORE sending - if (request.sessionId && this.isBudgetExceeded(request.sessionId)) { - this.logger.info(`Budget exceeded before request for session ${request.sessionId}, triggering summarization...`); - await this.summarizationService.triggerSummarization(request.sessionId, false); - } - const modifiedRequest: UserRequest = { ...request, singleRoundTrip: true @@ -122,6 +100,7 @@ export class ChatLanguageModelServiceImpl extends LanguageModelServiceImpl { /** * Execute the tool loop, handling tool calls and budget checks between iterations. + * This method coordinates the overall flow, delegating to helper methods for specific tasks. */ protected async executeToolLoop( languageModel: LanguageModel, @@ -133,8 +112,6 @@ export class ChatLanguageModelServiceImpl extends LanguageModelServiceImpl { // State that persists across the async iterator let currentMessages = [...request.messages]; - let pendingToolCalls: ToolCall[] = []; - let modelHandledLoop = false; const asyncIterator = { async *[Symbol.asyncIterator](): AsyncIterator { @@ -142,81 +119,42 @@ export class ChatLanguageModelServiceImpl extends LanguageModelServiceImpl { while (continueLoop) { continueLoop = false; - pendingToolCalls = []; - modelHandledLoop = false; - - // Create request with current messages - const currentRequest: UserRequest = { - ...request, - messages: currentMessages, - singleRoundTrip: true - }; - - let response: LanguageModelResponse; - try { - // Call the parent's sendRequest to get the response - response = await LanguageModelServiceImpl.prototype.sendRequest.call( - that, languageModel, currentRequest - ); - } catch (error) { - // Check if this is a "context too long" error - const errorMessage = error instanceof Error ? error.message : String(error); - if (errorMessage.toLowerCase().includes('context') || - errorMessage.toLowerCase().includes('token') || - errorMessage.toLowerCase().includes('too long') || - errorMessage.toLowerCase().includes('max_tokens')) { - that.logger.error( - 'Context too long error for session ' + sessionId + '. ' + - 'Cannot recover - summarization also requires an LLM call.', - error - ); - } - // Re-throw to let the chat agent handle and display the error - throw error; - } - if (!isLanguageModelStreamResponse(response)) { - // Non-streaming response - just return as-is - // This shouldn't happen with singleRoundTrip but handle gracefully - return; - } + // Get response from model + const response = await that.sendSingleRoundTripRequest( + languageModel, request, currentMessages, sessionId + ); - // Process the stream - for await (const part of response.stream) { - // Collect tool calls to check if model respected singleRoundTrip - if (isToolCallResponsePart(part)) { - for (const toolCall of part.tool_calls) { - // If any tool call has a result, the model handled the loop internally - if (toolCall.result !== undefined) { - modelHandledLoop = true; - } - // Collect finished tool calls without results (model respected singleRoundTrip) - if (toolCall.finished && toolCall.result === undefined && toolCall.id) { - pendingToolCalls.push(toolCall); - } - } - } - yield part; + // Process the stream and collect tool calls + const streamProcessor = that.processResponseStream(response.stream); + let streamResult: IteratorResult; + + // Yield all parts from the stream processor + while (!(streamResult = await streamProcessor.next()).done) { + yield streamResult.value; } + const { pendingToolCalls, modelHandledLoop } = streamResult.value; + // If model handled the loop internally, we're done if (modelHandledLoop) { return; } - // If there are pending tool calls, execute them and continue the loop + // If there are pending tool calls, execute them and check if we need to split if (pendingToolCalls.length > 0) { - // Execute tools - const toolResults = await that.executeTools(pendingToolCalls, tools); - - // Check budget after tool execution - if (that.isBudgetExceeded(sessionId)) { - that.logger.info(`Budget exceeded after tool execution for session ${sessionId}, triggering summarization...`); - // Pass skipReorder=true for mid-turn summarization to avoid disrupting the active request - await that.summarizationService.triggerSummarization(sessionId, true); + const { toolResults, shouldSplit } = await that.executeToolsAndCheckBudget( + pendingToolCalls, tools, sessionId + ); + + if (shouldSplit && sessionId) { + // Budget exceeded - mark pending split and exit cleanly + that.logger.info(`Splitting turn for session ${sessionId} due to budget exceeded`); + that.summarizationService.markPendingSplit(sessionId, request.requestId, pendingToolCalls, toolResults); + return; } - // Append tool messages to current messages + // Normal case - append tool messages and continue loop currentMessages = that.appendToolMessages( currentMessages, pendingToolCalls, @@ -241,10 +179,104 @@ export class ChatLanguageModelServiceImpl extends LanguageModelServiceImpl { return { stream: asyncIterator }; } + /** + * Send a single round-trip request to the language model. + * Handles context-too-long errors and ensures streaming response. + */ + protected async sendSingleRoundTripRequest( + languageModel: LanguageModel, + request: UserRequest, + currentMessages: LanguageModelMessage[], + sessionId: string | undefined + ): Promise { + const currentRequest: UserRequest = { + ...request, + messages: currentMessages, + singleRoundTrip: true + }; + + let response: LanguageModelResponse; + try { + response = await LanguageModelServiceImpl.prototype.sendRequest.call( + this, languageModel, currentRequest + ); + } catch (error) { + const errorMessage = error instanceof Error ? error.message : String(error); + if (errorMessage.toLowerCase().includes('context') || + errorMessage.toLowerCase().includes('token') || + errorMessage.toLowerCase().includes('too long') || + errorMessage.toLowerCase().includes('max_tokens')) { + this.logger.error( + 'Context too long error for session ' + sessionId + '. ' + + 'Cannot recover - summarization also requires an LLM call.', + error + ); + } + throw error; + } + + if (!isLanguageModelStreamResponse(response)) { + throw new Error('Budget-aware tool loop requires streaming response. Model returned non-streaming response.'); + } + + return response; + } + + /** + * Process a response stream, collecting tool calls and yielding parts. + * @returns Object with pendingToolCalls and whether the model handled the loop internally + */ + protected async *processResponseStream( + stream: AsyncIterable + ): AsyncGenerator { + const pendingToolCalls: ToolCall[] = []; + let modelHandledLoop = false; + + for await (const part of stream) { + if (isToolCallResponsePart(part)) { + for (const toolCall of part.tool_calls) { + // If any tool call has a result, the model handled the loop internally + if (toolCall.result !== undefined) { + modelHandledLoop = true; + } + // Collect finished tool calls without results (model respected singleRoundTrip) + if (toolCall.finished && toolCall.result === undefined && toolCall.id) { + pendingToolCalls.push(toolCall); + } + } + } + yield part; + } + + return { pendingToolCalls, modelHandledLoop }; + } + + /** + * Execute pending tool calls and check if budget is exceeded. + * Returns a signal indicating if the turn should be split. + */ + protected async executeToolsAndCheckBudget( + pendingToolCalls: ToolCall[], + tools: ToolRequest[], + sessionId: string | undefined + ): Promise<{ toolResults: Map; shouldSplit: boolean }> { + const toolResults = await this.executeTools(pendingToolCalls, tools); + + const shouldSplit = sessionId !== undefined && this.isBudgetExceeded(sessionId); + if (shouldSplit) { + this.logger.info(`Budget exceeded after tool execution for session ${sessionId}, will trigger split...`); + } + + return { toolResults, shouldSplit }; + } + /** * Check if the token budget is exceeded for a session. */ - protected isBudgetExceeded(sessionId: string): boolean { + protected isBudgetExceeded(sessionId: string | undefined): boolean { + if (!sessionId) { + return false; + } const tokens = this.tokenTracker.getSessionInputTokens(sessionId); return tokens !== undefined && tokens >= CHAT_TOKEN_THRESHOLD; } diff --git a/packages/ai-chat/src/browser/chat-session-store-impl.spec.ts b/packages/ai-chat/src/browser/chat-session-store-impl.spec.ts index 61e862d980739..a57b905cf1ba4 100644 --- a/packages/ai-chat/src/browser/chat-session-store-impl.spec.ts +++ b/packages/ai-chat/src/browser/chat-session-store-impl.spec.ts @@ -34,6 +34,7 @@ import { BinaryBuffer } from '@theia/core/lib/common/buffer'; import { ChatSessionIndex, ChatSessionMetadata } from '../common/chat-session-store'; import { PERSISTED_SESSION_LIMIT_PREF } from '../common/ai-chat-preferences'; import { ChatAgentLocation } from '../common/chat-agents'; +import { ChatSessionTokenTracker } from './chat-session-token-tracker'; disableJSDOM(); @@ -113,6 +114,12 @@ describe('ChatSessionStoreImpl', () => { container.bind('ChatSessionStore').toConstantValue(mockLogger); container.bind(ILogger).toConstantValue(mockLogger).whenTargetNamed('ChatSessionStore'); + const mockTokenTracker = { + getSessionInputTokens: sandbox.stub().returns(undefined), + getBranchTokensForSession: sandbox.stub().returns(undefined) + } as unknown as ChatSessionTokenTracker; + container.bind(ChatSessionTokenTracker).toConstantValue(mockTokenTracker); + container.bind(ChatSessionStoreImpl).toSelf().inSingletonScope(); chatSessionStore = container.get(ChatSessionStoreImpl); diff --git a/packages/ai-chat/src/browser/chat-session-summarization-service.spec.ts b/packages/ai-chat/src/browser/chat-session-summarization-service.spec.ts index 334173d6b2939..c3a2806150d72 100644 --- a/packages/ai-chat/src/browser/chat-session-summarization-service.spec.ts +++ b/packages/ai-chat/src/browser/chat-session-summarization-service.spec.ts @@ -18,11 +18,11 @@ import { expect } from 'chai'; import * as sinon from 'sinon'; import { Container } from '@theia/core/shared/inversify'; import { Emitter, ILogger } from '@theia/core'; -import { TokenUsage, TokenUsageServiceClient } from '@theia/ai-core'; +import { TokenUsage, TokenUsageServiceClient, ToolCall, ToolCallResult } from '@theia/ai-core'; import { ChatSessionSummarizationServiceImpl } from './chat-session-summarization-service'; import { ChatSessionTokenTracker, CHAT_TOKEN_THRESHOLD } from './chat-session-token-tracker'; -import { ChatRequestInvocation, ChatService, SessionCreatedEvent, SessionDeletedEvent } from '../common/chat-service'; -import { ChatRequestModel, ChatResponseModel, ChatSession } from '../common'; +import { ChatService, SessionCreatedEvent, SessionDeletedEvent } from '../common/chat-service'; +import { ChatSession } from '../common'; import { ChatSessionStore } from '../common/chat-session-store'; describe('ChatSessionSummarizationServiceImpl', () => { @@ -97,6 +97,9 @@ describe('ChatSessionSummarizationServiceImpl', () => { tokenTracker = { resetSessionTokens: sinon.stub(), getSessionInputTokens: sinon.stub(), + getSessionOutputTokens: sinon.stub(), + getSessionTotalTokens: sinon.stub(), + updateSessionTokens: sinon.stub(), onSessionTokensUpdated: sinon.stub(), setBranchTokens: sinon.stub().callsFake((sessionId: string, branchId: string, tokens: number) => { branchTokensMap.set(`${sessionId}:${branchId}`, tokens); @@ -174,224 +177,77 @@ describe('ChatSessionSummarizationServiceImpl', () => { sessionRegistry.clear(); }); - // Helper to create a mock ChatRequestInvocation - function createMockInvocation(params: { - requestId: string; - isError: boolean; - displayString: string; - errorObject?: Error; - }): ChatRequestInvocation { - const mockRequest = { - id: params.requestId, - request: { kind: 'summary' as const }, - addData: sinon.stub() - } as unknown as ChatRequestModel; - - const mockResponse = { - isError: params.isError, - errorObject: params.errorObject, - response: { - asDisplayString: () => params.displayString - } - } as unknown as ChatResponseModel; - - return { - requestCompleted: Promise.resolve(mockRequest), - responseCreated: Promise.resolve(mockResponse), - responseCompleted: Promise.resolve(mockResponse) - }; - } - - describe('performSummarization error handling', () => { - it('should return undefined and log warning when sendRequest response has error', async () => { - const sessionId = 'session-with-error'; - const branchId = 'branch-A'; - - // Create a mock model with insertSummary that calls the callback - const modelChangeEmitter = new Emitter(); - const mockModel = { - getBranch: sinon.stub(), - getBranches: sinon.stub().returns([{ id: branchId }]), - getRequest: sinon.stub(), - getRequests: sinon.stub().returns([]), - onDidChange: modelChangeEmitter.event, - insertSummary: sinon.stub().callsFake( - async (callback: () => Promise<{ requestId: string; summaryText: string } | undefined>) => { - const callbackResult = await callback(); - return callbackResult?.summaryText; - } - ) - }; - - const session = { - id: sessionId, - isActive: true, - model: mockModel - } as unknown as ChatSession; - - sessionRegistry.set(sessionId, session); - - // Mock sendRequest to return an invocation with error response - (chatService.sendRequest as sinon.SinonStub).resolves( - createMockInvocation({ - requestId: 'summary-request-id', - isError: true, - displayString: '', - errorObject: new Error('No language model configured') - }) - ); - - // Call triggerSummarization - const result = await service.triggerSummarization(sessionId, false); + describe('markPendingSplit', () => { + it('should store pending split data', () => { + const sessionId = 'session-1'; + const requestId = 'request-1'; + const pendingToolCalls: ToolCall[] = [ + { id: 'tool-1', function: { name: 'test_tool', arguments: '{}' }, finished: false } + ]; + const toolResults = new Map([['tool-1', 'result']]); - // Verify result is undefined (error response returns undefined from callback) - expect(result).to.be.undefined; + service.markPendingSplit(sessionId, requestId, pendingToolCalls, toolResults); - // Verify warning was logged for failed summarization - expect((logger.warn as sinon.SinonStub).called).to.be.true; + // Verify info was logged + expect((logger.info as sinon.SinonStub).calledWithMatch('Marking pending split')).to.be.true; }); + }); - it('should return undefined and log warning when sendRequest returns empty response', async () => { - const sessionId = 'session-with-empty'; - const branchId = 'branch-A'; - - // Create a mock model with insertSummary that calls the callback - const modelChangeEmitter = new Emitter(); - const mockModel = { - getBranch: sinon.stub(), - getBranches: sinon.stub().returns([{ id: branchId }]), - getRequest: sinon.stub(), - getRequests: sinon.stub().returns([]), - onDidChange: modelChangeEmitter.event, - insertSummary: sinon.stub().callsFake( - async (callback: () => Promise<{ requestId: string; summaryText: string } | undefined>) => { - const callbackResult = await callback(); - return callbackResult?.summaryText; - } - ) + describe('checkAndHandleSummarization', () => { + it('should return false when request kind is summary', async () => { + const sessionId = 'session-1'; + const mockAgent = { invoke: sinon.stub() }; + const mockRequest = { + id: 'request-1', + request: { kind: 'summary' }, + response: { isComplete: false } }; - const session = { - id: sessionId, - isActive: true, - model: mockModel - } as unknown as ChatSession; - - sessionRegistry.set(sessionId, session); - - // Mock sendRequest to return an invocation with empty response - (chatService.sendRequest as sinon.SinonStub).resolves( - createMockInvocation({ - requestId: 'summary-request-id', - isError: false, - displayString: ' ' - }) + const result = await service.checkAndHandleSummarization( + sessionId, + mockAgent as unknown as import('../common').ChatAgent, + mockRequest as unknown as import('../common').MutableChatRequestModel ); - // Call triggerSummarization - const result = await service.triggerSummarization(sessionId, false); - - // Verify result is undefined (empty response returns undefined from callback) - expect(result).to.be.undefined; - - // Verify warning was logged - expect((logger.warn as sinon.SinonStub).called).to.be.true; + expect(result).to.be.false; }); - it('should return summary text when response is successful', async () => { - const sessionId = 'session-success'; - const branchId = 'branch-A'; - const summaryText = 'This is a valid summary of the conversation.'; - - // Create a mock model with insertSummary that calls the callback - const modelChangeEmitter = new Emitter(); - const mockModel = { - getBranch: sinon.stub(), - getBranches: sinon.stub().returns([{ id: branchId }]), - getRequest: sinon.stub(), - getRequests: sinon.stub().returns([{ request: { kind: 'summary' }, getDataByKey: sinon.stub() }]), - onDidChange: modelChangeEmitter.event, - insertSummary: sinon.stub().callsFake( - async (callback: () => Promise<{ requestId: string; summaryText: string } | undefined>) => { - const callbackResult = await callback(); - return callbackResult?.summaryText; - } - ) + it('should return false when request kind is continuation', async () => { + const sessionId = 'session-1'; + const mockAgent = { invoke: sinon.stub() }; + const mockRequest = { + id: 'request-1', + request: { kind: 'continuation' }, + response: { isComplete: false } }; - const session = { - id: sessionId, - isActive: true, - model: mockModel - } as unknown as ChatSession; - - sessionRegistry.set(sessionId, session); - - // Mock sendRequest to return a successful invocation - (chatService.sendRequest as sinon.SinonStub).resolves( - createMockInvocation({ - requestId: 'summary-request-id', - isError: false, - displayString: summaryText - }) + const result = await service.checkAndHandleSummarization( + sessionId, + mockAgent as unknown as import('../common').ChatAgent, + mockRequest as unknown as import('../common').MutableChatRequestModel ); - // Call triggerSummarization - const result = await service.triggerSummarization(sessionId, false); - - // Verify result is the summary text - expect(result).to.equal(summaryText); + expect(result).to.be.false; }); - it('should reset token count to output tokens after successful summarization', async () => { - const sessionId = 'session-with-output-tokens'; - const branchId = 'branch-A'; - const summaryText = 'This is a valid summary.'; - const outputTokens = 1500; + it('should return false when tokens are below threshold', async () => { + const sessionId = 'session-1'; + tokenTracker.getSessionInputTokens.returns(100); // Below threshold - // Create a mock model with insertSummary that calls the callback - const modelChangeEmitter = new Emitter(); - const summaryRequestMock = { - request: { kind: 'summary' }, - getDataByKey: sinon.stub().withArgs('capturedOutputTokens').returns(outputTokens) - }; - const mockModel = { - getBranch: sinon.stub(), - getBranches: sinon.stub().returns([{ id: branchId }]), - getRequest: sinon.stub(), - getRequests: sinon.stub().returns([summaryRequestMock]), - onDidChange: modelChangeEmitter.event, - insertSummary: sinon.stub().callsFake( - async (callback: () => Promise<{ requestId: string; summaryText: string } | undefined>) => { - const callbackResult = await callback(); - return callbackResult?.summaryText; - } - ) + const mockAgent = { invoke: sinon.stub() }; + const mockRequest = { + id: 'request-1', + request: { kind: 'user' }, + response: { isComplete: false } }; - const session = { - id: sessionId, - isActive: true, - model: mockModel - } as unknown as ChatSession; - - sessionRegistry.set(sessionId, session); - - // Mock sendRequest to return a successful invocation - (chatService.sendRequest as sinon.SinonStub).resolves( - createMockInvocation({ - requestId: 'summary-request-id', - isError: false, - displayString: summaryText - }) + const result = await service.checkAndHandleSummarization( + sessionId, + mockAgent as unknown as import('../common').ChatAgent, + mockRequest as unknown as import('../common').MutableChatRequestModel ); - // Call triggerSummarization - await service.triggerSummarization(sessionId, false); - - // Verify token tracker was reset to output tokens (not 0) - expect((tokenTracker.resetSessionTokens as sinon.SinonStub).calledWith(sessionId, outputTokens)).to.be.true; - expect((tokenTracker.setBranchTokens as sinon.SinonStub).calledWith(sessionId, branchId, outputTokens)).to.be.true; + expect(result).to.be.false; }); }); @@ -412,9 +268,9 @@ describe('ChatSessionSummarizationServiceImpl', () => { outputTokens: 100 })); - // Verify branchTokens map is updated + // Verify branchTokens map is updated with totalTokens (inputTokens + outputTokens) const branchTokens = tokenTracker.getBranchTokensForSession(sessionId); - expect(branchTokens[branchId]).to.equal(1000); + expect(branchTokens[branchId]).to.equal(1100); // 1000 input + 100 output }); it('should update branchTokens when token usage event is for active branch', () => { @@ -433,9 +289,9 @@ describe('ChatSessionSummarizationServiceImpl', () => { outputTokens: 200 })); - // Verify branchTokens was updated (which confirms the handler ran and processed active branch) + // Verify branchTokens was updated with totalTokens (inputTokens + outputTokens) const branchTokens = tokenTracker.getBranchTokensForSession(sessionId); - expect(branchTokens[activeBranchId]).to.equal(5000); + expect(branchTokens[activeBranchId]).to.equal(5200); // 5000 input + 200 output }); it('should NOT trigger tracker reset for non-active branch but should store tokens', () => { @@ -464,9 +320,9 @@ describe('ChatSessionSummarizationServiceImpl', () => { // Verify tokenTracker.resetSessionTokens was NOT called additionally expect(tokenTracker.resetSessionTokens.callCount).to.equal(callCountBefore); - // But branchTokens should be updated + // But branchTokens should be updated with totalTokens (inputTokens + outputTokens) const branchTokens = tokenTracker.getBranchTokensForSession(sessionId); - expect(branchTokens[nonActiveBranchId]).to.equal(3000); + expect(branchTokens[nonActiveBranchId]).to.equal(3150); // 3000 input + 150 output }); it('should restore stored tokens when branch changes', () => { @@ -538,7 +394,7 @@ describe('ChatSessionSummarizationServiceImpl', () => { expect(tokenTracker.resetSessionTokens.calledWith(sessionId, undefined)).to.be.true; }); - it('should only trigger threshold for active branch', async () => { + it('should reset session tokens for active branch with valid input tokens', async () => { const sessionId = 'session-6'; const activeBranchId = 'branch-active'; const nonActiveBranchId = 'branch-other'; @@ -551,11 +407,7 @@ describe('ChatSessionSummarizationServiceImpl', () => { sessionRegistry.set(sessionId, session); - // Spy on handleThresholdExceeded - const handleThresholdSpy = sinon.spy( - service as unknown as { handleThresholdExceeded: (event: { sessionId: string; inputTokens: number }) => Promise }, - 'handleThresholdExceeded' - ); + const resetCallCountBefore = tokenTracker.resetSessionTokens.callCount; // Fire token usage event exceeding threshold for NON-active branch tokenUsageEmitter.fire(createTokenUsage({ @@ -565,8 +417,8 @@ describe('ChatSessionSummarizationServiceImpl', () => { outputTokens: 100 })); - // handleThresholdExceeded should NOT be called for non-active branch - expect(handleThresholdSpy.called).to.be.false; + // resetSessionTokens should NOT be called for non-active branch + expect(tokenTracker.resetSessionTokens.callCount).to.equal(resetCallCountBefore); // Now fire for active branch tokenUsageEmitter.fire(createTokenUsage({ @@ -576,12 +428,8 @@ describe('ChatSessionSummarizationServiceImpl', () => { outputTokens: 100 })); - // handleThresholdExceeded SHOULD be called for active branch - expect(handleThresholdSpy.calledOnce).to.be.true; - expect(handleThresholdSpy.calledWith({ - sessionId, - inputTokens: CHAT_TOKEN_THRESHOLD + 10000 - })).to.be.true; + // resetSessionTokens SHOULD be called for active branch with totalTokens (inputTokens + outputTokens) + expect(tokenTracker.resetSessionTokens.calledWith(sessionId, CHAT_TOKEN_THRESHOLD + 10100)).to.be.true; // threshold + 10000 input + 100 output }); it('should remove all branch entries when session is deleted', () => { @@ -593,9 +441,10 @@ describe('ChatSessionSummarizationServiceImpl', () => { tokenTracker.setBranchTokens('other-session', 'branch-X', 5000); const triggeredBranchesSet = (service as unknown as { triggeredBranches: Set }).triggeredBranches; - triggeredBranchesSet.add(`${sessionId}:branch-A`); - triggeredBranchesSet.add(`${sessionId}:branch-B`); - triggeredBranchesSet.add('other-session:branch-X'); + // Note: cleanupSession uses prefix `${sessionId}: ` (with trailing space) for matching + triggeredBranchesSet.add(`${sessionId}: branch-A`); + triggeredBranchesSet.add(`${sessionId}: branch-B`); + triggeredBranchesSet.add('other-session: branch-X'); // Fire session deleted event sessionEventEmitter.fire({ type: 'deleted', sessionId }); @@ -604,11 +453,11 @@ describe('ChatSessionSummarizationServiceImpl', () => { expect((tokenTracker.clearSessionBranchTokens as sinon.SinonStub).calledWith(sessionId)).to.be.true; // Verify triggeredBranches entries for deleted session are removed - expect(triggeredBranchesSet.has(`${sessionId}:branch-A`)).to.be.false; - expect(triggeredBranchesSet.has(`${sessionId}:branch-B`)).to.be.false; + expect(triggeredBranchesSet.has(`${sessionId}: branch-A`)).to.be.false; + expect(triggeredBranchesSet.has(`${sessionId}: branch-B`)).to.be.false; // Verify other session's triggeredBranches entries are preserved - expect(triggeredBranchesSet.has('other-session:branch-X')).to.be.true; + expect(triggeredBranchesSet.has('other-session: branch-X')).to.be.true; }); it('should populate branchTokens on persistence restore', () => { @@ -697,7 +546,7 @@ describe('ChatSessionSummarizationServiceImpl', () => { expect(tokenTracker.resetSessionTokens.callCount).to.equal(callCountBefore); }); - it('should handle cached input tokens correctly', () => { + it('should not double-count cached input tokens (inputTokens already includes cached)', () => { const sessionId = 'session-8'; const branchId = 'branch-A'; const requestId = `request-for-${branchId}`; @@ -706,18 +555,21 @@ describe('ChatSessionSummarizationServiceImpl', () => { sessionRegistry.set(sessionId, session); // Fire token usage event with cached tokens + // Per Anthropic API: inputTokens already INCLUDES cached tokens + // cachedInputTokens and readCachedInputTokens are just subsets indicating WHERE tokens came from tokenUsageEmitter.fire(createTokenUsage({ sessionId, requestId, - inputTokens: 1000, - cachedInputTokens: 500, - readCachedInputTokens: 200, + inputTokens: 1000, // This already includes any cached tokens + cachedInputTokens: 500, // Subset: 500 of the 1000 were cache writes + readCachedInputTokens: 200, // Subset: 200 of the 1000 were cache reads outputTokens: 100 })); - // Verify branchTokens includes all input token types + // Verify branchTokens uses only inputTokens (not sum with cached) + // totalInputTokens should be 1000, not 1000 + 500 + 200 = 1700 const branchTokens = tokenTracker.getBranchTokensForSession(sessionId); - expect(branchTokens[branchId]).to.equal(1700); // 1000 + 500 + 200 + expect(branchTokens[branchId]).to.equal(1100); // 1000 (input) + 100 (output), NOT 1800 }); it('should not update branchTokens when session is not found', () => { @@ -766,5 +618,21 @@ describe('ChatSessionSummarizationServiceImpl', () => { // Verify tokenTracker.resetSessionTokens was NOT called additionally expect(tokenTracker.resetSessionTokens.callCount).to.equal(callCountBefore); }); + + }); + + describe('cleanupSession', () => { + it('should clean up pendingSplits when session is deleted', () => { + const sessionId = 'session-to-cleanup'; + + // Add pending split + service.markPendingSplit(sessionId, 'request-1', [], new Map()); + + // Fire session deleted event + sessionEventEmitter.fire({ type: 'deleted', sessionId }); + + // Verify tokenTracker cleanup was called + expect((tokenTracker.clearSessionBranchTokens as sinon.SinonStub).calledWith(sessionId)).to.be.true; + }); }); }); diff --git a/packages/ai-chat/src/browser/chat-session-summarization-service.ts b/packages/ai-chat/src/browser/chat-session-summarization-service.ts index 8f1a808317468..9970ee122d950 100644 --- a/packages/ai-chat/src/browser/chat-session-summarization-service.ts +++ b/packages/ai-chat/src/browser/chat-session-summarization-service.ts @@ -17,14 +17,20 @@ import { inject, injectable, postConstruct } from '@theia/core/shared/inversify'; import { ILogger, nls } from '@theia/core'; import { FrontendApplicationContribution } from '@theia/core/lib/browser'; -import { TokenUsage, TokenUsageServiceClient } from '@theia/ai-core'; +import { ToolCall, ToolCallResult, TokenUsage, TokenUsageServiceClient } from '@theia/ai-core'; import { + ChatAgent, ChatService, ChatSession, ErrorChatResponseContent, ErrorChatResponseContentImpl, MutableChatModel, - MutableChatRequestModel + MutableChatRequestModel, + MutableChatResponseModel, + ParsedChatRequest, + SummaryChatResponseContentImpl, + ToolCallChatResponseContent, + ToolCallChatResponseContentImpl } from '../common'; import { isSessionCreatedEvent, isSessionDeletedEvent } from '../common/chat-service'; import { @@ -50,15 +56,30 @@ export interface ChatSessionSummarizationService { hasSummary(sessionId: string): boolean; /** - * Trigger summarization for a session. - * Called by the budget-aware tool loop when token threshold is exceeded mid-turn. - * - * @param sessionId The session to summarize - * @param skipReorder If true, skip removing/re-adding the trigger request (for mid-turn summarization) - * @returns Promise that resolves with the summary text on success, or `undefined` on failure + * Mark a pending mid-turn split. Called by the tool loop when budget exceeded. + * The split will be handled by checkAndHandleSummarization() after addContentsToResponse(). */ - triggerSummarization(sessionId: string, skipReorder: boolean): Promise; + markPendingSplit( + sessionId: string, + requestId: string, + pendingToolCalls: ToolCall[], + toolResults: Map + ): void; + /** + * Check and handle summarization after response content is added. + * Handles both mid-turn splits (from markPendingSplit) and between-turn summarization. + * + * @param sessionId The session ID + * @param agent The chat agent to invoke for summary/continuation + * @param request The current request being processed + * @returns true if summarization was triggered (caller should skip onResponseComplete), false otherwise + */ + checkAndHandleSummarization( + sessionId: string, + agent: ChatAgent, + request: MutableChatRequestModel + ): Promise; } @injectable() @@ -83,10 +104,28 @@ export class ChatSessionSummarizationServiceImpl implements ChatSessionSummariza /** * Tracks which branches have triggered summarization. * Key format: `${sessionId}:${branchId}` - * Used for deduplication (prevents multiple triggers for the same branch). + * + * Used for deduplication (prevents multiple triggers for the same branch during a single growth cycle). + * + * **Cleanup behavior:** + * - After successful summarization: The branch key is REMOVED to allow future re-triggering + * when tokens grow again past the threshold. + * - On session deletion: All entries with matching sessionId prefix are removed. + * - On branch change: New branches automatically get fresh tracking since their branchId + * differs from previously tracked branches. */ protected triggeredBranches: Set = new Set(); + /** + * Stores pending mid-turn split data, keyed by sessionId. + * Consumed by checkAndHandleSummarization() after addContentsToResponse(). + */ + protected pendingSplits = new Map; + }>(); + @postConstruct() protected init(): void { // Listen to token usage events and attribute to correct branch @@ -160,19 +199,18 @@ export class ChatSessionSummarizationServiceImpl implements ChatSessionSummariza return; } - const totalInputTokens = usage.inputTokens + (usage.cachedInputTokens ?? 0) + (usage.readCachedInputTokens ?? 0); - this.tokenTracker.setBranchTokens(usage.sessionId, branch.id, totalInputTokens); + const totalInputTokens = usage.inputTokens; + const totalTokens = totalInputTokens + (usage.outputTokens ?? 0); + + // Update branch tokens (for branch switching) + if (totalTokens > 0) { + this.tokenTracker.setBranchTokens(usage.sessionId, branch.id, totalTokens); + } const activeBranchId = this.getActiveBranchId(session); - if (branch.id === activeBranchId) { - this.tokenTracker.resetSessionTokens(usage.sessionId, totalInputTokens); - // Check threshold for active branch only - const branchKey = `${usage.sessionId}:${branch.id}`; - if (totalInputTokens >= CHAT_TOKEN_THRESHOLD && !this.triggeredBranches.has(branchKey)) { - this.triggeredBranches.add(branchKey); - this.handleThresholdExceeded({ sessionId: usage.sessionId, inputTokens: totalInputTokens }); - } + if (branch.id === activeBranchId && totalTokens > 0) { + this.tokenTracker.resetSessionTokens(usage.sessionId, totalTokens); } } @@ -190,15 +228,155 @@ export class ChatSessionSummarizationServiceImpl implements ChatSessionSummariza }); } - async triggerSummarization(sessionId: string, skipReorder: boolean): Promise { + markPendingSplit( + sessionId: string, + requestId: string, + pendingToolCalls: ToolCall[], + toolResults: Map + ): void { + this.logger.info(`Marking pending split for session ${sessionId}, request ${requestId}, ${pendingToolCalls.length} tool calls`); + this.pendingSplits.set(sessionId, { requestId, pendingToolCalls, toolResults }); + } + + async checkAndHandleSummarization( + sessionId: string, + agent: ChatAgent, + request: MutableChatRequestModel + ): Promise { + // Check for pending mid-turn split first + const pendingSplit = this.pendingSplits.get(sessionId); + if (pendingSplit) { + // Consume immediately to prevent re-entry + this.pendingSplits.delete(sessionId); + await this.handleMidTurnSplit(sessionId, agent, request, pendingSplit); + return true; + } + + // Between-turn check: skip if summary or continuation request + if (request.request.kind === 'summary' || request.request.kind === 'continuation') { + return false; + } + + // Check if threshold exceeded for between-turn summarization + const tokens = this.tokenTracker.getSessionInputTokens(sessionId); + if (tokens === undefined || tokens < CHAT_TOKEN_THRESHOLD) { + return false; + } + + // Between-turn summarization - trigger via existing performSummarization const session = this.chatService.getSession(sessionId); if (!session) { - this.logger.warn(`Session ${sessionId} not found for summarization`); - return undefined; + return false; + } + + // Complete current response first if not already + if (!request.response.isComplete) { + request.response.complete(); } - this.logger.info(`Mid-turn summarization triggered for session ${sessionId}`); - return this.performSummarization(sessionId, session.model as MutableChatModel, skipReorder); + // Use existing performSummarization for between-turn (it marks stale after summary) + await this.performSummarization(sessionId, session.model as MutableChatModel); + return true; + } + + protected async handleMidTurnSplit( + sessionId: string, + agent: ChatAgent, + request: MutableChatRequestModel, + pendingSplit: { requestId: string; pendingToolCalls: ToolCall[]; toolResults: Map } + ): Promise { + const session = this.chatService.getSession(sessionId); + if (!session) { + this.logger.warn(`Session ${sessionId} not found for mid-turn split`); + return; + } + const model = session.model as MutableChatModel; + + // Step 1: Remove pending tool calls from current response + this.removePendingToolCallsFromResponse(request.response, pendingSplit.pendingToolCalls); + + // Step 2: Complete current response + request.response.complete(); + + // Step 3: Create and invoke summary request (NO stale marking yet - summary needs full history) + // eslint-disable-next-line max-len + const summaryPrompt = 'Please provide a concise summary of our conversation so far, capturing all key requirements, decisions, context, and pending tasks so we can seamlessly continue. Do not include conversational elements, questions, or offers to continue. Do not start with a heading - output only the summary content.'; + + const summaryParsedRequest: ParsedChatRequest = { + request: { text: summaryPrompt, kind: 'summary' }, + parts: [{ kind: 'text', text: summaryPrompt, promptText: summaryPrompt, range: { start: 0, endExclusive: summaryPrompt.length } }], + toolRequests: new Map(), + variables: [] + }; + const summaryRequest = model.addRequest(summaryParsedRequest, undefined, { variables: [] }); + + // Invoke agent for summary (will populate summaryRequest.response) + await agent.invoke(summaryRequest); + + // Get summary text from response + const summaryText = summaryRequest.response.response.asDisplayString()?.trim() || ''; + + // Replace response content with SummaryChatResponseContent for proper rendering + summaryRequest.response.response.clearContent(); + summaryRequest.response.response.addContent(new SummaryChatResponseContentImpl(summaryText)); + + // Step 4: Mark ALL requests stale AFTER summary is generated (summary needed full history) + const allRequestsAfterSummary = model.getRequests(); + for (const req of allRequestsAfterSummary) { + // Don't mark continuation request stale (it will be created next) + if (req.request.kind !== 'continuation') { + (req as MutableChatRequestModel).isStale = true; + } + } + + // Step 5: Create continuation request with tool call content in response + // Include the summary plus an instruction to continue + const continuationSuffix = 'The tool call above was executed. Please continue with your task ' + + 'based on the result. If you need to make more tool calls to complete the task, please do so. ' + + 'Once you have all the information needed, provide your final response.'; + const continuationInstruction = `${summaryText}\n\n${continuationSuffix}`; + + const continuationParsedRequest: ParsedChatRequest = { + request: { text: continuationInstruction, kind: 'continuation' }, + parts: [{ kind: 'text', text: continuationInstruction, promptText: continuationInstruction, range: { start: 0, endExclusive: continuationInstruction.length } }], + toolRequests: new Map(), + variables: [] + }; + const continuationRequest = model.addRequest(continuationParsedRequest, undefined, { variables: [] }); + + // Add tool call content to response for UI display + for (const toolCall of pendingSplit.pendingToolCalls) { + const result = pendingSplit.toolResults.get(toolCall.id!); + const toolContent = new ToolCallChatResponseContentImpl( + toolCall.id, + toolCall.function?.name, + toolCall.function?.arguments, + true, // finished + result + ); + continuationRequest.response.response.addContent(toolContent); + } + + // Step 6: Invoke agent for continuation (token tracking will update from LLM response) + await agent.invoke(continuationRequest); + } + + protected removePendingToolCallsFromResponse( + response: MutableChatResponseModel, + pendingToolCalls: ToolCall[] + ): void { + const pendingIds = new Set(pendingToolCalls.map(tc => tc.id).filter(Boolean)); + const content = response.response.content; + + const filteredContent = content.filter(c => { + if (ToolCallChatResponseContent.is(c) && c.id && pendingIds.has(c.id)) { + return false; + } + return true; + }); + + response.response.clearContent(); + response.response.addContents(filteredContent); } protected async handleThresholdExceeded(event: SessionTokenThresholdEvent): Promise { @@ -214,10 +392,49 @@ export class ChatSessionSummarizationServiceImpl implements ChatSessionSummariza return; } - this.logger.info(`Token threshold exceeded for session ${sessionId}: ${inputTokens} tokens. Starting summarization...`); + this.logger.info(`Token threshold exceeded for session ${sessionId}: ${inputTokens} tokens.Starting summarization...`); await this.performSummarization(sessionId, session.model as MutableChatModel); } + /** + * Execute a callback with summarization lock for the session. + * Ensures lock is released even if callback throws. + */ + protected async withSummarizationLock( + sessionId: string, + callback: () => Promise + ): Promise { + if (this.summarizingSession.has(sessionId)) { + return undefined; + } + this.summarizingSession.add(sessionId); + try { + return await callback(); + } finally { + this.summarizingSession.delete(sessionId); + } + } + + /** + * Update token tracking after successful summarization. + */ + protected updateTokenTrackingAfterSummary( + sessionId: string, + outputTokens: number + ): void { + this.tokenTracker.resetSessionTokens(sessionId, outputTokens); + // Update branch tokens and allow future re-triggering + const session = this.chatService.getSession(sessionId); + if (session) { + const activeBranchId = this.getActiveBranchId(session); + if (activeBranchId) { + const branchKey = `${sessionId}:${activeBranchId} `; + this.tokenTracker.setBranchTokens(sessionId, activeBranchId, outputTokens); + this.triggeredBranches.delete(branchKey); + } + } + } + /** * Core summarization logic shared by both threshold-triggered and explicit mid-turn summarization. * @@ -226,100 +443,87 @@ export class ChatSessionSummarizationServiceImpl implements ChatSessionSummariza * @returns The summary text on success, or `undefined` on failure */ protected async performSummarization(sessionId: string, model: MutableChatModel, skipReorder?: boolean): Promise { - if (this.summarizingSession.has(sessionId)) { - return undefined; - } - - this.summarizingSession.add(sessionId); - - try { + return this.withSummarizationLock(sessionId, async () => { // Always use 'end' position - reordering breaks the hierarchy structure // because the summary is added as continuation of the trigger request const position = 'end'; // eslint-disable-next-line max-len - const summaryPrompt = 'Please provide a concise summary of our conversation so far, capturing all key requirements, decisions, context, and pending tasks so we can seamlessly continue.'; - - const summaryText = await model.insertSummary( - async () => { - const invocation = await this.chatService.sendRequest(sessionId, { - text: summaryPrompt, - kind: 'summary' - }); - if (!invocation) { - return undefined; - } - - const request = await invocation.requestCompleted; - - // Set up token listener to capture output tokens - let capturedOutputTokens: number | undefined; - const tokenListener = this.tokenUsageClient.onTokenUsageUpdated(usage => { - if (usage.sessionId === sessionId && usage.requestId === request.id) { - capturedOutputTokens = usage.outputTokens; - } - }); - - try { - const response = await invocation.responseCompleted; - - // Validate response - const summaryResponseText = response.response.asDisplayString()?.trim(); - if (response.isError || !summaryResponseText) { + const summaryPrompt = 'Please provide a concise summary of our conversation so far, capturing all key requirements, decisions, context, and pending tasks so we can seamlessly continue. ' + + 'Do not include conversational elements, questions, or offers to continue. Do not start with a heading - output only the summary content.'; + + try { + const summaryText = await model.insertSummary( + async () => { + const invocation = await this.chatService.sendRequest(sessionId, { + text: summaryPrompt, + kind: 'summary' + }); + if (!invocation) { return undefined; } - // Store captured output tokens on request for later retrieval - if (capturedOutputTokens !== undefined) { - (request as MutableChatRequestModel).addData('capturedOutputTokens', capturedOutputTokens); + const request = await invocation.requestCompleted; + + // Set up token listener to capture output tokens + let capturedOutputTokens: number | undefined; + const tokenListener = this.tokenUsageClient.onTokenUsageUpdated(usage => { + if (usage.sessionId === sessionId && usage.requestId === request.id) { + capturedOutputTokens = usage.outputTokens; + } + }); + + try { + const response = await invocation.responseCompleted; + + // Validate response + const summaryResponseText = response.response.asDisplayString()?.trim(); + if (response.isError || !summaryResponseText) { + return undefined; + } + + // Replace agent's markdown content with SummaryChatResponseContent for proper rendering + const mutableRequest = request as MutableChatRequestModel; + mutableRequest.response.response.clearContent(); + mutableRequest.response.response.addContent(new SummaryChatResponseContentImpl(summaryResponseText)); + + // Store captured output tokens on request for later retrieval + if (capturedOutputTokens !== undefined) { + (request as MutableChatRequestModel).addData('capturedOutputTokens', capturedOutputTokens); + } + + return { + requestId: request.id, + summaryText: summaryResponseText + }; + } finally { + tokenListener.dispose(); } + }, + position + ); + + if (!summaryText) { + this.logger.warn(`Summarization failed for session ${sessionId}`); + this.notifyUserOfFailure(model); + return undefined; + } - return { - requestId: request.id, - summaryText: summaryResponseText - }; - } finally { - tokenListener.dispose(); - } - }, - position - ); - - if (!summaryText) { - this.logger.warn(`Summarization failed for session ${sessionId}`); - this.notifyUserOfFailure(model); - return undefined; - } - - this.logger.info(`Added summary node to session ${sessionId}`); + this.logger.info(`Added summary node to session ${sessionId} `); - // Find the summary request to get captured output tokens - const summaryRequest = model.getRequests().find(r => r.request.kind === 'summary'); - const outputTokens = summaryRequest?.getDataByKey('capturedOutputTokens') ?? 0; + // Find the summary request to get captured output tokens + const summaryRequest = model.getRequests().find(r => r.request.kind === 'summary'); + const outputTokens = summaryRequest?.getDataByKey('capturedOutputTokens') ?? 0; - // Reset token count to the summary's output tokens (the new context size) - this.tokenTracker.resetSessionTokens(sessionId, outputTokens); - this.logger.info(`Reset token count for session ${sessionId} to ${outputTokens} after summarization`); + this.updateTokenTrackingAfterSummary(sessionId, outputTokens); + this.logger.info(`Reset token count for session ${sessionId} to ${outputTokens} after summarization`); - // Update branch tokens and allow future re-triggering - const session = this.chatService.getSession(sessionId); - if (session) { - const activeBranchId = this.getActiveBranchId(session); - if (activeBranchId) { - const branchKey = `${sessionId}:${activeBranchId}`; - this.tokenTracker.setBranchTokens(sessionId, activeBranchId, outputTokens); - this.triggeredBranches.delete(branchKey); - } + return summaryText; + } catch (error) { + this.logger.error(`Failed to summarize session ${sessionId}: `, error); + this.notifyUserOfFailure(model); + return undefined; } - - return summaryText; - - } catch (error) { - this.logger.error(`Failed to summarize session ${sessionId}:`, error); - this.notifyUserOfFailure(model); - return undefined; - } finally { - this.summarizingSession.delete(sessionId); - } + }); } /** @@ -362,12 +566,12 @@ export class ChatSessionSummarizationServiceImpl implements ChatSessionSummariza */ protected cleanupSession(sessionId: string): void { this.tokenTracker.clearSessionBranchTokens(sessionId); - const prefix = `${sessionId}:`; + this.pendingSplits.delete(sessionId); + const prefix = `${sessionId}: `; for (const key of this.triggeredBranches.keys()) { if (key.startsWith(prefix)) { this.triggeredBranches.delete(key); } } } - } diff --git a/packages/ai-chat/src/browser/chat-session-token-tracker.spec.ts b/packages/ai-chat/src/browser/chat-session-token-tracker.spec.ts index a08d4c7624251..02202c6a336cb 100644 --- a/packages/ai-chat/src/browser/chat-session-token-tracker.spec.ts +++ b/packages/ai-chat/src/browser/chat-session-token-tracker.spec.ts @@ -35,6 +35,30 @@ describe('ChatSessionTokenTrackerImpl', () => { }); }); + describe('getSessionOutputTokens', () => { + it('should return undefined for unknown session', () => { + expect(tracker.getSessionOutputTokens('unknown-session')).to.be.undefined; + }); + }); + + describe('getSessionTotalTokens', () => { + it('should return undefined for unknown session', () => { + expect(tracker.getSessionTotalTokens('unknown-session')).to.be.undefined; + }); + + it('should return input tokens when only input is set', () => { + const sessionId = 'session-1'; + tracker.updateSessionTokens(sessionId, 5000); + expect(tracker.getSessionTotalTokens(sessionId)).to.equal(5000); + }); + + it('should return sum of input and output tokens', () => { + const sessionId = 'session-1'; + tracker.updateSessionTokens(sessionId, 5000, 100); + expect(tracker.getSessionTotalTokens(sessionId)).to.equal(5100); + }); + }); + describe('resetSessionTokens', () => { it('should update token count and fire onSessionTokensUpdated', () => { const sessionId = 'session-1'; @@ -46,18 +70,22 @@ describe('ChatSessionTokenTrackerImpl', () => { tracker.resetSessionTokens(sessionId, 50000); expect(tracker.getSessionInputTokens(sessionId)).to.equal(50000); + expect(tracker.getSessionOutputTokens(sessionId)).to.be.undefined; expect(updateEvents).to.have.length(1); expect(updateEvents[0].sessionId).to.equal(sessionId); expect(updateEvents[0].inputTokens).to.equal(50000); + expect(updateEvents[0].outputTokens).to.be.undefined; // Reset to new baseline (simulating post-summarization) const newTokenCount = 10000; tracker.resetSessionTokens(sessionId, newTokenCount); expect(tracker.getSessionInputTokens(sessionId)).to.equal(newTokenCount); + expect(tracker.getSessionOutputTokens(sessionId)).to.be.undefined; expect(updateEvents).to.have.length(2); expect(updateEvents[1].sessionId).to.equal(sessionId); expect(updateEvents[1].inputTokens).to.equal(newTokenCount); + expect(updateEvents[1].outputTokens).to.be.undefined; }); it('should delete token count and emit undefined when called with undefined', () => { @@ -76,9 +104,90 @@ describe('ChatSessionTokenTrackerImpl', () => { tracker.resetSessionTokens(sessionId, undefined); expect(tracker.getSessionInputTokens(sessionId)).to.be.undefined; + expect(tracker.getSessionOutputTokens(sessionId)).to.be.undefined; expect(updateEvents).to.have.length(2); expect(updateEvents[1].sessionId).to.equal(sessionId); expect(updateEvents[1].inputTokens).to.be.undefined; + expect(updateEvents[1].outputTokens).to.be.undefined; + }); + + it('should clear output tokens when resetting', () => { + const sessionId = 'session-1'; + + // Set both input and output tokens via updateSessionTokens + tracker.updateSessionTokens(sessionId, 5000, 500); + expect(tracker.getSessionOutputTokens(sessionId)).to.equal(500); + + // Reset should clear output tokens + tracker.resetSessionTokens(sessionId, 3000); + expect(tracker.getSessionInputTokens(sessionId)).to.equal(3000); + expect(tracker.getSessionOutputTokens(sessionId)).to.be.undefined; + }); + }); + + describe('updateSessionTokens', () => { + it('should set input tokens and reset output to 0 when input provided', () => { + const sessionId = 'session-1'; + const updateEvents: SessionTokenUpdateEvent[] = []; + + tracker.onSessionTokensUpdated(event => updateEvents.push(event)); + + tracker.updateSessionTokens(sessionId, 5000); + + expect(tracker.getSessionInputTokens(sessionId)).to.equal(5000); + expect(tracker.getSessionOutputTokens(sessionId)).to.equal(0); + expect(updateEvents).to.have.length(1); + expect(updateEvents[0].inputTokens).to.equal(5000); + expect(updateEvents[0].outputTokens).to.equal(0); + }); + + it('should update output tokens progressively', () => { + const sessionId = 'session-1'; + const updateEvents: SessionTokenUpdateEvent[] = []; + + tracker.onSessionTokensUpdated(event => updateEvents.push(event)); + + // Initial request with input tokens + tracker.updateSessionTokens(sessionId, 5000, 0); + expect(tracker.getSessionOutputTokens(sessionId)).to.equal(0); + + // Progressive updates during streaming + tracker.updateSessionTokens(sessionId, undefined, 100); + expect(tracker.getSessionOutputTokens(sessionId)).to.equal(100); + expect(tracker.getSessionInputTokens(sessionId)).to.equal(5000); // Input unchanged + + tracker.updateSessionTokens(sessionId, undefined, 250); + expect(tracker.getSessionOutputTokens(sessionId)).to.equal(250); + + expect(updateEvents).to.have.length(3); + expect(updateEvents[2].inputTokens).to.equal(5000); + expect(updateEvents[2].outputTokens).to.equal(250); + }); + + it('should reset output to 0 when new input tokens arrive (new request)', () => { + const sessionId = 'session-1'; + + // First request + tracker.updateSessionTokens(sessionId, 5000, 500); + expect(tracker.getSessionTotalTokens(sessionId)).to.equal(5500); + + // New request starts - input tokens set, output resets + tracker.updateSessionTokens(sessionId, 5500); + expect(tracker.getSessionInputTokens(sessionId)).to.equal(5500); + expect(tracker.getSessionOutputTokens(sessionId)).to.equal(0); + expect(tracker.getSessionTotalTokens(sessionId)).to.equal(5500); + }); + + it('should not update input tokens when input is 0', () => { + const sessionId = 'session-1'; + + tracker.updateSessionTokens(sessionId, 5000); + expect(tracker.getSessionInputTokens(sessionId)).to.equal(5000); + + // Input of 0 should not reset + tracker.updateSessionTokens(sessionId, 0, 100); + expect(tracker.getSessionInputTokens(sessionId)).to.equal(5000); + expect(tracker.getSessionOutputTokens(sessionId)).to.equal(100); }); }); diff --git a/packages/ai-chat/src/browser/chat-session-token-tracker.ts b/packages/ai-chat/src/browser/chat-session-token-tracker.ts index 018b2d987cedd..0c39b9ccc6e6c 100644 --- a/packages/ai-chat/src/browser/chat-session-token-tracker.ts +++ b/packages/ai-chat/src/browser/chat-session-token-tracker.ts @@ -30,10 +30,21 @@ export interface SessionTokenThresholdEvent { } /** - * Hardcoded token budget and threshold for chat sessions. + * The maximum token budget for a chat session. + * This represents the approximate context window size that chat sessions target. */ export const CHAT_TOKEN_BUDGET = 200000; + +/** + * The percentage of the token budget at which summarization is triggered. + */ export const CHAT_TOKEN_THRESHOLD_PERCENT = 0.9; + +/** + * The token threshold at which summarization is triggered. + * When input tokens reach this value (90% of budget), the system will + * attempt to summarize the conversation to stay within context limits. + */ export const CHAT_TOKEN_THRESHOLD = CHAT_TOKEN_BUDGET * CHAT_TOKEN_THRESHOLD_PERCENT; @injectable() @@ -45,7 +56,13 @@ export class ChatSessionTokenTrackerImpl implements ChatSessionTokenTracker { * Map of sessionId -> latest inputTokens count. * Updated when token usage is reported for requests in that session. */ - protected sessionTokens = new Map(); + protected sessionInputTokens = new Map(); + + /** + * Map of sessionId -> latest outputTokens count. + * Updated progressively during streaming. + */ + protected sessionOutputTokens = new Map(); /** * Map of branch tokens. Key format: `${sessionId}:${branchId}` @@ -53,7 +70,20 @@ export class ChatSessionTokenTrackerImpl implements ChatSessionTokenTracker { protected branchTokens = new Map(); getSessionInputTokens(sessionId: string): number | undefined { - return this.sessionTokens.get(sessionId); + return this.sessionInputTokens.get(sessionId); + } + + getSessionOutputTokens(sessionId: string): number | undefined { + return this.sessionOutputTokens.get(sessionId); + } + + getSessionTotalTokens(sessionId: string): number | undefined { + const input = this.sessionInputTokens.get(sessionId); + const output = this.sessionOutputTokens.get(sessionId); + if (input === undefined && output === undefined) { + return undefined; + } + return (input ?? 0) + (output ?? 0); } /** @@ -63,15 +93,31 @@ export class ChatSessionTokenTrackerImpl implements ChatSessionTokenTracker { * * @param sessionId - The session ID to reset * @param newTokenCount - The new token count, or `undefined` to indicate unknown state. - * When `undefined`, deletes the stored count and emits `{ inputTokens: undefined }`. + * When `undefined`, deletes the stored count and emits `{ inputTokens: undefined, outputTokens: undefined }`. */ resetSessionTokens(sessionId: string, newTokenCount: number | undefined): void { if (newTokenCount === undefined) { - this.sessionTokens.delete(sessionId); + this.sessionInputTokens.delete(sessionId); } else { - this.sessionTokens.set(sessionId, newTokenCount); + this.sessionInputTokens.set(sessionId, newTokenCount); + } + this.sessionOutputTokens.delete(sessionId); + this.onSessionTokensUpdatedEmitter.fire({ sessionId, inputTokens: newTokenCount, outputTokens: undefined }); + } + + updateSessionTokens(sessionId: string, inputTokens?: number, outputTokens?: number): void { + if (inputTokens !== undefined && inputTokens > 0) { + this.sessionInputTokens.set(sessionId, inputTokens); + this.sessionOutputTokens.set(sessionId, 0); + } + if (outputTokens !== undefined) { + this.sessionOutputTokens.set(sessionId, outputTokens); } - this.onSessionTokensUpdatedEmitter.fire({ sessionId, inputTokens: newTokenCount }); + this.onSessionTokensUpdatedEmitter.fire({ + sessionId, + inputTokens: this.sessionInputTokens.get(sessionId), + outputTokens: this.sessionOutputTokens.get(sessionId) + }); } setBranchTokens(sessionId: string, branchId: string, tokens: number): void { diff --git a/packages/ai-chat/src/common/chat-agents.ts b/packages/ai-chat/src/common/chat-agents.ts index ed1afaab9fec2..24dd673b8d6b7 100644 --- a/packages/ai-chat/src/common/chat-agents.ts +++ b/packages/ai-chat/src/common/chat-agents.ts @@ -51,7 +51,7 @@ import { LanguageModelStreamResponsePart } from '@theia/ai-core/lib/common'; import { ContributionProvider, ILogger, isArray, nls } from '@theia/core'; -import { inject, injectable, named, postConstruct } from '@theia/core/shared/inversify'; +import { inject, injectable, named, optional, postConstruct } from '@theia/core/shared/inversify'; import { ChatAgentService } from './chat-agent-service'; import { ChatModel, @@ -97,6 +97,24 @@ export namespace SystemMessageDescription { } } +/** + * Symbol for optional injection of the ChatSummarizationService. + * This allows browser implementations to provide summarization support. + */ +export const ChatSessionSummarizationServiceSymbol = Symbol('ChatSessionSummarizationService'); + +/** + * Minimal interface for chat summarization service. + * Used by AbstractChatAgent to optionally trigger summarization. + */ +export interface ChatSummarizationService { + checkAndHandleSummarization( + sessionId: string, + agent: ChatAgent, + request: MutableChatRequestModel + ): Promise; +} + export interface ChatSessionContext extends AIVariableContext { request?: ChatRequestModel; model: ChatModel; @@ -170,6 +188,9 @@ export abstract class AbstractChatAgent implements ChatAgent { @inject(DefaultResponseContentFactory) protected defaultContentFactory: DefaultResponseContentFactory; + @inject(ChatSessionSummarizationServiceSymbol) @optional() + protected summarizationService?: ChatSummarizationService; + readonly abstract id: string; readonly abstract name: string; readonly abstract languageModelRequirements: LanguageModelRequirement[]; @@ -232,7 +253,10 @@ export abstract class AbstractChatAgent implements ChatAgent { const languageModelResponse = await this.sendLlmRequest(request, messages, tools, languageModel); await this.addContentsToResponse(languageModelResponse, request); - await this.onResponseComplete(request); + const summarizationHandled = await this.checkSummarization(request); + if (!summarizationHandled) { + await this.onResponseComplete(request); + } } catch (e) { this.handleError(request, e); @@ -320,7 +344,7 @@ export abstract class AbstractChatAgent implements ChatAgent { })); messages.push(...imageMessages); - if (request.response.isComplete || includeResponseInProgress) { + if (request.response.isComplete || includeResponseInProgress || request.request.kind === 'continuation') { const responseMessages: LanguageModelMessage[] = request.response.response.content .filter(c => { // we do not send errors or informational content @@ -404,6 +428,25 @@ export abstract class AbstractChatAgent implements ChatAgent { return undefined; } + /** + * Hook called after addContentsToResponse() to check if summarization is needed. + * Returns true if summarization was triggered (response handling is complete). + * Returns false if no summarization needed (caller should call onResponseComplete). + * + * Uses the injected ChatSummarizationService if available (browser context). + * Returns false in non-browser contexts where the service is not injected. + */ + protected async checkSummarization(request: MutableChatRequestModel): Promise { + if (this.summarizationService) { + return this.summarizationService.checkAndHandleSummarization( + request.session.id, + this, + request + ); + } + return false; + } + /** * Invoked after the response by the LLM completed successfully. * diff --git a/packages/ai-chat/src/common/chat-model-hierarchy.spec.ts b/packages/ai-chat/src/common/chat-model-hierarchy.spec.ts index 59d150d9fb23a..9439b1c04a311 100644 --- a/packages/ai-chat/src/common/chat-model-hierarchy.spec.ts +++ b/packages/ai-chat/src/common/chat-model-hierarchy.spec.ts @@ -162,4 +162,36 @@ describe('ChatRequestHierarchyBranchImpl', () => { expect(branch!.get().id).to.equal(request.id); }); }); + + describe('dispose()', () => { + it('should not throw when disposing an empty branch', () => { + const model = new MutableChatModel(ChatAgentLocation.Panel); + const request = model.addRequest(createParsedRequest('Single request')); + request.response.complete(); + + const branch = model.getBranch(request.id); + expect(branch).to.not.be.undefined; + + // Remove the request to make branch empty + branch!.remove(request); + expect(branch!.items.length).to.equal(0); + + // dispose() should not throw on empty branch + expect(() => branch!.dispose()).to.not.throw(); + }); + + it('should dispose all items when branch has items', () => { + const model = new MutableChatModel(ChatAgentLocation.Panel); + const request = model.addRequest(createParsedRequest('Test request')); + request.response.complete(); + + const branch = model.getBranch(request.id); + expect(branch).to.not.be.undefined; + expect(branch!.items.length).to.equal(1); + + // dispose() should not throw and should clear items + expect(() => branch!.dispose()).to.not.throw(); + expect(branch!.items.length).to.equal(0); + }); + }); }); diff --git a/packages/ai-chat/src/common/chat-model-serialization.ts b/packages/ai-chat/src/common/chat-model-serialization.ts index bbb899b374840..f700eb058d14e 100644 --- a/packages/ai-chat/src/common/chat-model-serialization.ts +++ b/packages/ai-chat/src/common/chat-model-serialization.ts @@ -42,7 +42,7 @@ export interface SerializableChatRequestData { text: string; agentId?: string; /** The type of request. Defaults to 'user' if not specified (for backward compatibility). */ - kind?: 'user' | 'summary'; + kind?: 'user' | 'summary' | 'continuation'; /** Indicates this request has been summarized and should be excluded from prompt construction */ isStale?: boolean; changeSet?: { diff --git a/packages/ai-chat/src/common/chat-model.ts b/packages/ai-chat/src/common/chat-model.ts index 7d7fc055cc4ad..db510156a6134 100644 --- a/packages/ai-chat/src/common/chat-model.ts +++ b/packages/ai-chat/src/common/chat-model.ts @@ -243,7 +243,7 @@ export interface ChangeSetDecoration { readonly additionalInfoSuffixIcon?: string[]; } -export type ChatRequestKind = 'user' | 'summary'; +export type ChatRequestKind = 'user' | 'summary' | 'continuation'; export interface ChatRequest { readonly text: string; @@ -1498,7 +1498,8 @@ export class ChatRequestHierarchyBranchImpl i } dispose(): void { - if (Disposable.is(this.get())) { + // Dispose all items if they are disposable (check first item, not get() which throws on empty) + if (this.items.length > 0 && Disposable.is(this.items[0].element)) { this.items.forEach(({ element }) => Disposable.is(element) && element.dispose()); } this.items.length = 0; diff --git a/packages/ai-chat/src/common/chat-request-parser.spec.ts b/packages/ai-chat/src/common/chat-request-parser.spec.ts index bf1f76e7e5ae8..d2f9c3d5225f7 100644 --- a/packages/ai-chat/src/common/chat-request-parser.spec.ts +++ b/packages/ai-chat/src/common/chat-request-parser.spec.ts @@ -20,19 +20,24 @@ import { ChatRequestParserImpl } from './chat-request-parser'; import { ChatAgentLocation } from './chat-agents'; import { ChatContext, ChatRequest } from './chat-model'; import { expect } from 'chai'; -import { AIVariable, DefaultAIVariableService, ResolvedAIVariable, ToolInvocationRegistryImpl, ToolRequest } from '@theia/ai-core'; +import { AIVariable, AIVariableService, ResolvedAIVariable, ToolInvocationRegistry, ToolRequest } from '@theia/ai-core'; import { ILogger, Logger } from '@theia/core'; import { ParsedChatRequestTextPart, ParsedChatRequestVariablePart } from './parsed-chat-request'; describe('ChatRequestParserImpl', () => { const chatAgentService = sinon.createStubInstance(ChatAgentServiceImpl); - const variableService = sinon.createStubInstance(DefaultAIVariableService); - const toolInvocationRegistry = sinon.createStubInstance(ToolInvocationRegistryImpl); + const variableService = { + getVariable: sinon.stub(), + resolveVariable: sinon.stub() + } as unknown as sinon.SinonStubbedInstance; + const toolInvocationRegistry = { + getFunction: sinon.stub() + } as unknown as sinon.SinonStubbedInstance; const logger: ILogger = sinon.createStubInstance(Logger); const parser = new ChatRequestParserImpl(chatAgentService, variableService, toolInvocationRegistry, logger); beforeEach(() => { - // Reset our stubs before each test + // Reset all stubs before each test to ensure test isolation sinon.reset(); }); diff --git a/packages/ai-chat/src/common/chat-service.ts b/packages/ai-chat/src/common/chat-service.ts index 8d100c35e8c88..8654b1836fbbf 100644 --- a/packages/ai-chat/src/common/chat-service.ts +++ b/packages/ai-chat/src/common/chat-service.ts @@ -276,7 +276,11 @@ export class ChatServiceImpl implements ChatService { return undefined; } - this.cancelIncompleteRequests(session); + // Don't cancel incomplete requests for summary requests - they run concurrently + // with the active request during mid-turn summarization + if (request.kind !== 'summary') { + this.cancelIncompleteRequests(session); + } const resolutionContext: ChatSessionContext = { model: session.model }; const resolvedContext = await this.resolveChatContext(request.variables ?? session.model.context.getVariables(), resolutionContext); @@ -317,7 +321,11 @@ export class ChatServiceImpl implements ChatService { } }); - agent.invoke(requestModel).catch(error => requestModel.response.error(error)); + // Don't invoke agent for continuation requests - they are containers for moved tool results + // The tool loop handles completing the response externally + if (request.kind !== 'continuation') { + agent.invoke(requestModel).catch(error => requestModel.response.error(error)); + } return invocation; } diff --git a/packages/ai-chat/src/common/chat-session-token-tracker.ts b/packages/ai-chat/src/common/chat-session-token-tracker.ts index 94b09b27f6027..c67558d46da0f 100644 --- a/packages/ai-chat/src/common/chat-session-token-tracker.ts +++ b/packages/ai-chat/src/common/chat-session-token-tracker.ts @@ -27,6 +27,12 @@ export interface SessionTokenUpdateEvent { * - `undefined`: Unknown/not yet measured (branch has never had an LLM request; do NOT coerce to 0) */ inputTokens: number | undefined; + /** + * The output token count for the active branch. + * - `number`: Known token count from the most recent LLM response (updated progressively during streaming) + * - `undefined`: Unknown/not yet measured + */ + outputTokens: number | undefined; } export const ChatSessionTokenTracker = Symbol('ChatSessionTokenTracker'); @@ -50,16 +56,39 @@ export interface ChatSessionTokenTracker { */ getSessionInputTokens(sessionId: string): number | undefined; + /** + * Get the latest output token count for a session. + * Returns the outputTokens from the most recent request in the session. + */ + getSessionOutputTokens(sessionId: string): number | undefined; + + /** + * Get the total token count for a session (input + output). + * Returns the sum of input and output tokens, representing current context window usage. + * Returns undefined if neither input nor output tokens are known. + */ + getSessionTotalTokens(sessionId: string): number | undefined; + /** * Reset the session's token count to a new baseline. * Called after summarization to reflect the reduced token usage. * * @param sessionId - The session ID to reset * @param newTokenCount - The new token count, or `undefined` to indicate unknown state. - * When `undefined`, deletes the stored count and emits `{ inputTokens: undefined }`. + * When `undefined`, deletes the stored count and emits `{ inputTokens: undefined, outputTokens: undefined }`. */ resetSessionTokens(sessionId: string, newTokenCount: number | undefined): void; + /** + * Update session tokens with input and/or output token counts. + * Called when token usage information is received from the LLM. + * + * @param sessionId - The session ID to update + * @param inputTokens - The input token count (sets new baseline, resets output to 0) + * @param outputTokens - The output token count (updated progressively during streaming) + */ + updateSessionTokens(sessionId: string, inputTokens?: number, outputTokens?: number): void; + /** * Store token count for a specific branch. * @param sessionId - The session ID diff --git a/packages/ai-core/src/browser/frontend-language-model-service.ts b/packages/ai-core/src/browser/frontend-language-model-service.ts index 087ed8aafea9a..7b6bba2a13a77 100644 --- a/packages/ai-core/src/browser/frontend-language-model-service.ts +++ b/packages/ai-core/src/browser/frontend-language-model-service.ts @@ -31,30 +31,12 @@ export class FrontendLanguageModelServiceImpl extends LanguageModelServiceImpl { languageModel: LanguageModel, languageModelRequest: UserRequest ): Promise { - const requestSettings = this.preferenceService.get(PREFERENCE_NAME_REQUEST_SETTINGS, []); - - const ids = languageModel.id.split('/'); - const matchingSetting = mergeRequestSettings(requestSettings, ids[1], ids[0], languageModelRequest.agentId); - if (matchingSetting?.requestSettings) { - // Merge the settings, with user request taking precedence - languageModelRequest.settings = { - ...matchingSetting.requestSettings, - ...languageModelRequest.settings - }; - } - if (matchingSetting?.clientSettings) { - // Merge the clientSettings, with user request taking precedence - languageModelRequest.clientSettings = { - ...matchingSetting.clientSettings, - ...languageModelRequest.clientSettings - }; - } - + applyRequestSettings(languageModelRequest, languageModel.id, languageModelRequest.agentId, this.preferenceService); return super.sendRequest(languageModel, languageModelRequest); } } -export const mergeRequestSettings = (requestSettings: RequestSetting[], modelId: string, providerId: string, agentId?: string): RequestSetting => { +const mergeRequestSettings = (requestSettings: RequestSetting[], modelId: string, providerId: string, agentId?: string): RequestSetting => { const prioritizedSettings = Prioritizeable.prioritizeAllSync(requestSettings, setting => getRequestSettingSpecificity(setting, { modelId, @@ -65,3 +47,31 @@ export const mergeRequestSettings = (requestSettings: RequestSetting[], modelId: const matchingSetting = prioritizedSettings.reduceRight((acc, cur) => ({ ...acc, ...cur.value }), {} as RequestSetting); return matchingSetting; }; + +/** + * Apply request settings from preferences to a user request. + * Merges settings based on model ID, provider ID, and agent ID specificity. + */ +export const applyRequestSettings = ( + request: UserRequest, + languageModelId: string, + agentId: string | undefined, + preferenceService: PreferenceService +): void => { + const requestSettings = preferenceService.get(PREFERENCE_NAME_REQUEST_SETTINGS, []); + const ids = languageModelId.split('/'); + const matchingSetting = mergeRequestSettings(requestSettings, ids[1], ids[0], agentId); + + if (matchingSetting?.requestSettings) { + request.settings = { + ...matchingSetting.requestSettings, + ...request.settings + }; + } + if (matchingSetting?.clientSettings) { + request.clientSettings = { + ...matchingSetting.clientSettings, + ...request.clientSettings + }; + } +}; diff --git a/packages/ai-ollama/src/node/ollama-language-model.ts b/packages/ai-ollama/src/node/ollama-language-model.ts index 91ebd23631f51..7ea4d4013ad64 100644 --- a/packages/ai-ollama/src/node/ollama-language-model.ts +++ b/packages/ai-ollama/src/node/ollama-language-model.ts @@ -17,7 +17,6 @@ import { LanguageModel, LanguageModelParsedResponse, - LanguageModelRequest, LanguageModelMessage, LanguageModelResponse, LanguageModelStreamResponse, @@ -27,7 +26,8 @@ import { ToolRequestParametersProperties, ImageContent, TokenUsageService, - LanguageModelStatus + LanguageModelStatus, + UserRequest } from '@theia/ai-core'; import { CancellationToken } from '@theia/core'; import { ChatRequest, Message, Ollama, Options, Tool, ToolCall as OllamaToolCall, ChatResponse } from 'ollama'; @@ -58,7 +58,7 @@ export class OllamaModel implements LanguageModel { protected readonly tokenUsageService?: TokenUsageService ) { } - async request(request: LanguageModelRequest, cancellationToken?: CancellationToken): Promise { + async request(request: UserRequest, cancellationToken?: CancellationToken): Promise { const settings = this.getSettings(request); const ollama = this.initializeOllama(); const stream = !(request.settings?.stream === false); // true by default, false only if explicitly specified @@ -71,7 +71,7 @@ export class OllamaModel implements LanguageModel { stream }; const structured = request.response_format?.type === 'json_schema'; - const sessionId = 'sessionId' in request ? (request as { sessionId?: string }).sessionId : undefined; + const sessionId = request.sessionId; return this.dispatchRequest(ollama, ollamaRequest, structured, sessionId, cancellationToken); } @@ -80,7 +80,7 @@ export class OllamaModel implements LanguageModel { * @param request The language model request containing specific settings. * @returns A partial ChatRequest object containing the merged settings. */ - protected getSettings(request: LanguageModelRequest): Partial { + protected getSettings(request: UserRequest): Partial { const settings = request.settings ?? {}; return { options: settings as Partial diff --git a/packages/ai-openai/src/node/openai-response-api-utils.ts b/packages/ai-openai/src/node/openai-response-api-utils.ts index e16ca7e778691..7ca081fb92bc7 100644 --- a/packages/ai-openai/src/node/openai-response-api-utils.ts +++ b/packages/ai-openai/src/node/openai-response-api-utils.ts @@ -322,8 +322,6 @@ class ResponseApiToolCallIterator implements AsyncIterableIterator(); - protected totalInputTokens = 0; - protected totalOutputTokens = 0; protected iteration = 0; protected readonly maxIterations: number; protected readonly tools: FunctionTool[] | undefined; @@ -448,10 +446,17 @@ class ResponseApiToolCallIterator implements AsyncIterableIterator { this.done = true; - // Record final token usage - if (this.tokenUsageService && (this.totalInputTokens > 0 || this.totalOutputTokens > 0)) { - try { - await this.tokenUsageService.recordTokenUsage( - this.modelId, - { - inputTokens: this.totalInputTokens, - outputTokens: this.totalOutputTokens, - requestId: this.request.requestId, - sessionId: this.request.sessionId - } - ); - } catch (error) { - console.error('Error recording token usage:', error); - } - } - // Resolve any outstanding requests if (this.terminalError) { this.requestQueue.forEach(request => request.reject(this.terminalError)); From 30ef4e8b52358cf0f38e8288959b21f944075fa3 Mon Sep 17 00:00:00 2001 From: Eugen Neufeld Date: Wed, 21 Jan 2026 13:20:01 +0100 Subject: [PATCH 4/5] small fixes --- .prompts/project-info.prompttemplate | 2 +- .../chat-tree-view/chat-view-tree-widget.tsx | 2 +- ...chat-session-summarization-service.spec.ts | 25 +++++++++++-------- .../chat-session-summarization-service.ts | 7 +++--- 4 files changed, 20 insertions(+), 16 deletions(-) diff --git a/.prompts/project-info.prompttemplate b/.prompts/project-info.prompttemplate index 5af299f7ddbe5..c6766ff6fb8e6 100644 --- a/.prompts/project-info.prompttemplate +++ b/.prompts/project-info.prompttemplate @@ -65,7 +65,7 @@ The main example applications are in `/examples/`: | Command (from root) | Purpose | |---------------------|---------| -| `npm install` | Install dependencies (required first) | +| `npm ci` | Install dependencies (required first) | | `npm run build:browser` | Build all packages + browser app | | `npm run start:browser` | Start browser example at localhost:3000 | | `npm run start:electron` | Start Electron desktop app | diff --git a/packages/ai-chat-ui/src/browser/chat-tree-view/chat-view-tree-widget.tsx b/packages/ai-chat-ui/src/browser/chat-tree-view/chat-view-tree-widget.tsx index eff865af63d61..31012188897d1 100644 --- a/packages/ai-chat-ui/src/browser/chat-tree-view/chat-view-tree-widget.tsx +++ b/packages/ai-chat-ui/src/browser/chat-tree-view/chat-view-tree-widget.tsx @@ -394,7 +394,7 @@ export class ChatViewTreeWidget extends TreeWidget { return { parent: this.model.root as CompositeTreeNode, get id(): string { - return this.request?.id ?? `empty-branch-${branch.id}`; + return this.request.id; }, get request(): ChatRequestModel { return branch.get(); diff --git a/packages/ai-chat/src/browser/chat-session-summarization-service.spec.ts b/packages/ai-chat/src/browser/chat-session-summarization-service.spec.ts index c3a2806150d72..f6f1c9a0d1163 100644 --- a/packages/ai-chat/src/browser/chat-session-summarization-service.spec.ts +++ b/packages/ai-chat/src/browser/chat-session-summarization-service.spec.ts @@ -428,8 +428,8 @@ describe('ChatSessionSummarizationServiceImpl', () => { outputTokens: 100 })); - // resetSessionTokens SHOULD be called for active branch with totalTokens (inputTokens + outputTokens) - expect(tokenTracker.resetSessionTokens.calledWith(sessionId, CHAT_TOKEN_THRESHOLD + 10100)).to.be.true; // threshold + 10000 input + 100 output + // updateSessionTokens SHOULD be called for active branch with separate input/output values + expect(tokenTracker.updateSessionTokens.calledWith(sessionId, CHAT_TOKEN_THRESHOLD + 10000, 100)).to.be.true; }); it('should remove all branch entries when session is deleted', () => { @@ -546,7 +546,7 @@ describe('ChatSessionSummarizationServiceImpl', () => { expect(tokenTracker.resetSessionTokens.callCount).to.equal(callCountBefore); }); - it('should not double-count cached input tokens (inputTokens already includes cached)', () => { + it('should include readCachedInputTokens in total context size', () => { const sessionId = 'session-8'; const branchId = 'branch-A'; const requestId = `request-for-${branchId}`; @@ -555,21 +555,24 @@ describe('ChatSessionSummarizationServiceImpl', () => { sessionRegistry.set(sessionId, session); // Fire token usage event with cached tokens - // Per Anthropic API: inputTokens already INCLUDES cached tokens - // cachedInputTokens and readCachedInputTokens are just subsets indicating WHERE tokens came from + // For providers like Anthropic with caching: + // - inputTokens: raw/non-cached input tokens + // - readCachedInputTokens: cached tokens read from cache + // Total context size = inputTokens + readCachedInputTokens tokenUsageEmitter.fire(createTokenUsage({ sessionId, requestId, - inputTokens: 1000, // This already includes any cached tokens - cachedInputTokens: 500, // Subset: 500 of the 1000 were cache writes - readCachedInputTokens: 200, // Subset: 200 of the 1000 were cache reads + inputTokens: 1000, // Raw input tokens (non-cached) + cachedInputTokens: 500, // Tokens written to cache (informational only) + readCachedInputTokens: 200, // Cached tokens read - adds to context size outputTokens: 100 })); - // Verify branchTokens uses only inputTokens (not sum with cached) - // totalInputTokens should be 1000, not 1000 + 500 + 200 = 1700 + // Verify branchTokens includes readCachedInputTokens for total context size + // totalInputTokens = 1000 (input) + 200 (cached read) = 1200 + // totalTokens = 1200 + 100 (output) = 1300 const branchTokens = tokenTracker.getBranchTokensForSession(sessionId); - expect(branchTokens[branchId]).to.equal(1100); // 1000 (input) + 100 (output), NOT 1800 + expect(branchTokens[branchId]).to.equal(1300); // 1000 + 200 + 100 }); it('should not update branchTokens when session is not found', () => { diff --git a/packages/ai-chat/src/browser/chat-session-summarization-service.ts b/packages/ai-chat/src/browser/chat-session-summarization-service.ts index 9970ee122d950..3b5b2bcd64a99 100644 --- a/packages/ai-chat/src/browser/chat-session-summarization-service.ts +++ b/packages/ai-chat/src/browser/chat-session-summarization-service.ts @@ -199,7 +199,8 @@ export class ChatSessionSummarizationServiceImpl implements ChatSessionSummariza return; } - const totalInputTokens = usage.inputTokens; + // Total input = raw input + cached tokens read (for providers like Anthropic with caching) + const totalInputTokens = usage.inputTokens + (usage.readCachedInputTokens ?? 0); const totalTokens = totalInputTokens + (usage.outputTokens ?? 0); // Update branch tokens (for branch switching) @@ -209,8 +210,8 @@ export class ChatSessionSummarizationServiceImpl implements ChatSessionSummariza const activeBranchId = this.getActiveBranchId(session); - if (branch.id === activeBranchId && totalTokens > 0) { - this.tokenTracker.resetSessionTokens(usage.sessionId, totalTokens); + if (branch.id === activeBranchId) { + this.tokenTracker.updateSessionTokens(usage.sessionId, totalInputTokens, usage.outputTokens); } } From e589233f9589a5bb1864a78f540990d985665788 Mon Sep 17 00:00:00 2001 From: Eugen Neufeld Date: Thu, 22 Jan 2026 14:57:15 +0100 Subject: [PATCH 5/5] more fixes --- .../src/node/anthropic-language-model.ts | 50 +- .../chat-tree-view/chat-view-tree-widget.tsx | 10 - .../chat-language-model-service.spec.ts | 90 ++++ .../browser/chat-language-model-service.ts | 14 +- ...chat-session-summarization-service.spec.ts | 489 ++++++++---------- .../chat-session-summarization-service.ts | 207 +++----- packages/ai-chat/src/common/chat-agents.ts | 74 ++- .../common/chat-model-insert-summary.spec.ts | 17 +- packages/ai-chat/src/common/chat-model.ts | 8 +- .../src/node/google-language-model.ts | 29 +- .../src/browser/ai-history-exchange-card.tsx | 16 + .../src/common/orchestrator-chat-agent.ts | 16 +- .../src/node/ollama-language-model.ts | 24 +- .../src/node/openai-language-model.ts | 26 - .../src/node/openai-response-api-utils.ts | 57 +- .../src/node/openai-streaming-iterator.ts | 20 +- .../src/node/vercel-ai-language-model.ts | 34 -- 17 files changed, 527 insertions(+), 654 deletions(-) diff --git a/packages/ai-anthropic/src/node/anthropic-language-model.ts b/packages/ai-anthropic/src/node/anthropic-language-model.ts index 2ad01beba420b..9d758206d5021 100644 --- a/packages/ai-anthropic/src/node/anthropic-language-model.ts +++ b/packages/ai-anthropic/src/node/anthropic-language-model.ts @@ -23,7 +23,6 @@ import { LanguageModelStreamResponsePart, LanguageModelTextResponse, TokenUsageService, - TokenUsageParams, UserRequest, ImageContent, ToolCallResult, @@ -312,36 +311,22 @@ export class AnthropicModel implements LanguageModel { } else if (event.type === 'message_start') { currentMessages.push(event.message); currentMessage = event.message; - // Report input tokens immediately - if (that.tokenUsageService && event.message.usage) { - that.tokenUsageService.recordTokenUsage(that.id, { - inputTokens: event.message.usage.input_tokens, - outputTokens: event.message.usage.output_tokens, - cachedInputTokens: event.message.usage.cache_creation_input_tokens ?? undefined, - readCachedInputTokens: event.message.usage.cache_read_input_tokens ?? undefined, - requestId: request.requestId, - sessionId: request.sessionId - }); + // Yield initial usage data (input tokens known, output tokens = 0) + if (event.message.usage) { + yield { + input_tokens: event.message.usage.input_tokens, + output_tokens: 0, + cache_creation_input_tokens: event.message.usage.cache_creation_input_tokens ?? undefined, + cache_read_input_tokens: event.message.usage.cache_read_input_tokens ?? undefined + }; } } else if (event.type === 'message_stop') { if (currentMessage) { + // Yield final output tokens only (input/cached tokens already yielded at message_start) yield { - input_tokens: currentMessage.usage.input_tokens, - output_tokens: currentMessage.usage.output_tokens, - cache_creation_input_tokens: currentMessage.usage.cache_creation_input_tokens ?? undefined, - cache_read_input_tokens: currentMessage.usage.cache_read_input_tokens ?? undefined + input_tokens: 0, + output_tokens: currentMessage.usage.output_tokens }; - // Report final token usage - if (that.tokenUsageService) { - that.tokenUsageService.recordTokenUsage(that.id, { - inputTokens: currentMessage.usage.input_tokens, - outputTokens: currentMessage.usage.output_tokens, - cachedInputTokens: currentMessage.usage.cache_creation_input_tokens ?? undefined, - readCachedInputTokens: currentMessage.usage.cache_read_input_tokens ?? undefined, - requestId: request.requestId, - sessionId: request.sessionId - }); - } } } } @@ -438,19 +423,6 @@ export class AnthropicModel implements LanguageModel { const response = await anthropic.messages.create(params); const textContent = response.content[0]; - // Record token usage if token usage service is available - if (this.tokenUsageService && response.usage) { - const tokenUsageParams: TokenUsageParams = { - inputTokens: response.usage.input_tokens, - outputTokens: response.usage.output_tokens, - cachedInputTokens: response.usage.cache_creation_input_tokens ?? undefined, - readCachedInputTokens: response.usage.cache_read_input_tokens ?? undefined, - requestId: request.requestId, - sessionId: request.sessionId - }; - await this.tokenUsageService.recordTokenUsage(this.id, tokenUsageParams); - } - if (textContent?.type === 'text') { return { text: textContent.text }; } diff --git a/packages/ai-chat-ui/src/browser/chat-tree-view/chat-view-tree-widget.tsx b/packages/ai-chat-ui/src/browser/chat-tree-view/chat-view-tree-widget.tsx index 31012188897d1..2b47ba76dd616 100644 --- a/packages/ai-chat-ui/src/browser/chat-tree-view/chat-view-tree-widget.tsx +++ b/packages/ai-chat-ui/src/browser/chat-tree-view/chat-view-tree-widget.tsx @@ -485,10 +485,6 @@ export class ChatViewTreeWidget extends TreeWidget { const nodes: TreeNode[] = []; this.chatModelId = chatModel.id; chatModel.getBranches().forEach(branch => { - // Skip empty branches (can occur during insertSummary operations) - if (branch.items.length === 0) { - return; - } const request = branch.get(); nodes.push(this.mapRequestToNode(branch)); // Skip separate response node for summary/continuation requests - response is rendered within request node @@ -508,12 +504,6 @@ export class ChatViewTreeWidget extends TreeWidget { if (!TreeNode.isVisible(node)) { return undefined; } - if (isRequestNode(node)) { - // Skip rendering if the branch is empty (request will be undefined) - if (!node.request) { - return undefined; - } - } if (!(isRequestNode(node) || isResponseNode(node))) { return super.renderNode(node, props); } diff --git a/packages/ai-chat/src/browser/chat-language-model-service.spec.ts b/packages/ai-chat/src/browser/chat-language-model-service.spec.ts index 4e6a864a24bfa..392980d383971 100644 --- a/packages/ai-chat/src/browser/chat-language-model-service.spec.ts +++ b/packages/ai-chat/src/browser/chat-language-model-service.spec.ts @@ -365,6 +365,96 @@ describe('ChatLanguageModelServiceImpl', () => { }); }); + describe('subRequestId handling', () => { + it('should set subRequestId with format requestId-0 on first call', async () => { + mockPreferenceService.get.withArgs(BUDGET_AWARE_TOOL_LOOP_PREF, false).returns(true); + mockTokenTracker.getSessionInputTokens.returns(100); + + const request: UserRequest = { + sessionId: 'session-1', + requestId: 'request-1', + messages: [{ actor: 'user', type: 'text', text: 'Hello' }], + tools: [{ id: 'tool-1', name: 'test-tool', parameters: { type: 'object', properties: {} }, handler: async () => 'result' }] + }; + + const mockStream = createMockStream([{ content: 'Response' }]); + mockLanguageModel.request.resolves({ stream: mockStream }); + + const response = await service.sendRequest(mockLanguageModel, request); + + // Consume stream to trigger the actual request + // eslint-disable-next-line @typescript-eslint/no-unused-vars + for await (const _part of (response as LanguageModelStreamResponse).stream) { + // just consume + } + + expect(mockLanguageModel.request.calledOnce).to.be.true; + const actualRequest = mockLanguageModel.request.firstCall.args[0] as UserRequest; + expect(actualRequest.subRequestId).to.equal('request-1-0'); + }); + + it('should increment subRequestId across multiple tool loop iterations', async () => { + mockPreferenceService.get.withArgs(BUDGET_AWARE_TOOL_LOOP_PREF, false).returns(true); + mockTokenTracker.getSessionInputTokens.returns(100); + + const toolHandler = sinon.stub().resolves('tool result'); + const request: UserRequest = { + sessionId: 'session-1', + requestId: 'request-1', + messages: [{ actor: 'user', type: 'text', text: 'Hello' }], + tools: [{ + id: 'tool-1', + name: 'test-tool', + parameters: { type: 'object', properties: {} }, + handler: toolHandler + }] + }; + + // First call: model returns tool call without result (respected singleRoundTrip) + const firstStream = createMockStream([ + { content: 'Let me use a tool' }, + { tool_calls: [{ id: 'call-1', function: { name: 'test-tool', arguments: '{}' }, finished: true }] } + ]); + + // Second call: model returns another tool call + const secondStream = createMockStream([ + { content: 'Using another tool' }, + { tool_calls: [{ id: 'call-2', function: { name: 'test-tool', arguments: '{}' }, finished: true }] } + ]); + + // Third call: model returns final response + const thirdStream = createMockStream([ + { content: 'Done!' } + ]); + + mockLanguageModel.request + .onFirstCall().resolves({ stream: firstStream }) + .onSecondCall().resolves({ stream: secondStream }) + .onThirdCall().resolves({ stream: thirdStream }); + + const response = await service.sendRequest(mockLanguageModel, request); + + // Consume the stream to trigger the tool loop + // eslint-disable-next-line @typescript-eslint/no-unused-vars + for await (const _part of (response as LanguageModelStreamResponse).stream) { + // just consume + } + + // Verify three LLM calls were made + expect(mockLanguageModel.request.calledThrice).to.be.true; + + // Verify subRequestId increments: request-1-0, request-1-1, request-1-2 + const firstCallRequest = mockLanguageModel.request.firstCall.args[0] as UserRequest; + expect(firstCallRequest.subRequestId).to.equal('request-1-0'); + + const secondCallRequest = mockLanguageModel.request.secondCall.args[0] as UserRequest; + expect(secondCallRequest.subRequestId).to.equal('request-1-1'); + + const thirdCallRequest = mockLanguageModel.request.thirdCall.args[0] as UserRequest; + expect(thirdCallRequest.subRequestId).to.equal('request-1-2'); + }); + }); + describe('error handling', () => { it('should throw error when model returns non-streaming response', async () => { mockPreferenceService.get.withArgs(BUDGET_AWARE_TOOL_LOOP_PREF, false).returns(true); diff --git a/packages/ai-chat/src/browser/chat-language-model-service.ts b/packages/ai-chat/src/browser/chat-language-model-service.ts index 50e631f9ba3ca..277408e186613 100644 --- a/packages/ai-chat/src/browser/chat-language-model-service.ts +++ b/packages/ai-chat/src/browser/chat-language-model-service.ts @@ -116,13 +116,14 @@ export class ChatLanguageModelServiceImpl extends LanguageModelServiceImpl { const asyncIterator = { async *[Symbol.asyncIterator](): AsyncIterator { let continueLoop = true; + let iteration = 0; while (continueLoop) { continueLoop = false; // Get response from model const response = await that.sendSingleRoundTripRequest( - languageModel, request, currentMessages, sessionId + languageModel, request, currentMessages, sessionId, iteration ); // Process the stream and collect tool calls @@ -149,7 +150,6 @@ export class ChatLanguageModelServiceImpl extends LanguageModelServiceImpl { if (shouldSplit && sessionId) { // Budget exceeded - mark pending split and exit cleanly - that.logger.info(`Splitting turn for session ${sessionId} due to budget exceeded`); that.summarizationService.markPendingSplit(sessionId, request.requestId, pendingToolCalls, toolResults); return; } @@ -170,6 +170,7 @@ export class ChatLanguageModelServiceImpl extends LanguageModelServiceImpl { })); yield { tool_calls: resultsToYield }; + iteration++; continueLoop = true; } } @@ -187,12 +188,14 @@ export class ChatLanguageModelServiceImpl extends LanguageModelServiceImpl { languageModel: LanguageModel, request: UserRequest, currentMessages: LanguageModelMessage[], - sessionId: string | undefined + sessionId: string | undefined, + iteration: number ): Promise { const currentRequest: UserRequest = { ...request, messages: currentMessages, - singleRoundTrip: true + singleRoundTrip: true, + subRequestId: `${request.requestId}-${iteration}` }; let response: LanguageModelResponse; @@ -263,9 +266,6 @@ export class ChatLanguageModelServiceImpl extends LanguageModelServiceImpl { const toolResults = await this.executeTools(pendingToolCalls, tools); const shouldSplit = sessionId !== undefined && this.isBudgetExceeded(sessionId); - if (shouldSplit) { - this.logger.info(`Budget exceeded after tool execution for session ${sessionId}, will trigger split...`); - } return { toolResults, shouldSplit }; } diff --git a/packages/ai-chat/src/browser/chat-session-summarization-service.spec.ts b/packages/ai-chat/src/browser/chat-session-summarization-service.spec.ts index f6f1c9a0d1163..1eb1fb036bce0 100644 --- a/packages/ai-chat/src/browser/chat-session-summarization-service.spec.ts +++ b/packages/ai-chat/src/browser/chat-session-summarization-service.spec.ts @@ -18,9 +18,9 @@ import { expect } from 'chai'; import * as sinon from 'sinon'; import { Container } from '@theia/core/shared/inversify'; import { Emitter, ILogger } from '@theia/core'; -import { TokenUsage, TokenUsageServiceClient, ToolCall, ToolCallResult } from '@theia/ai-core'; +import { ToolCall, ToolCallResult, UsageResponsePart } from '@theia/ai-core'; import { ChatSessionSummarizationServiceImpl } from './chat-session-summarization-service'; -import { ChatSessionTokenTracker, CHAT_TOKEN_THRESHOLD } from './chat-session-token-tracker'; +import { ChatSessionTokenTracker } from './chat-session-token-tracker'; import { ChatService, SessionCreatedEvent, SessionDeletedEvent } from '../common/chat-service'; import { ChatSession } from '../common'; import { ChatSessionStore } from '../common/chat-session-store'; @@ -30,27 +30,24 @@ describe('ChatSessionSummarizationServiceImpl', () => { let service: ChatSessionSummarizationServiceImpl; let tokenTracker: sinon.SinonStubbedInstance; let chatService: sinon.SinonStubbedInstance; - let tokenUsageClient: sinon.SinonStubbedInstance; let logger: sinon.SinonStubbedInstance; - let tokenUsageEmitter: Emitter; let sessionEventEmitter: Emitter; let sessionRegistry: Map; let sessionStore: sinon.SinonStubbedInstance; - // Helper to create a mock TokenUsage event - function createTokenUsage(params: { - sessionId: string; - requestId: string; - inputTokens: number; - outputTokens: number; - cachedInputTokens?: number; - readCachedInputTokens?: number; - }): TokenUsage { + // Helper to create a mock UsageResponsePart + function createUsageResponsePart(params: { + input_tokens: number; + output_tokens: number; + cache_creation_input_tokens?: number; + cache_read_input_tokens?: number; + }): UsageResponsePart { return { - ...params, - model: 'test-model', - timestamp: new Date() + input_tokens: params.input_tokens, + output_tokens: params.output_tokens, + cache_creation_input_tokens: params.cache_creation_input_tokens, + cache_read_input_tokens: params.cache_read_input_tokens }; } @@ -85,8 +82,7 @@ describe('ChatSessionSummarizationServiceImpl', () => { beforeEach(() => { container = new Container(); - // Create emitters for event simulation - tokenUsageEmitter = new Emitter(); + // Create emitter for session event simulation sessionEventEmitter = new Emitter(); // Create session registry for dynamic lookup @@ -138,10 +134,6 @@ describe('ChatSessionSummarizationServiceImpl', () => { sendRequest: sinon.stub() } as unknown as sinon.SinonStubbedInstance; - tokenUsageClient = { - onTokenUsageUpdated: tokenUsageEmitter.event - } as unknown as sinon.SinonStubbedInstance; - logger = { info: sinon.stub(), warn: sinon.stub(), @@ -160,7 +152,6 @@ describe('ChatSessionSummarizationServiceImpl', () => { // Bind to container container.bind(ChatSessionTokenTracker).toConstantValue(tokenTracker); container.bind(ChatService).toConstantValue(chatService); - container.bind(TokenUsageServiceClient).toConstantValue(tokenUsageClient); container.bind(ILogger).toConstantValue(logger); container.bind(ChatSessionStore).toConstantValue(sessionStore); container.bind(ChatSessionSummarizationServiceImpl).toSelf().inSingletonScope(); @@ -172,13 +163,12 @@ describe('ChatSessionSummarizationServiceImpl', () => { afterEach(() => { sinon.restore(); - tokenUsageEmitter.dispose(); sessionEventEmitter.dispose(); sessionRegistry.clear(); }); describe('markPendingSplit', () => { - it('should store pending split data', () => { + it('should store pending split data', async () => { const sessionId = 'session-1'; const requestId = 'request-1'; const pendingToolCalls: ToolCall[] = [ @@ -186,10 +176,53 @@ describe('ChatSessionSummarizationServiceImpl', () => { ]; const toolResults = new Map([['tool-1', 'result']]); + // Create mock session for handleMidTurnSplit + const session = createMockSession(sessionId, 'branch-1'); + sessionRegistry.set(sessionId, session); + const modelStub = session.model as sinon.SinonStubbedInstance; + (modelStub as unknown as { addRequest: sinon.SinonStub }).addRequest = sinon.stub().returns({ + id: 'new-request', + response: { + response: { + content: [], + clearContent: sinon.stub(), + addContent: sinon.stub(), + addContents: sinon.stub(), + asDisplayString: sinon.stub().returns('Summary text') + } + } + }); + (modelStub as unknown as { getRequests: sinon.SinonStub }).getRequests = sinon.stub().returns([]); + service.markPendingSplit(sessionId, requestId, pendingToolCalls, toolResults); - // Verify info was logged - expect((logger.info as sinon.SinonStub).calledWithMatch('Marking pending split')).to.be.true; + // Verify pending split is stored by calling checkAndHandleSummarization + // which consumes the pending split and returns true + const mockAgent = { invoke: sinon.stub().resolves() }; + const mockResponse = { + isComplete: false, + complete: sinon.stub(), + response: { + content: [], + clearContent: sinon.stub(), + addContent: sinon.stub(), + addContents: sinon.stub() + } + }; + const mockRequest = { + id: requestId, + request: { kind: 'user' }, + response: mockResponse + }; + const result = await service.checkAndHandleSummarization( + sessionId, + mockAgent as unknown as import('../common').ChatAgent, + mockRequest as unknown as import('../common').MutableChatRequestModel, + undefined + ); + + // Should return true because pending split was consumed + expect(result).to.be.true; }); }); @@ -202,17 +235,19 @@ describe('ChatSessionSummarizationServiceImpl', () => { request: { kind: 'summary' }, response: { isComplete: false } }; + const usage = createUsageResponsePart({ input_tokens: 100, output_tokens: 50 }); const result = await service.checkAndHandleSummarization( sessionId, mockAgent as unknown as import('../common').ChatAgent, - mockRequest as unknown as import('../common').MutableChatRequestModel + mockRequest as unknown as import('../common').MutableChatRequestModel, + usage ); expect(result).to.be.false; }); - it('should return false when request kind is continuation', async () => { + it('should return false when request kind is continuation and below threshold', async () => { const sessionId = 'session-1'; const mockAgent = { invoke: sinon.stub() }; const mockRequest = { @@ -220,111 +255,168 @@ describe('ChatSessionSummarizationServiceImpl', () => { request: { kind: 'continuation' }, response: { isComplete: false } }; + const usage = createUsageResponsePart({ input_tokens: 100, output_tokens: 50 }); // Below threshold const result = await service.checkAndHandleSummarization( sessionId, mockAgent as unknown as import('../common').ChatAgent, - mockRequest as unknown as import('../common').MutableChatRequestModel + mockRequest as unknown as import('../common').MutableChatRequestModel, + usage ); expect(result).to.be.false; + // Verify token tracker was still updated for continuation requests + expect(tokenTracker.updateSessionTokens.calledWith(sessionId, 100, 50)).to.be.true; }); - it('should return false when tokens are below threshold', async () => { + it('should not skip continuation request when it exceeds threshold', async () => { const sessionId = 'session-1'; - tokenTracker.getSessionInputTokens.returns(100); // Below threshold + const session = createMockSession(sessionId, 'branch-1'); + sessionRegistry.set(sessionId, session); const mockAgent = { invoke: sinon.stub() }; + const completeStub = sinon.stub(); const mockRequest = { id: 'request-1', - request: { kind: 'user' }, - response: { isComplete: false } + request: { kind: 'continuation' }, + response: { + isComplete: false, + complete: completeStub + } }; + // 7000 tokens > CHAT_TOKEN_THRESHOLD (6300) + const usage = createUsageResponsePart({ input_tokens: 7000, output_tokens: 500 }); - const result = await service.checkAndHandleSummarization( + // Mock model.insertSummary for performSummarization + const modelStub = session.model as sinon.SinonStubbedInstance; + (modelStub as unknown as { insertSummary: sinon.SinonStub }).insertSummary = sinon.stub().resolves('Summary text'); + + // Call the method - it may or may not fully complete summarization + // depending on mocks, but the key behavior is that it doesn't skip + // the threshold check for continuation requests when above threshold + await service.checkAndHandleSummarization( sessionId, mockAgent as unknown as import('../common').ChatAgent, - mockRequest as unknown as import('../common').MutableChatRequestModel + mockRequest as unknown as import('../common').MutableChatRequestModel, + usage ); - expect(result).to.be.false; + // Verify token tracker was updated with high token values + // This confirms the method processed the usage data and didn't skip early + expect(tokenTracker.updateSessionTokens.calledWith(sessionId, 7000, 500)).to.be.true; }); - }); - describe('per-branch token tracking', () => { - it('should attribute tokens to the correct branch via model.getBranch(requestId)', () => { + it('should return false when tokens are below threshold', async () => { const sessionId = 'session-1'; - const branchId = 'branch-A'; - const requestId = `request-for-${branchId}`; - const session = createMockSession(sessionId, branchId); - sessionRegistry.set(sessionId, session); + const mockAgent = { invoke: sinon.stub() }; + const mockRequest = { + id: 'request-1', + request: { kind: 'user' }, + response: { isComplete: false } + }; + const usage = createUsageResponsePart({ input_tokens: 100, output_tokens: 50 }); // Below threshold - // Fire token usage event - tokenUsageEmitter.fire(createTokenUsage({ + const result = await service.checkAndHandleSummarization( sessionId, - requestId, - inputTokens: 1000, - outputTokens: 100 - })); + mockAgent as unknown as import('../common').ChatAgent, + mockRequest as unknown as import('../common').MutableChatRequestModel, + usage + ); - // Verify branchTokens map is updated with totalTokens (inputTokens + outputTokens) - const branchTokens = tokenTracker.getBranchTokensForSession(sessionId); - expect(branchTokens[branchId]).to.equal(1100); // 1000 input + 100 output + expect(result).to.be.false; }); - it('should update branchTokens when token usage event is for active branch', () => { - const sessionId = 'session-2'; - const activeBranchId = 'branch-active'; - const requestId = `request-for-${activeBranchId}`; - const session = createMockSession(sessionId, activeBranchId); - - sessionRegistry.set(sessionId, session); + it('should update token tracker with usage data', async () => { + const sessionId = 'session-1'; + const mockAgent = { invoke: sinon.stub() }; + const mockRequest = { + id: 'request-1', + request: { kind: 'user' }, + response: { isComplete: false } + }; + const usage = createUsageResponsePart({ + input_tokens: 1000, + output_tokens: 200, + cache_creation_input_tokens: 100, + cache_read_input_tokens: 50 + }); - // Fire token usage event for active branch - tokenUsageEmitter.fire(createTokenUsage({ + await service.checkAndHandleSummarization( sessionId, - requestId, - inputTokens: 5000, - outputTokens: 200 - })); + mockAgent as unknown as import('../common').ChatAgent, + mockRequest as unknown as import('../common').MutableChatRequestModel, + usage + ); - // Verify branchTokens was updated with totalTokens (inputTokens + outputTokens) - const branchTokens = tokenTracker.getBranchTokensForSession(sessionId); - expect(branchTokens[activeBranchId]).to.equal(5200); // 5000 input + 200 output + // Total input = input_tokens + cache_creation + cache_read = 1000 + 100 + 50 = 1150 + expect(tokenTracker.updateSessionTokens.calledWith(sessionId, 1150, 200)).to.be.true; }); - it('should NOT trigger tracker reset for non-active branch but should store tokens', () => { - const sessionId = 'session-3'; - const activeBranchId = 'branch-B'; - const nonActiveBranchId = 'branch-A'; - const requestId = `request-for-${nonActiveBranchId}`; - // Active branch is B, but we fire event for branch A - const session = createMockSession(sessionId, activeBranchId, [ - { id: nonActiveBranchId }, - { id: activeBranchId } // Last element is active - ]); + it('should consume pending split and handle mid-turn split', async () => { + const sessionId = 'session-1'; + const requestId = 'request-1'; + const pendingToolCalls: ToolCall[] = [ + { id: 'tool-1', function: { name: 'test_tool', arguments: '{}' }, finished: false } + ]; + const toolResults = new Map([['tool-1', 'result']]); + // Create mock session + const session = createMockSession(sessionId, 'branch-1'); sessionRegistry.set(sessionId, session); - const callCountBefore = tokenTracker.resetSessionTokens.callCount; + // Mark pending split + service.markPendingSplit(sessionId, requestId, pendingToolCalls, toolResults); - // Fire token usage event for non-active branch - tokenUsageEmitter.fire(createTokenUsage({ - sessionId, - requestId, - inputTokens: 3000, - outputTokens: 150 - })); + const mockAgent = { invoke: sinon.stub().resolves() }; + const mockResponse = { + isComplete: false, + complete: sinon.stub(), + response: { + content: [], + clearContent: sinon.stub(), + addContent: sinon.stub(), + addContents: sinon.stub() + } + }; + const mockRequest = { + id: requestId, + request: { kind: 'user' }, + response: mockResponse + }; + const usage = createUsageResponsePart({ input_tokens: 100, output_tokens: 50 }); + + // Mock model methods needed for handleMidTurnSplit + const modelStub = session.model as sinon.SinonStubbedInstance; + (modelStub as unknown as { addRequest: sinon.SinonStub }).addRequest = sinon.stub().returns({ + id: 'new-request', + response: { + response: { + content: [], + clearContent: sinon.stub(), + addContent: sinon.stub(), + addContents: sinon.stub(), + asDisplayString: sinon.stub().returns('Summary text') + } + } + }); + (modelStub as unknown as { getRequests: sinon.SinonStub }).getRequests = sinon.stub().returns([]); - // Verify tokenTracker.resetSessionTokens was NOT called additionally - expect(tokenTracker.resetSessionTokens.callCount).to.equal(callCountBefore); + const result = await service.checkAndHandleSummarization( + sessionId, + mockAgent as unknown as import('../common').ChatAgent, + mockRequest as unknown as import('../common').MutableChatRequestModel, + usage + ); - // But branchTokens should be updated with totalTokens (inputTokens + outputTokens) - const branchTokens = tokenTracker.getBranchTokensForSession(sessionId); - expect(branchTokens[nonActiveBranchId]).to.equal(3150); // 3000 input + 150 output + // Should return true because pending split was handled + expect(result).to.be.true; + // Response should be completed + expect(mockResponse.complete.called).to.be.true; }); + }); + describe('branch change handling', () => { it('should restore stored tokens when branch changes', () => { const sessionId = 'session-4'; const branchA = 'branch-A'; @@ -394,72 +486,6 @@ describe('ChatSessionSummarizationServiceImpl', () => { expect(tokenTracker.resetSessionTokens.calledWith(sessionId, undefined)).to.be.true; }); - it('should reset session tokens for active branch with valid input tokens', async () => { - const sessionId = 'session-6'; - const activeBranchId = 'branch-active'; - const nonActiveBranchId = 'branch-other'; - - // Create session with two branches, active is the last one - const session = createMockSession(sessionId, activeBranchId, [ - { id: nonActiveBranchId }, - { id: activeBranchId } - ]); - - sessionRegistry.set(sessionId, session); - - const resetCallCountBefore = tokenTracker.resetSessionTokens.callCount; - - // Fire token usage event exceeding threshold for NON-active branch - tokenUsageEmitter.fire(createTokenUsage({ - sessionId, - requestId: `request-for-${nonActiveBranchId}`, - inputTokens: CHAT_TOKEN_THRESHOLD + 10000, - outputTokens: 100 - })); - - // resetSessionTokens should NOT be called for non-active branch - expect(tokenTracker.resetSessionTokens.callCount).to.equal(resetCallCountBefore); - - // Now fire for active branch - tokenUsageEmitter.fire(createTokenUsage({ - sessionId, - requestId: `request-for-${activeBranchId}`, - inputTokens: CHAT_TOKEN_THRESHOLD + 10000, - outputTokens: 100 - })); - - // updateSessionTokens SHOULD be called for active branch with separate input/output values - expect(tokenTracker.updateSessionTokens.calledWith(sessionId, CHAT_TOKEN_THRESHOLD + 10000, 100)).to.be.true; - }); - - it('should remove all branch entries when session is deleted', () => { - const sessionId = 'session-to-delete'; - - // Pre-populate branch tokens via tracker and triggeredBranches - tokenTracker.setBranchTokens(sessionId, 'branch-A', 1000); - tokenTracker.setBranchTokens(sessionId, 'branch-B', 2000); - tokenTracker.setBranchTokens('other-session', 'branch-X', 5000); - - const triggeredBranchesSet = (service as unknown as { triggeredBranches: Set }).triggeredBranches; - // Note: cleanupSession uses prefix `${sessionId}: ` (with trailing space) for matching - triggeredBranchesSet.add(`${sessionId}: branch-A`); - triggeredBranchesSet.add(`${sessionId}: branch-B`); - triggeredBranchesSet.add('other-session: branch-X'); - - // Fire session deleted event - sessionEventEmitter.fire({ type: 'deleted', sessionId }); - - // Verify clearSessionBranchTokens was called - expect((tokenTracker.clearSessionBranchTokens as sinon.SinonStub).calledWith(sessionId)).to.be.true; - - // Verify triggeredBranches entries for deleted session are removed - expect(triggeredBranchesSet.has(`${sessionId}: branch-A`)).to.be.false; - expect(triggeredBranchesSet.has(`${sessionId}: branch-B`)).to.be.false; - - // Verify other session's triggeredBranches entries are preserved - expect(triggeredBranchesSet.has('other-session: branch-X')).to.be.true; - }); - it('should populate branchTokens on persistence restore', () => { const sessionId = 'restored-session'; const activeBranchId = 'branch-restored'; @@ -497,145 +523,38 @@ describe('ChatSessionSummarizationServiceImpl', () => { const branchTokens = tokenTracker.getBranchTokensForSession(sessionId); expect(branchTokens).to.deep.equal(branchTokensData); }); - - it('should skip summary requests in token handler', () => { - const sessionId = 'session-7'; - const branchId = 'branch-A'; - const summaryRequestId = 'summary-request-for-branch-A'; - - // Create session where getRequest returns summary kind for specific request - const modelChangeEmitter = new Emitter(); - const session = { - id: sessionId, - isActive: true, - model: { - getBranch: sinon.stub().callsFake((requestId: string) => { - if (requestId === summaryRequestId) { - return { id: branchId }; - } - return undefined; - }), - getBranches: sinon.stub().returns([{ id: branchId }]), - getRequest: sinon.stub().callsFake((requestId: string) => { - if (requestId === summaryRequestId) { - return { request: { kind: 'summary' } }; - } - return { request: { kind: 'user' } }; - }), - onDidChange: modelChangeEmitter.event - } - } as unknown as ChatSession; - - sessionRegistry.set(sessionId, session); - - const callCountBefore = tokenTracker.resetSessionTokens.callCount; - - // Fire token usage event for summary request - tokenUsageEmitter.fire(createTokenUsage({ - sessionId, - requestId: summaryRequestId, - inputTokens: 5000, - outputTokens: 200 - })); - - // Verify branchTokens was NOT updated - const branchTokens = tokenTracker.getBranchTokensForSession(sessionId); - expect(branchTokens[branchId]).to.be.undefined; - - // Verify tokenTracker.resetSessionTokens was NOT called additionally - expect(tokenTracker.resetSessionTokens.callCount).to.equal(callCountBefore); - }); - - it('should include readCachedInputTokens in total context size', () => { - const sessionId = 'session-8'; - const branchId = 'branch-A'; - const requestId = `request-for-${branchId}`; - const session = createMockSession(sessionId, branchId); - - sessionRegistry.set(sessionId, session); - - // Fire token usage event with cached tokens - // For providers like Anthropic with caching: - // - inputTokens: raw/non-cached input tokens - // - readCachedInputTokens: cached tokens read from cache - // Total context size = inputTokens + readCachedInputTokens - tokenUsageEmitter.fire(createTokenUsage({ - sessionId, - requestId, - inputTokens: 1000, // Raw input tokens (non-cached) - cachedInputTokens: 500, // Tokens written to cache (informational only) - readCachedInputTokens: 200, // Cached tokens read - adds to context size - outputTokens: 100 - })); - - // Verify branchTokens includes readCachedInputTokens for total context size - // totalInputTokens = 1000 (input) + 200 (cached read) = 1200 - // totalTokens = 1200 + 100 (output) = 1300 - const branchTokens = tokenTracker.getBranchTokensForSession(sessionId); - expect(branchTokens[branchId]).to.equal(1300); // 1000 + 200 + 100 - }); - - it('should not update branchTokens when session is not found', () => { - const sessionId = 'non-existent-session'; - - // Don't add to sessionRegistry - session not found - - // Fire token usage event - tokenUsageEmitter.fire(createTokenUsage({ - sessionId, - requestId: 'some-request', - inputTokens: 1000, - outputTokens: 100 - })); - - // Verify branchTokens was NOT updated - const branchTokens = tokenTracker.getBranchTokensForSession(sessionId); - expect(branchTokens).to.deep.equal({}); - }); - - it('should not update branchTokens when branch is not found for request', () => { - const sessionId = 'session-9'; - const branchId = 'branch-A'; - const unknownRequestId = 'unknown-request'; - - // Create session where getBranch returns undefined for unknown request - const session = createMockSession(sessionId, branchId); - ((session.model as unknown as { getBranch: sinon.SinonStub }).getBranch).withArgs(unknownRequestId).returns(undefined); - - sessionRegistry.set(sessionId, session); - - const callCountBefore = tokenTracker.resetSessionTokens.callCount; - - // Fire token usage event for unknown request - tokenUsageEmitter.fire(createTokenUsage({ - sessionId, - requestId: unknownRequestId, - inputTokens: 1000, - outputTokens: 100 - })); - - // Verify branchTokens was NOT updated - const branchTokens = tokenTracker.getBranchTokensForSession(sessionId); - expect(branchTokens).to.deep.equal({}); - - // Verify tokenTracker.resetSessionTokens was NOT called additionally - expect(tokenTracker.resetSessionTokens.callCount).to.equal(callCountBefore); - }); - }); describe('cleanupSession', () => { - it('should clean up pendingSplits when session is deleted', () => { + it('should clean up all session data when session is deleted', () => { const sessionId = 'session-to-cleanup'; + // Pre-populate branch tokens via tracker + tokenTracker.setBranchTokens(sessionId, 'branch-A', 1000); + tokenTracker.setBranchTokens(sessionId, 'branch-B', 2000); + tokenTracker.setBranchTokens('other-session', 'branch-X', 5000); + // Add pending split service.markPendingSplit(sessionId, 'request-1', [], new Map()); + // Add to triggeredBranches + const triggeredBranchesSet = (service as unknown as { triggeredBranches: Set }).triggeredBranches; + triggeredBranchesSet.add(`${sessionId}: branch-A`); + triggeredBranchesSet.add(`${sessionId}: branch-B`); + triggeredBranchesSet.add('other-session: branch-X'); + // Fire session deleted event sessionEventEmitter.fire({ type: 'deleted', sessionId }); - // Verify tokenTracker cleanup was called + // Verify clearSessionBranchTokens was called expect((tokenTracker.clearSessionBranchTokens as sinon.SinonStub).calledWith(sessionId)).to.be.true; + + // Verify triggeredBranches entries for deleted session are removed + expect(triggeredBranchesSet.has(`${sessionId}: branch-A`)).to.be.false; + expect(triggeredBranchesSet.has(`${sessionId}: branch-B`)).to.be.false; + + // Verify other session's triggeredBranches entries are preserved + expect(triggeredBranchesSet.has('other-session: branch-X')).to.be.true; }); }); }); diff --git a/packages/ai-chat/src/browser/chat-session-summarization-service.ts b/packages/ai-chat/src/browser/chat-session-summarization-service.ts index 3b5b2bcd64a99..fb9167809100c 100644 --- a/packages/ai-chat/src/browser/chat-session-summarization-service.ts +++ b/packages/ai-chat/src/browser/chat-session-summarization-service.ts @@ -17,7 +17,7 @@ import { inject, injectable, postConstruct } from '@theia/core/shared/inversify'; import { ILogger, nls } from '@theia/core'; import { FrontendApplicationContribution } from '@theia/core/lib/browser'; -import { ToolCall, ToolCallResult, TokenUsage, TokenUsageServiceClient } from '@theia/ai-core'; +import { ToolCall, ToolCallResult, UsageResponsePart } from '@theia/ai-core'; import { ChatAgent, ChatService, @@ -35,8 +35,7 @@ import { import { isSessionCreatedEvent, isSessionDeletedEvent } from '../common/chat-service'; import { CHAT_TOKEN_THRESHOLD, - ChatSessionTokenTracker, - SessionTokenThresholdEvent + ChatSessionTokenTracker } from './chat-session-token-tracker'; export const ChatSessionSummarizationService = Symbol('ChatSessionSummarizationService'); @@ -73,12 +72,15 @@ export interface ChatSessionSummarizationService { * @param sessionId The session ID * @param agent The chat agent to invoke for summary/continuation * @param request The current request being processed + * @param usage Usage data from the response stream for synchronous token tracking. + * May be undefined if the stream ended early (e.g., due to budget-exceeded split). * @returns true if summarization was triggered (caller should skip onResponseComplete), false otherwise */ checkAndHandleSummarization( sessionId: string, agent: ChatAgent, - request: MutableChatRequestModel + request: MutableChatRequestModel, + usage: UsageResponsePart | undefined ): Promise; } @@ -93,9 +95,6 @@ export class ChatSessionSummarizationServiceImpl implements ChatSessionSummariza @inject(ILogger) protected readonly logger: ILogger; - @inject(TokenUsageServiceClient) - protected readonly tokenUsageClient: TokenUsageServiceClient; - /** * Set of sessionIds currently being summarized to prevent concurrent summarization. */ @@ -128,9 +127,6 @@ export class ChatSessionSummarizationServiceImpl implements ChatSessionSummariza @postConstruct() protected init(): void { - // Listen to token usage events and attribute to correct branch - this.tokenUsageClient.onTokenUsageUpdated(usage => this.handleTokenUsage(usage)); - // Listen for new sessions and set up branch change listeners this.chatService.onSessionEvent(event => { if (isSessionCreatedEvent(event)) { @@ -174,47 +170,6 @@ export class ChatSessionSummarizationServiceImpl implements ChatSessionSummariza return session.model.getBranches().at(-1)?.id; } - /** - * Handle token usage events and attribute to correct branch. - */ - protected handleTokenUsage(usage: TokenUsage): void { - if (!usage.sessionId) { - return; - } - - const session = this.chatService.getSession(usage.sessionId); - if (!session) { - return; - } - - const model = session.model as MutableChatModel; - const branch = model.getBranch(usage.requestId); - if (!branch) { - this.logger.debug('Token event for unknown request', { sessionId: usage.sessionId, requestId: usage.requestId }); - return; - } - - // Skip summary requests - the per-summarization listener handles these - if (model.getRequest(usage.requestId)?.request.kind === 'summary') { - return; - } - - // Total input = raw input + cached tokens read (for providers like Anthropic with caching) - const totalInputTokens = usage.inputTokens + (usage.readCachedInputTokens ?? 0); - const totalTokens = totalInputTokens + (usage.outputTokens ?? 0); - - // Update branch tokens (for branch switching) - if (totalTokens > 0) { - this.tokenTracker.setBranchTokens(usage.sessionId, branch.id, totalTokens); - } - - const activeBranchId = this.getActiveBranchId(session); - - if (branch.id === activeBranchId) { - this.tokenTracker.updateSessionTokens(usage.sessionId, totalInputTokens, usage.outputTokens); - } - } - /** * Set up a listener for branch changes in a chat session. * When a branch change occurs (e.g., user edits an older message), reset token tracking. @@ -222,7 +177,6 @@ export class ChatSessionSummarizationServiceImpl implements ChatSessionSummariza protected setupBranchChangeListener(session: ChatSession): void { session.model.onDidChange(event => { if (event.kind === 'changeHierarchyBranch') { - this.logger.info(`Branch changed in session ${session.id}, switching to branch ${event.branch.id}`); const storedTokens = this.tokenTracker.getBranchTokens(session.id, event.branch.id); this.tokenTracker.resetSessionTokens(session.id, storedTokens); } @@ -235,16 +189,26 @@ export class ChatSessionSummarizationServiceImpl implements ChatSessionSummariza pendingToolCalls: ToolCall[], toolResults: Map ): void { - this.logger.info(`Marking pending split for session ${sessionId}, request ${requestId}, ${pendingToolCalls.length} tool calls`); this.pendingSplits.set(sessionId, { requestId, pendingToolCalls, toolResults }); } + /** + * Update token tracking during streaming. + * Called when usage data is received in the stream, before the response completes. + * This enables real-time token count updates in the UI. + */ + updateTokens(sessionId: string, usage: UsageResponsePart): void { + const totalInputTokens = usage.input_tokens + (usage.cache_creation_input_tokens ?? 0) + (usage.cache_read_input_tokens ?? 0); + this.tokenTracker.updateSessionTokens(sessionId, totalInputTokens, usage.output_tokens); + } + async checkAndHandleSummarization( sessionId: string, agent: ChatAgent, - request: MutableChatRequestModel + request: MutableChatRequestModel, + usage: UsageResponsePart | undefined ): Promise { - // Check for pending mid-turn split first + // Check for pending mid-turn split first (may exist even without usage data) const pendingSplit = this.pendingSplits.get(sessionId); if (pendingSplit) { // Consume immediately to prevent re-entry @@ -253,31 +217,41 @@ export class ChatSessionSummarizationServiceImpl implements ChatSessionSummariza return true; } - // Between-turn check: skip if summary or continuation request - if (request.request.kind === 'summary' || request.request.kind === 'continuation') { + // If no usage data, nothing more to do + if (!usage) { return false; } - // Check if threshold exceeded for between-turn summarization - const tokens = this.tokenTracker.getSessionInputTokens(sessionId); - if (tokens === undefined || tokens < CHAT_TOKEN_THRESHOLD) { + // Always skip summary requests before any token work + if (request.request.kind === 'summary') { return false; } - // Between-turn summarization - trigger via existing performSummarization - const session = this.chatService.getSession(sessionId); - if (!session) { + // Calculate tokens for all other requests (user and continuation) + const totalInputTokens = usage.input_tokens + (usage.cache_creation_input_tokens ?? 0) + (usage.cache_read_input_tokens ?? 0); + this.tokenTracker.updateSessionTokens(sessionId, totalInputTokens, usage.output_tokens); + + // Skip continuation requests only if below threshold + if (request.request.kind === 'continuation' && totalInputTokens < CHAT_TOKEN_THRESHOLD) { return false; } - // Complete current response first if not already - if (!request.response.isComplete) { - request.response.complete(); + // Check threshold with fresh data + if (totalInputTokens >= CHAT_TOKEN_THRESHOLD) { + const session = this.chatService.getSession(sessionId); + if (session) { + // Complete current response first if not already + if (!request.response.isComplete) { + request.response.complete(); + } + await this.performSummarization(sessionId, session.model as MutableChatModel); + return true; + } else { + this.logger.warn(`Session ${sessionId} not found for between-turn summarization`); + } } - // Use existing performSummarization for between-turn (it marks stale after summary) - await this.performSummarization(sessionId, session.model as MutableChatModel); - return true; + return false; } protected async handleMidTurnSplit( @@ -299,7 +273,7 @@ export class ChatSessionSummarizationServiceImpl implements ChatSessionSummariza // Step 2: Complete current response request.response.complete(); - // Step 3: Create and invoke summary request (NO stale marking yet - summary needs full history) + // Step 3: Create summary request (stale marking deferred so summary sees full history) // eslint-disable-next-line max-len const summaryPrompt = 'Please provide a concise summary of our conversation so far, capturing all key requirements, decisions, context, and pending tasks so we can seamlessly continue. Do not include conversational elements, questions, or offers to continue. Do not start with a heading - output only the summary content.'; @@ -314,6 +288,10 @@ export class ChatSessionSummarizationServiceImpl implements ChatSessionSummariza // Invoke agent for summary (will populate summaryRequest.response) await agent.invoke(summaryRequest); + // Reset token tracking with summary output as new baseline BEFORE continuation + const summaryOutputTokens = this.tokenTracker.getSessionOutputTokens(sessionId) ?? 0; + this.updateTokenTrackingAfterSummary(sessionId, summaryOutputTokens); + // Get summary text from response const summaryText = summaryRequest.response.response.asDisplayString()?.trim() || ''; @@ -330,8 +308,7 @@ export class ChatSessionSummarizationServiceImpl implements ChatSessionSummariza } } - // Step 5: Create continuation request with tool call content in response - // Include the summary plus an instruction to continue + // Step 5: Create continuation request with summary and pending tool calls const continuationSuffix = 'The tool call above was executed. Please continue with your task ' + 'based on the result. If you need to make more tool calls to complete the task, please do so. ' + 'Once you have all the information needed, provide your final response.'; @@ -358,7 +335,7 @@ export class ChatSessionSummarizationServiceImpl implements ChatSessionSummariza continuationRequest.response.response.addContent(toolContent); } - // Step 6: Invoke agent for continuation (token tracking will update from LLM response) + // Step 6: Invoke agent for continuation (token tracking will update normally) await agent.invoke(continuationRequest); } @@ -380,23 +357,6 @@ export class ChatSessionSummarizationServiceImpl implements ChatSessionSummariza response.response.addContents(filteredContent); } - protected async handleThresholdExceeded(event: SessionTokenThresholdEvent): Promise { - const { sessionId, inputTokens } = event; - - if (this.summarizingSession.has(sessionId)) { - return; - } - - const session = this.chatService.getSession(sessionId); - if (!session) { - this.logger.warn(`Session ${sessionId} not found for summarization`); - return; - } - - this.logger.info(`Token threshold exceeded for session ${sessionId}: ${inputTokens} tokens.Starting summarization...`); - await this.performSummarization(sessionId, session.model as MutableChatModel); - } - /** * Execute a callback with summarization lock for the session. * Ensures lock is released even if callback throws. @@ -429,7 +389,7 @@ export class ChatSessionSummarizationServiceImpl implements ChatSessionSummariza if (session) { const activeBranchId = this.getActiveBranchId(session); if (activeBranchId) { - const branchKey = `${sessionId}:${activeBranchId} `; + const branchKey = `${sessionId}:${activeBranchId}`; this.tokenTracker.setBranchTokens(sessionId, activeBranchId, outputTokens); this.triggeredBranches.delete(branchKey); } @@ -445,8 +405,7 @@ export class ChatSessionSummarizationServiceImpl implements ChatSessionSummariza */ protected async performSummarization(sessionId: string, model: MutableChatModel, skipReorder?: boolean): Promise { return this.withSummarizationLock(sessionId, async () => { - // Always use 'end' position - reordering breaks the hierarchy structure - // because the summary is added as continuation of the trigger request + // Always use 'end' position - other positions break hierarchy structure const position = 'end'; // eslint-disable-next-line max-len const summaryPrompt = 'Please provide a concise summary of our conversation so far, capturing all key requirements, decisions, context, and pending tasks so we can seamlessly continue. ' + @@ -464,41 +423,29 @@ export class ChatSessionSummarizationServiceImpl implements ChatSessionSummariza } const request = await invocation.requestCompleted; + const response = await invocation.responseCompleted; - // Set up token listener to capture output tokens - let capturedOutputTokens: number | undefined; - const tokenListener = this.tokenUsageClient.onTokenUsageUpdated(usage => { - if (usage.sessionId === sessionId && usage.requestId === request.id) { - capturedOutputTokens = usage.outputTokens; - } - }); - - try { - const response = await invocation.responseCompleted; - - // Validate response - const summaryResponseText = response.response.asDisplayString()?.trim(); - if (response.isError || !summaryResponseText) { - return undefined; - } - - // Replace agent's markdown content with SummaryChatResponseContent for proper rendering - const mutableRequest = request as MutableChatRequestModel; - mutableRequest.response.response.clearContent(); - mutableRequest.response.response.addContent(new SummaryChatResponseContentImpl(summaryResponseText)); - - // Store captured output tokens on request for later retrieval - if (capturedOutputTokens !== undefined) { - (request as MutableChatRequestModel).addData('capturedOutputTokens', capturedOutputTokens); - } - - return { - requestId: request.id, - summaryText: summaryResponseText - }; - } finally { - tokenListener.dispose(); + // Validate response + const summaryResponseText = response.response.asDisplayString()?.trim(); + if (response.isError) { + this.logger.error(`Summary response has error: ${response.errorObject?.message}`); + return undefined; + } + if (!summaryResponseText) { + this.logger.error(`Summary response text is empty. Content count: ${response.response.content.length}, ` + + `content kinds: ${response.response.content.map(c => c.kind).join(', ')}`); + return undefined; } + + // Replace agent's markdown content with SummaryChatResponseContent for proper rendering + const mutableRequest = request as MutableChatRequestModel; + mutableRequest.response.response.clearContent(); + mutableRequest.response.response.addContent(new SummaryChatResponseContentImpl(summaryResponseText)); + + return { + requestId: request.id, + summaryText: summaryResponseText + }; }, position ); @@ -509,14 +456,10 @@ export class ChatSessionSummarizationServiceImpl implements ChatSessionSummariza return undefined; } - this.logger.info(`Added summary node to session ${sessionId} `); - - // Find the summary request to get captured output tokens - const summaryRequest = model.getRequests().find(r => r.request.kind === 'summary'); - const outputTokens = summaryRequest?.getDataByKey('capturedOutputTokens') ?? 0; + // Get output tokens from tracker (handleTokenUsage now tracks summary requests) + const outputTokens = this.tokenTracker.getSessionOutputTokens(sessionId) ?? 0; this.updateTokenTrackingAfterSummary(sessionId, outputTokens); - this.logger.info(`Reset token count for session ${sessionId} to ${outputTokens} after summarization`); return summaryText; } catch (error) { diff --git a/packages/ai-chat/src/common/chat-agents.ts b/packages/ai-chat/src/common/chat-agents.ts index 24dd673b8d6b7..efc71dfe5d379 100644 --- a/packages/ai-chat/src/common/chat-agents.ts +++ b/packages/ai-chat/src/common/chat-agents.ts @@ -42,6 +42,7 @@ import { TextMessage, ToolCall, ToolRequest, + UsageResponsePart, } from '@theia/ai-core'; import { Agent, @@ -108,10 +109,25 @@ export const ChatSessionSummarizationServiceSymbol = Symbol('ChatSessionSummariz * Used by AbstractChatAgent to optionally trigger summarization. */ export interface ChatSummarizationService { + /** + * Update token tracking during streaming. + * Called when usage data is received in the stream, before the response completes. + */ + updateTokens( + sessionId: string, + usage: UsageResponsePart + ): void; + + /** + * Check and handle summarization after response completes. + * Called after the stream ends with the final accumulated usage. + * Usage may be undefined if the stream ended early (e.g., due to budget-exceeded split). + */ checkAndHandleSummarization( sessionId: string, agent: ChatAgent, - request: MutableChatRequestModel + request: MutableChatRequestModel, + usage: UsageResponsePart | undefined ): Promise; } @@ -252,8 +268,8 @@ export abstract class AbstractChatAgent implements ChatAgent { ]; const languageModelResponse = await this.sendLlmRequest(request, messages, tools, languageModel); - await this.addContentsToResponse(languageModelResponse, request); - const summarizationHandled = await this.checkSummarization(request); + const usage = await this.addContentsToResponse(languageModelResponse, request); + const summarizationHandled = await this.checkSummarization(request, usage); if (!summarizationHandled) { await this.onResponseComplete(request); } @@ -435,13 +451,18 @@ export abstract class AbstractChatAgent implements ChatAgent { * * Uses the injected ChatSummarizationService if available (browser context). * Returns false in non-browser contexts where the service is not injected. + * Returns false if no usage data is available (LLM doesn't report usage). + * + * @param request The chat request model + * @param usage Usage data from the response stream for synchronous token tracking */ - protected async checkSummarization(request: MutableChatRequestModel): Promise { + protected async checkSummarization(request: MutableChatRequestModel, usage: UsageResponsePart | undefined): Promise { if (this.summarizationService) { return this.summarizationService.checkAndHandleSummarization( request.session.id, this, - request + request, + usage ); } return false; @@ -457,17 +478,18 @@ export abstract class AbstractChatAgent implements ChatAgent { return request.response.complete(); } - protected abstract addContentsToResponse(languageModelResponse: LanguageModelResponse, request: MutableChatRequestModel): Promise; + protected abstract addContentsToResponse(languageModelResponse: LanguageModelResponse, request: MutableChatRequestModel): Promise; } @injectable() export abstract class AbstractTextToModelParsingChatAgent extends AbstractChatAgent { - protected async addContentsToResponse(languageModelResponse: LanguageModelResponse, request: MutableChatRequestModel): Promise { + protected async addContentsToResponse(languageModelResponse: LanguageModelResponse, request: MutableChatRequestModel): Promise { const responseAsText = await getTextOfResponse(languageModelResponse); const parsedCommand = await this.parseTextResponse(responseAsText); const content = this.createResponseContent(parsedCommand, request); request.response.response.addContent(content); + return undefined; // Text responses don't have usage data in the same way } protected abstract parseTextResponse(text: string): Promise; @@ -497,15 +519,14 @@ export abstract class AbstractStreamParsingChatAgent extends AbstractChatAgent { @inject(ToolCallChatResponseContentFactory) protected toolCallResponseContentFactory: ToolCallChatResponseContentFactory; - protected override async addContentsToResponse(languageModelResponse: LanguageModelResponse, request: MutableChatRequestModel): Promise { + protected override async addContentsToResponse(languageModelResponse: LanguageModelResponse, request: MutableChatRequestModel): Promise { if (isLanguageModelTextResponse(languageModelResponse)) { const contents = this.parseContents(languageModelResponse.text, request); request.response.response.addContents(contents); - return; + return undefined; } if (isLanguageModelStreamResponse(languageModelResponse)) { - await this.addStreamResponse(languageModelResponse, request); - return; + return this.addStreamResponse(languageModelResponse, request); } this.logger.error( 'Received unknown response in agent. Return response as text' @@ -515,17 +536,42 @@ export abstract class AbstractStreamParsingChatAgent extends AbstractChatAgent { JSON.stringify(languageModelResponse) ) ); + return undefined; } - protected async addStreamResponse(languageModelResponse: LanguageModelStreamResponse, request: MutableChatRequestModel): Promise { + protected async addStreamResponse(languageModelResponse: LanguageModelStreamResponse, request: MutableChatRequestModel): Promise { let completeTextBuffer = ''; let startIndex = request.response.response.content.length; + // Accumulate usage data across multiple usage parts (some providers yield input/output separately) + let accumulatedUsage: UsageResponsePart | undefined; for await (const token of languageModelResponse.stream) { // Skip unknown tokens. For example OpenAI sends empty tokens around tool calls if (!isLanguageModelStreamResponsePart(token)) { console.debug(`Unknown token: '${JSON.stringify(token)}'. Skipping`); continue; } + // Accumulate usage data (some providers yield input/output separately) + if (isUsageResponsePart(token)) { + if (!accumulatedUsage) { + accumulatedUsage = { ...token }; + } else { + // Accumulate non-zero values (providers may yield input and output in separate parts) + if (token.input_tokens > 0) { + accumulatedUsage.input_tokens = token.input_tokens; + } + if (token.output_tokens > 0) { + accumulatedUsage.output_tokens = token.output_tokens; + } + if (token.cache_creation_input_tokens !== undefined) { + accumulatedUsage.cache_creation_input_tokens = token.cache_creation_input_tokens; + } + if (token.cache_read_input_tokens !== undefined) { + accumulatedUsage.cache_read_input_tokens = token.cache_read_input_tokens; + } + } + // Update token tracking in real-time during streaming + this.summarizationService?.updateTokens(request.session.id, accumulatedUsage); + } const newContent = this.parse(token, request); if (!isTextResponsePart(token)) { // For non-text tokens (like tool calls), add them directly @@ -538,8 +584,7 @@ export abstract class AbstractStreamParsingChatAgent extends AbstractChatAgent { startIndex = request.response.response.content.length; completeTextBuffer = ''; } else { - // parse the entire text so far (since beginning of the stream or last non-text token) - // and replace the entire content with the currently parsed content parts + // Parse accumulated text and replace with parsed content parts completeTextBuffer += token.content; const parsedContents = this.parseContents(completeTextBuffer, request); @@ -552,6 +597,7 @@ export abstract class AbstractStreamParsingChatAgent extends AbstractChatAgent { request.response.response.addContents(parsedContents); } } + return accumulatedUsage; } protected parse(token: LanguageModelStreamResponsePart, request: MutableChatRequestModel): ChatResponseContent | ChatResponseContent[] { diff --git a/packages/ai-chat/src/common/chat-model-insert-summary.spec.ts b/packages/ai-chat/src/common/chat-model-insert-summary.spec.ts index d0fb28bedb9ff..fbc9ff24856fb 100644 --- a/packages/ai-chat/src/common/chat-model-insert-summary.spec.ts +++ b/packages/ai-chat/src/common/chat-model-insert-summary.spec.ts @@ -114,7 +114,7 @@ describe('MutableChatModel.insertSummary()', () => { expect(requests[3].request.kind).to.equal('summary'); }); - it('should mark all requests except the last as stale', async () => { + it('should mark ALL original requests as stale (including trigger)', async () => { const model = createModelWithRequests(3); await model.insertSummary( @@ -123,11 +123,13 @@ describe('MutableChatModel.insertSummary()', () => { ); const requests = model.getRequests(); - // Requests 1-2 (indices 0-1) should be stale, request 3 should not be + // All 3 original requests (indices 0-2) should be stale (including trigger request) + // This is because for between-turn summarization, the trigger request's content + // has ALREADY been summarized and should be excluded from future prompts expect(requests[0].isStale).to.be.true; expect(requests[1].isStale).to.be.true; - expect(requests[2].isStale).to.be.false; - // Summary request (index 3) should also not be stale + expect(requests[2].isStale).to.be.true; + // Summary request (index 3) should NOT be stale (it's created inside callback, after stale list is computed) expect(requests[3].isStale).to.be.false; }); @@ -244,11 +246,12 @@ describe('MutableChatModel.insertSummary()', () => { const requests = model.getRequests(); // First request was already stale, should remain stale expect(requests[0].isStale).to.be.true; - // Second and third requests should now be stale + // Second, third, and fourth requests should now be stale (all original requests) expect(requests[1].isStale).to.be.true; expect(requests[2].isStale).to.be.true; - // Fourth (last before summary) should not be stale - expect(requests[3].isStale).to.be.false; + expect(requests[3].isStale).to.be.true; + // Summary request (index 4) should NOT be stale + expect(requests[4].isStale).to.be.false; }); }); }); diff --git a/packages/ai-chat/src/common/chat-model.ts b/packages/ai-chat/src/common/chat-model.ts index db510156a6134..543c88554e989 100644 --- a/packages/ai-chat/src/common/chat-model.ts +++ b/packages/ai-chat/src/common/chat-model.ts @@ -965,11 +965,11 @@ export class MutableChatModel implements ChatModel, Disposable { return undefined; } - // The request to preserve (most recent exchange, not summarized) - const requestToPreserve = allRequests[allRequests.length - 1]; + // For 'end' position (between-turn), all existing requests are summarized and should be marked stale + // For other positions, preserve the most recent exchange + const requestToPreserve = position === 'end' ? undefined : allRequests[allRequests.length - 1]; // Identify which requests will be marked stale after successful summarization - // (all non-stale requests except the preserved one) const requestsToMarkStale = allRequests.filter(r => !r.isStale && r !== requestToPreserve); // Call the callback to create the summary request and invoke the agent @@ -1300,7 +1300,7 @@ export class ChatRequestHierarchyImpl[] { - return Array.from(this.iterateBranches()); + return Array.from(this.iterateBranches()).filter(b => b.items.length > 0); } protected *iterateBranches(): Generator> { diff --git a/packages/ai-google/src/node/google-language-model.ts b/packages/ai-google/src/node/google-language-model.ts index bfd69f1e3bfb6..4b5a35c6dc0ca 100644 --- a/packages/ai-google/src/node/google-language-model.ts +++ b/packages/ai-google/src/node/google-language-model.ts @@ -309,17 +309,15 @@ export class GoogleModel implements LanguageModel { yield { content: chunk.text }; } - // Report token usage if available - if (chunk.usageMetadata && that.tokenUsageService && that.id) { + // Yield usage data when available + if (chunk.usageMetadata) { const promptTokens = chunk.usageMetadata.promptTokenCount; const completionTokens = chunk.usageMetadata.candidatesTokenCount; - if (promptTokens && completionTokens) { - that.tokenUsageService.recordTokenUsage(that.id, { - inputTokens: promptTokens, - outputTokens: completionTokens, - requestId: request.requestId, - sessionId: request.sessionId - }).catch(error => console.error('Error recording token usage:', error)); + if (promptTokens !== undefined && completionTokens !== undefined) { + yield { + input_tokens: promptTokens, + output_tokens: completionTokens + }; } } } @@ -445,19 +443,6 @@ export class GoogleModel implements LanguageModel { responseText = model.text ?? ''; } - // Record token usage if available - if (model.usageMetadata && this.tokenUsageService) { - const promptTokens = model.usageMetadata.promptTokenCount; - const completionTokens = model.usageMetadata.candidatesTokenCount; - if (promptTokens && completionTokens) { - await this.tokenUsageService.recordTokenUsage(this.id, { - inputTokens: promptTokens, - outputTokens: completionTokens, - requestId: request.requestId, - sessionId: request.sessionId - }); - } - } return { text: responseText }; } catch (error) { throw new Error(`Failed to get response from Gemini API: ${error instanceof Error ? error.message : 'Unknown error'}`); diff --git a/packages/ai-history/src/browser/ai-history-exchange-card.tsx b/packages/ai-history/src/browser/ai-history-exchange-card.tsx index f97a02e9f3cd1..832ed52338e61 100644 --- a/packages/ai-history/src/browser/ai-history-exchange-card.tsx +++ b/packages/ai-history/src/browser/ai-history-exchange-card.tsx @@ -29,6 +29,22 @@ const getTextFromResponse = (response: LanguageModelExchangeRequestResponse): st for (const chunk of response.parts) { if ('content' in chunk && chunk.content) { result += chunk.content; + } else if ('tool_calls' in chunk && Array.isArray(chunk.tool_calls)) { + // Format tool calls for display - only show finished tool calls to avoid duplicates + for (const toolCall of chunk.tool_calls) { + if (toolCall.finished && toolCall.function?.name) { + result += `[Tool Call: ${toolCall.function.name}`; + if (toolCall.function.arguments) { + // Truncate long arguments for readability + const args = toolCall.function.arguments; + const truncatedArgs = args.length > 100 ? args.substring(0, 100) + '...' : args; + result += `(${truncatedArgs})`; + } + result += ']\n'; + } + } + } else if ('thought' in chunk && chunk.thought) { + result += `[Thinking: ${chunk.thought.substring(0, 100)}...]\n`; } } return result; diff --git a/packages/ai-ide/src/common/orchestrator-chat-agent.ts b/packages/ai-ide/src/common/orchestrator-chat-agent.ts index 93731af3c165e..852f389052ba4 100644 --- a/packages/ai-ide/src/common/orchestrator-chat-agent.ts +++ b/packages/ai-ide/src/common/orchestrator-chat-agent.ts @@ -14,7 +14,16 @@ // SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-only WITH Classpath-exception-2.0 // ***************************************************************************** -import { AIVariableContext, getJsonOfText, getTextOfResponse, LanguageModel, LanguageModelMessage, LanguageModelRequirement, LanguageModelResponse } from '@theia/ai-core'; +import { + AIVariableContext, + getJsonOfText, + getTextOfResponse, + LanguageModel, + LanguageModelMessage, + LanguageModelRequirement, + LanguageModelResponse, + UsageResponsePart +} from '@theia/ai-core'; import { inject, injectable } from '@theia/core/shared/inversify'; import { ChatAgentService } from '@theia/ai-chat/lib/common/chat-agent-service'; import { ChatToolRequest } from '@theia/ai-chat/lib/common/chat-tool-request-service'; @@ -126,7 +135,7 @@ export class OrchestratorChatAgent extends AbstractStreamParsingChatAgent { ); } - protected override async addContentsToResponse(response: LanguageModelResponse, request: MutableChatRequestModel): Promise { + protected override async addContentsToResponse(response: LanguageModelResponse, request: MutableChatRequestModel): Promise { const responseText = await getTextOfResponse(response); let agentIds: string[] = []; @@ -184,6 +193,9 @@ export class OrchestratorChatAgent extends AbstractStreamParsingChatAgent { // Get the original request if available const originalRequest = '__originalRequest' in request ? request.__originalRequest as MutableChatRequestModel : request; await agent.invoke(originalRequest); + + // Orchestrator delegates to another agent, no usage data to return + return undefined; } } diff --git a/packages/ai-ollama/src/node/ollama-language-model.ts b/packages/ai-ollama/src/node/ollama-language-model.ts index 7ea4d4013ad64..1a48cda94ee91 100644 --- a/packages/ai-ollama/src/node/ollama-language-model.ts +++ b/packages/ai-ollama/src/node/ollama-language-model.ts @@ -30,7 +30,7 @@ import { UserRequest } from '@theia/ai-core'; import { CancellationToken } from '@theia/core'; -import { ChatRequest, Message, Ollama, Options, Tool, ToolCall as OllamaToolCall, ChatResponse } from 'ollama'; +import { ChatRequest, Message, Ollama, Options, Tool, ToolCall as OllamaToolCall } from 'ollama'; export const OllamaModelIdentifier = Symbol('OllamaModelIdentifier'); @@ -158,7 +158,13 @@ export class OllamaModel implements LanguageModel { } if (chunk.done) { - that.recordTokenUsage(chunk, sessionId); + // Yield usage data when stream completes + if (chunk.prompt_eval_count !== undefined && chunk.eval_count !== undefined) { + yield { + input_tokens: chunk.prompt_eval_count, + output_tokens: chunk.eval_count + }; + } if (chunk.done_reason && chunk.done_reason !== 'stop') { throw new Error('Ollama stopped unexpectedly. Reason: ' + chunk.done_reason); @@ -267,9 +273,8 @@ export class OllamaModel implements LanguageModel { toolCalls.push(...chunk.message.tool_calls); } - // if the response is done, record the token usage and check the done reason + // if the response is done, check the done reason if (chunk.done) { - this.recordTokenUsage(chunk, sessionId); lastUpdated = chunk.created_at; if (chunk.done_reason && chunk.done_reason !== 'stop') { throw new Error('Ollama stopped unexpectedly. Reason: ' + chunk.done_reason); @@ -333,17 +338,6 @@ export class OllamaModel implements LanguageModel { return toolCallsForResponse; } - private recordTokenUsage(response: ChatResponse, sessionId?: string): void { - if (this.tokenUsageService && response.prompt_eval_count && response.eval_count) { - this.tokenUsageService.recordTokenUsage(this.id, { - inputTokens: response.prompt_eval_count, - outputTokens: response.eval_count, - requestId: `ollama_${response.created_at}`, - sessionId - }).catch(error => console.error('Error recording token usage:', error)); - } - } - protected initializeOllama(): Ollama { const host = this.host(); if (!host) { diff --git a/packages/ai-openai/src/node/openai-language-model.ts b/packages/ai-openai/src/node/openai-language-model.ts index a59094ca175dc..802b0d5fcbd69 100644 --- a/packages/ai-openai/src/node/openai-language-model.ts +++ b/packages/ai-openai/src/node/openai-language-model.ts @@ -174,19 +174,6 @@ export class OpenAiModel implements LanguageModel { const message = response.choices[0].message; - // Record token usage if token usage service is available - if (this.tokenUsageService && response.usage) { - await this.tokenUsageService.recordTokenUsage( - this.id, - { - inputTokens: response.usage.prompt_tokens, - outputTokens: response.usage.completion_tokens, - requestId: request.requestId, - sessionId: request.sessionId - } - ); - } - return { text: message.content ?? '' }; @@ -210,19 +197,6 @@ export class OpenAiModel implements LanguageModel { console.error('Error in OpenAI chat completion stream:', JSON.stringify(message)); } - // Record token usage if token usage service is available - if (this.tokenUsageService && result.usage) { - await this.tokenUsageService.recordTokenUsage( - this.id, - { - inputTokens: result.usage.prompt_tokens, - outputTokens: result.usage.completion_tokens, - requestId: request.requestId, - sessionId: request.sessionId - } - ); - } - return { content: message.content ?? '', parsed: message.parsed diff --git a/packages/ai-openai/src/node/openai-response-api-utils.ts b/packages/ai-openai/src/node/openai-response-api-utils.ts index 7ca081fb92bc7..6a92fc67b3069 100644 --- a/packages/ai-openai/src/node/openai-response-api-utils.ts +++ b/packages/ai-openai/src/node/openai-response-api-utils.ts @@ -104,18 +104,6 @@ export class OpenAiResponseApiUtils { ...settings }); - // Record token usage if available - if (tokenUsageService && response.usage) { - await tokenUsageService.recordTokenUsage( - modelId, - { - inputTokens: response.usage.input_tokens, - outputTokens: response.usage.output_tokens, - requestId: request.requestId - } - ); - } - return { text: response.output_text || '' }; } } @@ -186,16 +174,12 @@ export class OpenAiResponseApiUtils { content: event.delta }; } else if (event.type === 'response.completed') { - if (tokenUsageService && event.response?.usage) { - await tokenUsageService.recordTokenUsage( - modelId, - { - inputTokens: event.response.usage.input_tokens, - outputTokens: event.response.usage.output_tokens, - requestId, - sessionId - } - ); + // Yield usage data when response completes + if (event.response?.usage) { + yield { + input_tokens: event.response.usage.input_tokens, + output_tokens: event.response.usage.output_tokens + }; } } else if (event.type === 'error') { console.error('Response API error:', event.message); @@ -446,19 +430,6 @@ class ResponseApiToolCallIterator implements AsyncIterableIterator { @@ -62,19 +62,15 @@ export class StreamingAsyncIterator implements AsyncIterableIterator { - // Handle token usage reporting - if (chunk.usage && this.tokenUsageService && this.model) { + // Yield token usage as UsageResponsePart when available + if (chunk.usage) { const inputTokens = chunk.usage.prompt_tokens || 0; const outputTokens = chunk.usage.completion_tokens || 0; if (inputTokens > 0 || outputTokens > 0) { - const tokenUsageParams: TokenUsageParams = { - inputTokens, - outputTokens, - requestId, - sessionId: this.sessionId - }; - this.tokenUsageService.recordTokenUsage(this.model, tokenUsageParams) - .catch(error => console.error('Error recording token usage:', error)); + this.handleIncoming({ + input_tokens: inputTokens, + output_tokens: outputTokens + }); } } // OpenAI API defines the type of a tool_call as optional but fails if it is not set diff --git a/packages/ai-vercel-ai/src/node/vercel-ai-language-model.ts b/packages/ai-vercel-ai/src/node/vercel-ai-language-model.ts index 94f2e140ba3fa..e598e41cb3bfe 100644 --- a/packages/ai-vercel-ai/src/node/vercel-ai-language-model.ts +++ b/packages/ai-vercel-ai/src/node/vercel-ai-language-model.ts @@ -33,11 +33,8 @@ import { CancellationToken, Disposable, ILogger } from '@theia/core'; import { CoreMessage, generateObject, - GenerateObjectResult, generateText, - GenerateTextResult, jsonSchema, - StepResult, streamText, TextStreamPart, tool, @@ -267,8 +264,6 @@ export class VercelAiModel implements LanguageModel { ...settings }); - await this.recordTokenUsage(response, request); - return { text: response.text }; } @@ -323,31 +318,12 @@ export class VercelAiModel implements LanguageModel { ...settings }); - await this.recordTokenUsage(response, request); - return { content: JSON.stringify(response.object), parsed: response.object }; } - private async recordTokenUsage( - result: GenerateObjectResult | GenerateTextResult, - request: UserRequest - ): Promise { - if (this.tokenUsageService && !isNaN(result.usage.completionTokens) && !isNaN(result.usage.promptTokens)) { - await this.tokenUsageService.recordTokenUsage( - this.id, - { - inputTokens: result.usage.promptTokens, - outputTokens: result.usage.completionTokens, - requestId: request.requestId, - sessionId: request.sessionId - } - ); - } - } - protected async handleStreamingRequest( model: LanguageModelV1, request: UserRequest, @@ -367,16 +343,6 @@ export class VercelAiModel implements LanguageModel { maxRetries: this.maxRetries, toolCallStreaming: true, abortSignal, - onStepFinish: (stepResult: StepResult) => { - if (!isNaN(stepResult.usage.completionTokens) && !isNaN(stepResult.usage.promptTokens)) { - this.tokenUsageService?.recordTokenUsage(this.id, { - inputTokens: stepResult.usage.promptTokens, - outputTokens: stepResult.usage.completionTokens, - requestId: request.requestId, - sessionId: request.sessionId - }); - } - }, ...settings });