diff --git a/.prompts/project-info.prompttemplate b/.prompts/project-info.prompttemplate index 1a31686d4c108..c6766ff6fb8e6 100644 --- a/.prompts/project-info.prompttemplate +++ b/.prompts/project-info.prompttemplate @@ -59,6 +59,20 @@ Tests are located in the same directory as the components under test. If you want to compile something, run the linter or tests, prefer to execute them for changed packages first, as they will run faster. Only build the full project once you are done for a final validation. +### Building and Running the Demo App + +The main example applications are in `/examples/`: + +| Command (from root) | Purpose | +|---------------------|---------| +| `npm ci` | Install dependencies (required first) | +| `npm run build:browser` | Build all packages + browser app | +| `npm run start:browser` | Start browser example at localhost:3000 | +| `npm run start:electron` | Start Electron desktop app | +| `npm run watch` | Watch mode for development | + +**Requirements:** Node.js ≥18.17.0, <21 + ### Styling Theia permits extensive color theming and makes extensive use of CSS variables. Styles are typically located either in an `index.css` file for an entire package or in a diff --git a/packages/ai-anthropic/src/node/anthropic-language-model.ts b/packages/ai-anthropic/src/node/anthropic-language-model.ts index 006acd1898b82..9d758206d5021 100644 --- a/packages/ai-anthropic/src/node/anthropic-language-model.ts +++ b/packages/ai-anthropic/src/node/anthropic-language-model.ts @@ -23,7 +23,6 @@ import { LanguageModelStreamResponsePart, LanguageModelTextResponse, TokenUsageService, - TokenUsageParams, UserRequest, ImageContent, ToolCallResult, @@ -263,7 +262,6 @@ export class AnthropicModel implements LanguageModel { const asyncIterator = { async *[Symbol.asyncIterator](): AsyncIterator { - const toolCalls: ToolCallback[] = []; let toolCall: ToolCallback | undefined; const currentMessages: Message[] = []; @@ -313,25 +311,40 @@ export class AnthropicModel implements LanguageModel { } else if (event.type === 'message_start') { currentMessages.push(event.message); currentMessage = event.message; + // Yield initial usage data (input tokens known, output tokens = 0) + if (event.message.usage) { + yield { + input_tokens: event.message.usage.input_tokens, + output_tokens: 0, + cache_creation_input_tokens: event.message.usage.cache_creation_input_tokens ?? undefined, + cache_read_input_tokens: event.message.usage.cache_read_input_tokens ?? undefined + }; + } } else if (event.type === 'message_stop') { if (currentMessage) { - yield { input_tokens: currentMessage.usage.input_tokens, output_tokens: currentMessage.usage.output_tokens }; - // Record token usage if token usage service is available - if (that.tokenUsageService && currentMessage.usage) { - const tokenUsageParams: TokenUsageParams = { - inputTokens: currentMessage.usage.input_tokens, - outputTokens: currentMessage.usage.output_tokens, - cachedInputTokens: currentMessage.usage.cache_creation_input_tokens || undefined, - readCachedInputTokens: currentMessage.usage.cache_read_input_tokens || undefined, - requestId: request.requestId - }; - await that.tokenUsageService.recordTokenUsage(that.id, tokenUsageParams); - } + // Yield final output tokens only (input/cached tokens already yielded at message_start) + yield { + input_tokens: 0, + output_tokens: currentMessage.usage.output_tokens + }; } - } } if (toolCalls.length > 0) { + // singleRoundTrip mode: Return tool calls to caller without executing them. + // This allows external tool loop management (e.g., for budget-aware summarization). + // When enabled, we yield the tool_calls and return immediately; the caller + // handles tool execution and decides whether to continue the conversation. + if (request.singleRoundTrip) { + const pendingCalls = toolCalls.map(tc => ({ + finished: true, + id: tc.id, + function: { name: tc.name, arguments: tc.args.length === 0 ? '{}' : tc.args } + })); + yield { tool_calls: pendingCalls }; + return; + } + const toolResult = await Promise.all(toolCalls.map(async tc => { const tool = request.tools?.find(t => t.name === tc.name); const argsObject = tc.args.length === 0 ? '{}' : tc.args; @@ -410,16 +423,6 @@ export class AnthropicModel implements LanguageModel { const response = await anthropic.messages.create(params); const textContent = response.content[0]; - // Record token usage if token usage service is available - if (this.tokenUsageService && response.usage) { - const tokenUsageParams: TokenUsageParams = { - inputTokens: response.usage.input_tokens, - outputTokens: response.usage.output_tokens, - requestId: request.requestId - }; - await this.tokenUsageService.recordTokenUsage(this.id, tokenUsageParams); - } - if (textContent?.type === 'text') { return { text: textContent.text }; } diff --git a/packages/ai-chat-ui/src/browser/ai-chat-ui-frontend-module.ts b/packages/ai-chat-ui/src/browser/ai-chat-ui-frontend-module.ts index e0250fdf579fb..3214f098b5422 100644 --- a/packages/ai-chat-ui/src/browser/ai-chat-ui-frontend-module.ts +++ b/packages/ai-chat-ui/src/browser/ai-chat-ui-frontend-module.ts @@ -40,6 +40,7 @@ import { TextPartRenderer, } from './chat-response-renderer'; import { UnknownPartRenderer } from './chat-response-renderer/unknown-part-renderer'; +import { SummaryPartRenderer } from './chat-response-renderer/summary-part-renderer'; import { GitHubSelectionResolver, TextFragmentSelectionResolver, @@ -139,6 +140,7 @@ export default new ContainerModule((bind, _unbind, _isBound, rebind) => { bind(ChatResponsePartRenderer).to(TextPartRenderer).inSingletonScope(); bind(ChatResponsePartRenderer).to(DelegationResponseRenderer).inSingletonScope(); bind(ChatResponsePartRenderer).to(UnknownPartRenderer).inSingletonScope(); + bind(ChatResponsePartRenderer).to(SummaryPartRenderer).inSingletonScope(); [CommandContribution, MenuContribution].forEach(serviceIdentifier => bind(serviceIdentifier).to(ChatViewMenuContribution).inSingletonScope() ); diff --git a/packages/ai-chat-ui/src/browser/chat-response-renderer/summary-part-renderer.tsx b/packages/ai-chat-ui/src/browser/chat-response-renderer/summary-part-renderer.tsx new file mode 100644 index 0000000000000..5021352bd4db4 --- /dev/null +++ b/packages/ai-chat-ui/src/browser/chat-response-renderer/summary-part-renderer.tsx @@ -0,0 +1,67 @@ +// ***************************************************************************** +// Copyright (C) 2025 EclipseSource GmbH. +// +// This program and the accompanying materials are made available under the +// terms of the Eclipse Public License v. 2.0 which is available at +// http://www.eclipse.org/legal/epl-2.0. +// +// This Source Code may also be made available under the following Secondary +// Licenses when the conditions for such availability set forth in the Eclipse +// Public License v. 2.0 are satisfied: GNU General Public License, version 2 +// with the GNU Classpath Exception which is available at +// https://www.gnu.org/software/classpath/license.html. +// +// SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-only WITH Classpath-exception-2.0 +// ***************************************************************************** + +import { ChatResponsePartRenderer } from '../chat-response-part-renderer'; +import { inject, injectable } from '@theia/core/shared/inversify'; +import { ChatResponseContent, SummaryChatResponseContent } from '@theia/ai-chat/lib/common'; +import { ReactNode } from '@theia/core/shared/react'; +import { nls } from '@theia/core/lib/common/nls'; +import * as React from '@theia/core/shared/react'; +import { OpenerService } from '@theia/core/lib/browser'; +import { useMarkdownRendering } from './markdown-part-renderer'; + +/** + * Renderer for SummaryChatResponseContent. + * Displays the summary in a collapsible section that is collapsed by default. + */ +@injectable() +export class SummaryPartRenderer implements ChatResponsePartRenderer { + + @inject(OpenerService) + protected readonly openerService: OpenerService; + + canHandle(response: ChatResponseContent): number { + if (SummaryChatResponseContent.is(response)) { + return 10; + } + return -1; + } + + render(response: SummaryChatResponseContent): ReactNode { + return ; + } +} + +interface SummaryContentProps { + content: string; + openerService: OpenerService; +} + +const SummaryContent: React.FC = ({ content, openerService }) => { + const contentRef = useMarkdownRendering(content, openerService); + + return ( +
+
+ + + {nls.localize('theia/ai/chat-ui/summary-part-renderer/conversationSummary', 'Conversation Summary')} + +
+
+
+ ); +}; diff --git a/packages/ai-chat-ui/src/browser/chat-token-usage-indicator.spec.ts b/packages/ai-chat-ui/src/browser/chat-token-usage-indicator.spec.ts new file mode 100644 index 0000000000000..9d01dc4b709e0 --- /dev/null +++ b/packages/ai-chat-ui/src/browser/chat-token-usage-indicator.spec.ts @@ -0,0 +1,317 @@ +// ***************************************************************************** +// Copyright (C) 2025 EclipseSource GmbH. +// +// This program and the accompanying materials are made available under the +// terms of the Eclipse Public License v. 2.0 which is available at +// http://www.eclipse.org/legal/epl-2.0. +// +// This Source Code may also be made available under the following Secondary +// Licenses when the conditions for such availability set forth in the Eclipse +// Public License v. 2.0 are satisfied: GNU General Public License, version 2 +// with the GNU Classpath Exception which is available at +// https://www.gnu.org/software/classpath/license.html. +// +// SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-only WITH Classpath-exception-2.0 +// ***************************************************************************** + +import { enableJSDOM } from '@theia/core/lib/browser/test/jsdom'; +let disableJSDOM = enableJSDOM(); +import { FrontendApplicationConfigProvider } from '@theia/core/lib/browser/frontend-application-config-provider'; +FrontendApplicationConfigProvider.set({}); + +import { expect } from 'chai'; +import * as React from '@theia/core/shared/react'; +import * as ReactDOMClient from '@theia/core/shared/react-dom/client'; +import { flushSync } from '@theia/core/shared/react-dom'; +import { Emitter } from '@theia/core'; +import { + ChatSessionTokenTracker, + SessionTokenUpdateEvent +} from '@theia/ai-chat/lib/browser'; +import { ChatTokenUsageIndicator, ChatTokenUsageIndicatorProps } from './chat-token-usage-indicator'; + +disableJSDOM(); + +describe('ChatTokenUsageIndicator', () => { + let container: HTMLDivElement; + let root: ReactDOMClient.Root; + + const createMockTokenTracker = (tokens: number | undefined): ChatSessionTokenTracker => { + const updateEmitter = new Emitter(); + return { + onSessionTokensUpdated: updateEmitter.event, + getSessionInputTokens: () => tokens, + getSessionOutputTokens: () => undefined, + getSessionTotalTokens: () => tokens, + resetSessionTokens: () => { }, + updateSessionTokens: () => { }, + setBranchTokens: () => { }, + getBranchTokens: () => undefined, + getBranchTokensForSession: () => ({}), + restoreBranchTokens: () => { }, + clearSessionBranchTokens: () => { } + }; + }; + + before(() => { + disableJSDOM = enableJSDOM(); + }); + + after(() => { + disableJSDOM(); + }); + + beforeEach(() => { + container = document.createElement('div'); + document.body.appendChild(container); + root = ReactDOMClient.createRoot(container); + }); + + afterEach(() => { + flushSync(() => { + root.unmount(); + }); + container.remove(); + }); + + const renderComponent = (props: ChatTokenUsageIndicatorProps): void => { + flushSync(() => { + root.render(React.createElement(ChatTokenUsageIndicator, props)); + }); + }; + + describe('token formatting', () => { + it('should display "-" when no tokens are tracked', () => { + const mockTracker = createMockTokenTracker(undefined); + renderComponent({ + sessionId: 'test-session', + tokenTracker: mockTracker, + budgetAwareEnabled: true + }); + + const textContent = container.textContent; + expect(textContent).to.contain('-'); + }); + + it('should format small token counts as plain numbers', () => { + const mockTracker = createMockTokenTracker(500); + renderComponent({ + sessionId: 'test-session', + tokenTracker: mockTracker, + budgetAwareEnabled: true + }); + + const textContent = container.textContent; + expect(textContent).to.contain('500'); + }); + + it('should format large token counts with "k" suffix', () => { + const mockTracker = createMockTokenTracker(125000); + renderComponent({ + sessionId: 'test-session', + tokenTracker: mockTracker, + budgetAwareEnabled: true + }); + + const textContent = container.textContent; + expect(textContent).to.contain('125k'); + }); + + it('should format token counts with decimal "k" suffix when needed', () => { + const mockTracker = createMockTokenTracker(1500); + renderComponent({ + sessionId: 'test-session', + tokenTracker: mockTracker, + budgetAwareEnabled: true + }); + + const textContent = container.textContent; + expect(textContent).to.contain('1.5k'); + }); + }); + + describe('color coding', () => { + it('should have green class when usage is below 70%', () => { + // Below 70% of CHAT_TOKEN_THRESHOLD + const mockTracker = createMockTokenTracker(100000); + renderComponent({ + sessionId: 'test-session', + tokenTracker: mockTracker, + budgetAwareEnabled: true + }); + + const indicator = container.querySelector('.theia-ChatTokenUsageIndicator'); + expect(indicator?.classList.contains('token-usage-green')).to.be.true; + }); + + it('should have yellow class when usage is between 70% and 90%', () => { + // Between 70% and 90% of CHAT_TOKEN_THRESHOLD (180000 * 0.7 = 126000, 180000 * 0.9 = 162000) + const mockTracker = createMockTokenTracker(150000); + renderComponent({ + sessionId: 'test-session', + tokenTracker: mockTracker, + budgetAwareEnabled: true + }); + + const indicator = container.querySelector('.theia-ChatTokenUsageIndicator'); + expect(indicator?.classList.contains('token-usage-yellow')).to.be.true; + }); + + it('should have red class when usage is at or above 90%', () => { + // At or above 90% of CHAT_TOKEN_THRESHOLD + const mockTracker = createMockTokenTracker(170000); + renderComponent({ + sessionId: 'test-session', + tokenTracker: mockTracker, + budgetAwareEnabled: true + }); + + const indicator = container.querySelector('.theia-ChatTokenUsageIndicator'); + expect(indicator?.classList.contains('token-usage-red')).to.be.true; + }); + + it('should have none class when no tokens are tracked', () => { + const mockTracker = createMockTokenTracker(undefined); + renderComponent({ + sessionId: 'test-session', + tokenTracker: mockTracker, + budgetAwareEnabled: true + }); + + const indicator = container.querySelector('.theia-ChatTokenUsageIndicator'); + expect(indicator?.classList.contains('token-usage-none')).to.be.true; + }); + }); + + describe('tooltip', () => { + it('should include budget-aware status in tooltip when enabled', () => { + const mockTracker = createMockTokenTracker(100000); + renderComponent({ + sessionId: 'test-session', + tokenTracker: mockTracker, + budgetAwareEnabled: true + }); + + const indicator = container.querySelector('.theia-ChatTokenUsageIndicator'); + const title = indicator?.getAttribute('title'); + expect(title).to.include('Budget-aware: Enabled'); + }); + + it('should include budget-aware status in tooltip when disabled', () => { + const mockTracker = createMockTokenTracker(100000); + renderComponent({ + sessionId: 'test-session', + tokenTracker: mockTracker, + budgetAwareEnabled: false + }); + + const indicator = container.querySelector('.theia-ChatTokenUsageIndicator'); + const title = indicator?.getAttribute('title'); + expect(title).to.include('Budget-aware: Disabled'); + }); + + it('should include threshold and budget values in tooltip', () => { + const mockTracker = createMockTokenTracker(100000); + renderComponent({ + sessionId: 'test-session', + tokenTracker: mockTracker, + budgetAwareEnabled: true + }); + + const indicator = container.querySelector('.theia-ChatTokenUsageIndicator'); + const title = indicator?.getAttribute('title'); + expect(title).to.include('Threshold:'); + expect(title).to.include('Budget:'); + }); + + it('should show "None" in tooltip when no tokens tracked', () => { + const mockTracker = createMockTokenTracker(undefined); + renderComponent({ + sessionId: 'test-session', + tokenTracker: mockTracker, + budgetAwareEnabled: true + }); + + const indicator = container.querySelector('.theia-ChatTokenUsageIndicator'); + const title = indicator?.getAttribute('title'); + expect(title).to.include('Tokens: None'); + }); + }); + + describe('subscription to token updates', () => { + it('should update when token tracker fires update event', () => { + const updateEmitter = new Emitter(); + let currentTokens = 50000; + + const mockTracker: ChatSessionTokenTracker = { + onSessionTokensUpdated: updateEmitter.event, + getSessionInputTokens: () => currentTokens, + getSessionOutputTokens: () => undefined, + getSessionTotalTokens: () => currentTokens, + resetSessionTokens: () => { }, + updateSessionTokens: () => { }, + setBranchTokens: () => { }, + getBranchTokens: () => undefined, + getBranchTokensForSession: () => ({}), + restoreBranchTokens: () => { }, + clearSessionBranchTokens: () => { } + }; + + renderComponent({ + sessionId: 'test-session', + tokenTracker: mockTracker, + budgetAwareEnabled: true + }); + + // Initial state + let textContent = container.textContent; + expect(textContent).to.contain('50k'); + + // Fire update event within flushSync to ensure synchronous React update + currentTokens = 100000; + flushSync(() => { + updateEmitter.fire({ sessionId: 'test-session', inputTokens: 100000, outputTokens: undefined }); + }); + + textContent = container.textContent; + expect(textContent).to.contain('100k'); + }); + + it('should not update when event is for different session', () => { + const updateEmitter = new Emitter(); + + const mockTracker: ChatSessionTokenTracker = { + onSessionTokensUpdated: updateEmitter.event, + getSessionInputTokens: () => 50000, + getSessionOutputTokens: () => undefined, + getSessionTotalTokens: () => 50000, + resetSessionTokens: () => { }, + updateSessionTokens: () => { }, + setBranchTokens: () => { }, + getBranchTokens: () => undefined, + getBranchTokensForSession: () => ({}), + restoreBranchTokens: () => { }, + clearSessionBranchTokens: () => { } + }; + + renderComponent({ + sessionId: 'test-session', + tokenTracker: mockTracker, + budgetAwareEnabled: true + }); + + // Initial state + let textContent = container.textContent; + expect(textContent).to.contain('50k'); + + // Fire update event for different session within flushSync + flushSync(() => { + updateEmitter.fire({ sessionId: 'other-session', inputTokens: 100000, outputTokens: undefined }); + }); + + textContent = container.textContent; + // Should still show 50k since we didn't update our session + expect(textContent).to.contain('50k'); + }); + }); +}); diff --git a/packages/ai-chat-ui/src/browser/chat-token-usage-indicator.tsx b/packages/ai-chat-ui/src/browser/chat-token-usage-indicator.tsx new file mode 100644 index 0000000000000..a8a17e826607b --- /dev/null +++ b/packages/ai-chat-ui/src/browser/chat-token-usage-indicator.tsx @@ -0,0 +1,120 @@ +// ***************************************************************************** +// Copyright (C) 2025 EclipseSource GmbH. +// +// This program and the accompanying materials are made available under the +// terms of the Eclipse Public License v. 2.0 which is available at +// http://www.eclipse.org/legal/epl-2.0. +// +// This Source Code may also be made available under the following Secondary +// Licenses when the conditions for such availability set forth in the Eclipse +// Public License v. 2.0 are satisfied: GNU General Public License, version 2 +// with the GNU Classpath Exception which is available at +// https://www.gnu.org/software/classpath/license.html. +// +// SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-only WITH Classpath-exception-2.0 +// ***************************************************************************** + +import * as React from '@theia/core/shared/react'; +import { + ChatSessionTokenTracker, + CHAT_TOKEN_BUDGET, + CHAT_TOKEN_THRESHOLD +} from '@theia/ai-chat/lib/browser'; + +/** Percentage of threshold at which to show warning color (yellow) */ +const TOKEN_USAGE_WARNING_PERCENT = 70; +/** Percentage of threshold at which to show critical color (red) */ +const TOKEN_USAGE_CRITICAL_PERCENT = 90; + +export interface ChatTokenUsageIndicatorProps { + sessionId: string; + tokenTracker: ChatSessionTokenTracker; + budgetAwareEnabled: boolean; +} + +/** + * Formats a token count to a human-readable string. + * E.g., 125000 -> "125k", 1500 -> "1.5k", 500 -> "500" + */ +const formatTokenCount = (tokens: number | undefined): string => { + if (tokens === undefined) { + return '-'; + } + if (tokens >= 1000) { + const k = tokens / 1000; + // Show one decimal place if needed, otherwise whole number + return k % 1 === 0 ? `${k}k` : `${k.toFixed(1)}k`; + } + return tokens.toString(); +}; + +/** + * Determines the color class based on usage percentage. + * - Green: <70% + * - Yellow: 70-<90% + * - Red: ≥90% + */ +const getUsageColorClass = (tokens: number | undefined, threshold: number): string => { + if (tokens === undefined) { + return 'token-usage-none'; + } + const percentage = (tokens / threshold) * 100; + if (percentage >= TOKEN_USAGE_CRITICAL_PERCENT) { + return 'token-usage-red'; + } + if (percentage >= TOKEN_USAGE_WARNING_PERCENT) { + return 'token-usage-yellow'; + } + return 'token-usage-green'; +}; + +/** + * A React component that displays the current token usage for a chat session. + * Shows current input tokens vs threshold with color coding based on usage percentage. + */ +export const ChatTokenUsageIndicator: React.FC = ({ + sessionId, + tokenTracker, + budgetAwareEnabled +}) => { + const [inputTokens, setInputTokens] = React.useState( + () => tokenTracker.getSessionInputTokens(sessionId) + ); + + React.useEffect(() => { + // Get initial value + setInputTokens(tokenTracker.getSessionInputTokens(sessionId)); + + // Subscribe to token updates + const disposable = tokenTracker.onSessionTokensUpdated(event => { + if (event.sessionId === sessionId) { + setInputTokens(event.inputTokens); + } + }); + + return () => disposable.dispose(); + }, [sessionId, tokenTracker]); + + const thresholdFormatted = formatTokenCount(CHAT_TOKEN_THRESHOLD); + const budgetFormatted = formatTokenCount(CHAT_TOKEN_BUDGET); + const currentFormatted = formatTokenCount(inputTokens); + const colorClass = getUsageColorClass(inputTokens, CHAT_TOKEN_THRESHOLD); + + const tooltipText = [ + `Tokens: ${inputTokens !== undefined ? inputTokens.toLocaleString() : 'None'}`, + `Threshold: ${CHAT_TOKEN_THRESHOLD.toLocaleString()} (${thresholdFormatted})`, + `Budget: ${CHAT_TOKEN_BUDGET.toLocaleString()} (${budgetFormatted})`, + `Budget-aware: ${budgetAwareEnabled ? 'Enabled' : 'Disabled'}` + ].join('\n'); + + return ( +
+ + {currentFormatted} / {budgetFormatted} tokens + +
+ ); +}; diff --git a/packages/ai-chat-ui/src/browser/chat-tree-view/chat-view-tree-widget.tsx b/packages/ai-chat-ui/src/browser/chat-tree-view/chat-view-tree-widget.tsx index c2edadf90a7c9..2b47ba76dd616 100644 --- a/packages/ai-chat-ui/src/browser/chat-tree-view/chat-view-tree-widget.tsx +++ b/packages/ai-chat-ui/src/browser/chat-tree-view/chat-view-tree-widget.tsx @@ -487,7 +487,10 @@ export class ChatViewTreeWidget extends TreeWidget { chatModel.getBranches().forEach(branch => { const request = branch.get(); nodes.push(this.mapRequestToNode(branch)); - nodes.push(this.mapResponseToNode(request.response)); + // Skip separate response node for summary/continuation requests - response is rendered within request node + if (request.request.kind !== 'summary' && request.request.kind !== 'continuation') { + nodes.push(this.mapResponseToNode(request.response)); + } }); this.model.root.children = nodes; this.model.refresh(); @@ -504,9 +507,14 @@ export class ChatViewTreeWidget extends TreeWidget { if (!(isRequestNode(node) || isResponseNode(node))) { return super.renderNode(node, props); } + + // Check if this is a summary or continuation request - skip header for both + const isSummaryOrContinuation = isRequestNode(node) && + (node.request.request.kind === 'summary' || node.request.request.kind === 'continuation'); + return
this.handleContextMenu(node, e)}> - {this.renderAgent(node)} + {!isSummaryOrContinuation && this.renderAgent(node)} {this.renderDetail(node)}
; @@ -622,6 +630,7 @@ export class ChatViewTreeWidget extends TreeWidget { chatAgentService={this.chatAgentService} variableService={this.variableService} openerService={this.openerService} + renderResponseContent={(content: ChatResponseContent, responseNode?: ResponseNode) => this.renderResponseContent(content, responseNode)} provideChatInputWidget={() => { const editableNode = node; if (isEditableRequestNode(editableNode)) { @@ -652,6 +661,21 @@ export class ChatViewTreeWidget extends TreeWidget { />; } + protected renderResponseContent(content: ChatResponseContent, node?: ResponseNode): React.ReactNode { + const renderer = this.chatResponsePartRenderers.getContributions().reduce<[number, ChatResponsePartRenderer | undefined]>( + (prev, current) => { + const prio = current.canHandle(content); + if (prio > prev[0]) { + return [prio, current]; + } return prev; + }, + [-1, undefined])[1]; + if (!renderer) { + return undefined; + } + return renderer.render(content, node as ResponseNode); + } + protected renderChatResponse(node: ResponseNode): React.ReactNode { return (
@@ -757,7 +781,7 @@ const WidgetContainer: React.FC = ({ widget }) => { const ChatRequestRender = ( { node, hoverService, chatAgentService, variableService, openerService, - provideChatInputWidget + provideChatInputWidget, renderResponseContent }: { node: RequestNode, hoverService: HoverService, @@ -765,9 +789,16 @@ const ChatRequestRender = ( variableService: AIVariableService, openerService: OpenerService, provideChatInputWidget: () => ReactWidget | undefined, + renderResponseContent?: (content: ChatResponseContent, node?: ResponseNode) => React.ReactNode, }) => { - const parts = node.request.message.parts; - if (EditableChatRequestModel.isEditing(node.request)) { + // Capture the request object once to avoid getter issues + const request = node.request; + const parts = request.message.parts; + const isStale = request.isStale === true; + const isSummary = request.request.kind === 'summary'; + const isContinuation = request.request.kind === 'continuation'; + + if (EditableChatRequestModel.isEditing(request)) { const widget = provideChatInputWidget(); if (widget) { return
@@ -805,43 +836,57 @@ const ChatRequestRender = ( }; return ( -
-

- {parts.map((part, index) => { - if (part instanceof ParsedChatRequestAgentPart || part instanceof ParsedChatRequestVariablePart) { - let description = undefined; - let className = ''; - if (part instanceof ParsedChatRequestAgentPart) { - description = chatAgentService.getAgent(part.agentId)?.description; - className = 'theia-RequestNode-AgentLabel'; - } else if (part instanceof ParsedChatRequestVariablePart) { - description = variableService.getVariable(part.variableName)?.description; - className = 'theia-RequestNode-VariableLabel'; +

+ {(isSummary || isContinuation) && renderResponseContent ? ( +
+ {request.response.response.content.map((c, i) => { + const syntheticResponseNode: ResponseNode = { + id: request.response.id, + parent: node.parent, + response: request.response, + sessionId: node.sessionId + }; + return
{renderResponseContent(c, syntheticResponseNode)}
; + })} +
+ ) : ( +

+ {parts.map((part, index) => { + if (part instanceof ParsedChatRequestAgentPart || part instanceof ParsedChatRequestVariablePart) { + let description = undefined; + let className = ''; + if (part instanceof ParsedChatRequestAgentPart) { + description = chatAgentService.getAgent(part.agentId)?.description; + className = 'theia-RequestNode-AgentLabel'; + } else if (part instanceof ParsedChatRequestVariablePart) { + description = variableService.getVariable(part.variableName)?.description; + className = 'theia-RequestNode-VariableLabel'; + } + return ( + + ); + } else { + const ref = useMarkdownRendering( + part.text + .replace(/^[\r\n]+|[\r\n]+$/g, '') // remove excessive new lines + .replace(/(^ )/g, ' '), // enforce keeping space before + openerService, + true + ); + return ( + + ); } - return ( - - ); - } else { - const ref = useMarkdownRendering( - part.text - .replace(/^[\r\n]+|[\r\n]+$/g, '') // remove excessive new lines - .replace(/(^ )/g, ' '), // enforce keeping space before - openerService, - true - ); - return ( - - ); - } - })} -

- {renderFooter()} + })} +

+ )} + {!isSummary && !isContinuation && renderFooter()}
); }; diff --git a/packages/ai-chat-ui/src/browser/chat-view-widget.tsx b/packages/ai-chat-ui/src/browser/chat-view-widget.tsx index b16555cf52f15..b206fff94b1f5 100644 --- a/packages/ai-chat-ui/src/browser/chat-view-widget.tsx +++ b/packages/ai-chat-ui/src/browser/chat-view-widget.tsx @@ -15,6 +15,8 @@ // ***************************************************************************** import { CommandService, deepClone, Emitter, Event, MessageService, PreferenceService, URI } from '@theia/core'; import { ChatRequest, ChatRequestModel, ChatService, ChatSession, isActiveSessionChangedEvent, MutableChatModel } from '@theia/ai-chat'; +import { ChatSessionTokenTracker } from '@theia/ai-chat/lib/browser'; +import { BUDGET_AWARE_TOOL_LOOP_PREF } from '@theia/ai-chat/lib/common/ai-chat-preferences'; import { BaseWidget, codicon, ExtractableWidget, Message, PanelLayout, StatefulWidget } from '@theia/core/lib/browser'; import { nls } from '@theia/core/lib/common/nls'; import { inject, injectable, optional, postConstruct } from '@theia/core/shared/inversify'; @@ -25,6 +27,9 @@ import { AIVariableResolutionRequest } from '@theia/ai-core'; import { ProgressBarFactory } from '@theia/core/lib/browser/progress-bar-factory'; import { FrontendVariableService } from '@theia/ai-core/lib/browser'; import { FrontendLanguageModelRegistry } from '@theia/ai-core/lib/common'; +import { ChatTokenUsageIndicator } from './chat-token-usage-indicator'; +import * as React from '@theia/core/shared/react'; +import { Root, createRoot } from '@theia/core/shared/react-dom/client'; export namespace ChatViewWidget { export interface State { @@ -63,11 +68,17 @@ export class ChatViewWidget extends BaseWidget implements ExtractableWidget, Sta @inject(FrontendLanguageModelRegistry) protected readonly languageModelRegistry: FrontendLanguageModelRegistry; + @inject(ChatSessionTokenTracker) + protected readonly tokenTracker: ChatSessionTokenTracker; + @inject(ChatWelcomeMessageProvider) @optional() protected readonly welcomeProvider?: ChatWelcomeMessageProvider; protected chatSession: ChatSession; + protected tokenIndicatorContainer: HTMLDivElement; + protected tokenIndicatorRoot: Root; + protected _state: ChatViewWidget.State = { locked: false, temporaryLocked: false }; protected readonly onStateChangedEmitter = new Emitter(); @@ -107,8 +118,17 @@ export class ChatViewWidget extends BaseWidget implements ExtractableWidget, Sta layout.addWidget(this.treeWidget); this.inputWidget.node.classList.add('chat-input-widget'); layout.addWidget(this.inputWidget); + + // Add token indicator container after inputWidget is added to the layout + // so insertAdjacentElement can properly place it in the DOM + this.tokenIndicatorContainer = document.createElement('div'); + this.tokenIndicatorContainer.classList.add('chat-token-usage-container'); + this.inputWidget.node.insertAdjacentElement('beforebegin', this.tokenIndicatorContainer); this.chatSession = this.chatService.createSession(); + this.tokenIndicatorRoot = createRoot(this.tokenIndicatorContainer); + this.renderTokenIndicator(); + this.inputWidget.onQuery = this.onQuery.bind(this); this.inputWidget.onUnpin = this.onUnpin.bind(this); this.inputWidget.onCancel = this.onCancel.bind(this); @@ -177,6 +197,7 @@ export class ChatViewWidget extends BaseWidget implements ExtractableWidget, Sta this.treeWidget.trackChatModel(this.chatSession.model); this.inputWidget.chatModel = this.chatSession.model; this.inputWidget.pinnedAgent = this.chatSession.pinnedAgent; + this.renderTokenIndicator(); } else { console.warn(`Session with ${event.sessionId} not found.`); } @@ -299,4 +320,20 @@ export class ChatViewWidget extends BaseWidget implements ExtractableWidget, Sta getSettings(): { [key: string]: unknown } | undefined { return this.chatSession.model.settings; } + + protected renderTokenIndicator(): void { + const budgetAwareEnabled = this.preferenceService.get(BUDGET_AWARE_TOOL_LOOP_PREF, false); + this.tokenIndicatorRoot.render( + React.createElement(ChatTokenUsageIndicator, { + sessionId: this.chatSession.id, + tokenTracker: this.tokenTracker, + budgetAwareEnabled + }) + ); + } + + override dispose(): void { + this.tokenIndicatorRoot.unmount(); + super.dispose(); + } } diff --git a/packages/ai-chat-ui/src/browser/style/index.css b/packages/ai-chat-ui/src/browser/style/index.css index 3c3946c04c0df..37a365b4dfa9c 100644 --- a/packages/ai-chat-ui/src/browser/style/index.css +++ b/packages/ai-chat-ui/src/browser/style/index.css @@ -7,8 +7,8 @@ flex: 1; } -.chat-input-widget > .ps__rail-x, -.chat-input-widget > .ps__rail-y { +.chat-input-widget>.ps__rail-x, +.chat-input-widget>.ps__rail-y { display: none !important; } @@ -23,7 +23,7 @@ overflow-wrap: break-word; } -div:last-child > .theia-ChatNode { +div:last-child>.theia-ChatNode { border: none; } @@ -59,6 +59,7 @@ div:last-child > .theia-ChatNode { } @keyframes dots { + 0%, 20% { content: ""; @@ -137,7 +138,7 @@ div:last-child > .theia-ChatNode { padding-inline-start: 1rem; } -.theia-ChatNode li > p { +.theia-ChatNode li>p { margin-top: 0; margin-bottom: 0; } @@ -151,7 +152,7 @@ div:last-child > .theia-ChatNode { font-size: var(--theia-code-font-size); } -.theia-RequestNode > p div { +.theia-RequestNode>p div { display: inline; } @@ -450,8 +451,7 @@ div:last-child > .theia-ChatNode { text-align: center; } -.theia-ChatInput-ChangeSet-List - .theia-ChatInput-ChangeSet-Icon.codicon::before { +.theia-ChatInput-ChangeSet-List .theia-ChatInput-ChangeSet-Icon.codicon::before { font-size: var(--theia-ui-font-size1); } @@ -468,8 +468,7 @@ div:last-child > .theia-ChatNode { color: var(--theia-disabledForeground); } -.theia-ChatInput-ChangeSet-List - .theia-ChatInput-ChangeSet-AdditionalInfo-SuffixIcon { +.theia-ChatInput-ChangeSet-List .theia-ChatInput-ChangeSet-AdditionalInfo-SuffixIcon { font-size: var(--theia-ui-font-size0) px; margin-left: 4px; } @@ -749,8 +748,7 @@ div:last-child > .theia-ChatNode { display: flex; flex-direction: column; gap: 8px; - border: var(--theia-border-width) solid - var(--theia-sideBarSectionHeader-border); + border: var(--theia-border-width) solid var(--theia-sideBarSectionHeader-border); padding: 8px 12px 12px; border-radius: 5px; margin: 0 0 8px 0; @@ -1106,8 +1104,7 @@ details[open].collapsible-arguments .collapsible-arguments-summary { /* Delegation response styles */ .theia-delegation-container { - border: var(--theia-border-width) solid - var(--theia-sideBarSectionHeader-border); + border: var(--theia-border-width) solid var(--theia-sideBarSectionHeader-border); border-radius: var(--theia-ui-padding); margin: var(--theia-ui-padding) 0; background-color: var(--theia-sideBar-background); @@ -1122,8 +1119,7 @@ details[open].collapsible-arguments .collapsible-arguments-summary { padding: var(--theia-ui-padding); background-color: var(--theia-editorGroupHeader-tabsBackground); border-radius: var(--theia-ui-padding) var(--theia-ui-padding) 0 0; - border-bottom: var(--theia-border-width) solid - var(--theia-sideBarSectionHeader-border); + border-bottom: var(--theia-border-width) solid var(--theia-sideBarSectionHeader-border); list-style: none; position: relative; } @@ -1226,8 +1222,7 @@ details[open].collapsible-arguments .collapsible-arguments-summary { .delegation-prompt-section { margin-bottom: var(--theia-ui-padding); padding-bottom: var(--theia-ui-padding); - border-bottom: var(--theia-border-width) solid - var(--theia-sideBarSectionHeader-border); + border-bottom: var(--theia-border-width) solid var(--theia-sideBarSectionHeader-border); } .delegation-prompt { @@ -1244,7 +1239,7 @@ details[open].collapsible-arguments .collapsible-arguments-summary { margin-top: var(--theia-ui-padding); } -.delegation-response-section > strong { +.delegation-response-section>strong { display: block; margin-bottom: var(--theia-ui-padding); color: var(--theia-foreground); @@ -1283,6 +1278,96 @@ details[open].collapsible-arguments .collapsible-arguments-summary { color: var(--theia-button-foreground, #fff); } +/* Stale request indicator styles */ +.theia-RequestNode-stale { + opacity: 0.7; +} + +/* Summary request styles */ +.theia-RequestNode-summary { + background-color: var(--theia-editor-inactiveSelectionBackground); + border-left: 3px solid var(--theia-focusBorder); + padding-left: 8px; + margin: 8px 0; + border-radius: 4px; +} + +.theia-RequestNode-SummaryHeader { + display: flex; + align-items: center; + gap: 6px; + font-weight: 500; + margin-bottom: 8px; + color: var(--theia-descriptionForeground); +} + +.theia-RequestNode-SummaryHeader .codicon { + font-size: 14px; +} + +.theia-RequestNode-SummaryContent { + margin-top: 8px; +} + + +/* Chat summary styles */ +.theia-chat-summary { + margin: 8px 0; + border: 1px solid var(--theia-sideBarSectionHeader-border); + border-radius: 4px; + background-color: var(--theia-editorGroupHeader-tabsBackground); +} + +.theia-chat-summary details { + width: 100%; +} + +.theia-chat-summary summary { + cursor: pointer; + padding: 8px 12px; + display: flex; + align-items: center; + gap: 8px; + font-weight: 600; + color: var(--theia-foreground); + user-select: none; + list-style: none; + position: relative; +} + +.theia-chat-summary summary::-webkit-details-marker { + display: none; +} + +.theia-chat-summary summary::before { + content: "\25BC"; + position: absolute; + right: 12px; + top: 50%; + transform: translateY(-50%); + transition: transform 0.2s ease; + color: var(--theia-descriptionForeground); + font-size: var(--theia-ui-font-size1); +} + +.theia-chat-summary details:not([open]) summary::before { + transform: translateY(-50%) rotate(-90deg); +} + +.theia-chat-summary summary:hover { + background-color: var(--theia-toolbar-hoverBackground); +} + +.theia-chat-summary summary .codicon { + color: var(--theia-descriptionForeground); +} + +.theia-chat-summary-content { + padding: 12px; + border-top: 1px solid var(--theia-sideBarSectionHeader-border); + background-color: var(--theia-editor-background); +} + /* Unknown content styles */ .theia-chat-unknown-content { margin: 4px 0; @@ -1310,3 +1395,39 @@ details[open].collapsible-arguments .collapsible-arguments-summary { font-family: var(--theia-code-font-family); white-space: pre-wrap; } + +/* Token usage indicator styles */ +.chat-token-usage-container { + padding: 4px 8px; + border-top: 1px solid var(--theia-sideBarSectionHeader-border); + background-color: var(--theia-sideBar-background); +} + +.theia-ChatTokenUsageIndicator { + font-size: 11px; + display: flex; + align-items: center; + justify-content: center; + gap: 4px; + font-family: var(--theia-ui-font-family); +} + +.theia-ChatTokenUsageIndicator .token-usage-text { + white-space: nowrap; +} + +.theia-ChatTokenUsageIndicator.token-usage-none { + color: var(--theia-disabledForeground); +} + +.theia-ChatTokenUsageIndicator.token-usage-green { + color: var(--theia-successForeground, var(--theia-charts-green)); +} + +.theia-ChatTokenUsageIndicator.token-usage-yellow { + color: var(--theia-editorWarning-foreground, var(--theia-charts-yellow)); +} + +.theia-ChatTokenUsageIndicator.token-usage-red { + color: var(--theia-errorForeground, var(--theia-charts-red)); +} diff --git a/packages/ai-chat/README.md b/packages/ai-chat/README.md index 8aafbd53d7a22..ceca078422f02 100644 --- a/packages/ai-chat/README.md +++ b/packages/ai-chat/README.md @@ -15,6 +15,30 @@ The `@theia/ai-chat` extension provides the concept of a language model chat to Theia. It serves as the basis for `@theia/ai-chat-ui` to provide the Chat UI. +## Features + +### Budget-Aware Tool Loop (Experimental) + +When enabled via the `ai-features.chat.experimentalBudgetAwareToolLoop` preference, the chat system can automatically trigger conversation summarization when the token budget is exceeded during tool call loops. This prevents "context too long" API errors during complex multi-tool tasks. + +**How it works:** +1. When a chat request includes tools, the system manages the tool loop externally instead of letting the language model handle it internally +2. Between tool call iterations, the system checks if the token budget is exceeded +3. If exceeded, it triggers summarization to compress the conversation history +4. The task continues with the summarized context plus the pending tool calls + +**Requirements:** +- Currently only supports Anthropic models (other models fall back to standard behavior) +- Requires the language model to support the `singleRoundTrip` request property + +**Preference:** + +```json +{ + "ai-features.chat.experimentalBudgetAwareToolLoop": true +} +``` + ## Additional Information - [API documentation for `@theia/ai-chat`](https://eclipse-theia.github.io/theia/docs/next/modules/_theia_ai-chat.html) diff --git a/packages/ai-chat/src/browser/ai-chat-frontend-module.ts b/packages/ai-chat/src/browser/ai-chat-frontend-module.ts index 7661bab3c6d17..9005993b2fcf2 100644 --- a/packages/ai-chat/src/browser/ai-chat-frontend-module.ts +++ b/packages/ai-chat/src/browser/ai-chat-frontend-module.ts @@ -14,7 +14,7 @@ // SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-only WITH Classpath-exception-2.0 // ***************************************************************************** -import { Agent, AgentService, AIVariableContribution, bindToolProvider } from '@theia/ai-core/lib/common'; +import { Agent, AgentService, AIVariableContribution, bindToolProvider, LanguageModelService } from '@theia/ai-core/lib/common'; import { bindContributionProvider, CommandContribution, PreferenceContribution } from '@theia/core'; import { FrontendApplicationContribution, LabelProviderContribution } from '@theia/core/lib/browser'; import { ContainerModule } from '@theia/core/shared/inversify'; @@ -28,7 +28,8 @@ import { ToolCallChatResponseContentFactory, PinChatAgent, ChatServiceFactory, - ChatAgentServiceFactory + ChatAgentServiceFactory, + ChatSessionSummarizationServiceSymbol } from '../common'; import { ChatAgentsVariableContribution } from '../common/chat-agents-variable-contribution'; import { CustomChatAgent } from '../common/custom-chat-agent'; @@ -73,8 +74,11 @@ import { ChangeSetElementDeserializerRegistryImpl } from '../common/change-set-element-deserializer'; import { ChangeSetFileElementDeserializerContribution } from './change-set-file-element-deserializer'; +import { ChatSessionTokenTracker, ChatSessionTokenTrackerImpl } from './chat-session-token-tracker'; +import { ChatSessionSummarizationService, ChatSessionSummarizationServiceImpl } from './chat-session-summarization-service'; +import { ChatLanguageModelServiceImpl } from './chat-language-model-service'; -export default new ContainerModule(bind => { +export default new ContainerModule((bind, unbind, isBound, rebind) => { bindContributionProvider(bind, ChatAgent); bind(ChatContentDeserializerRegistryImpl).toSelf().inSingletonScope(); @@ -186,4 +190,16 @@ export default new ContainerModule(bind => { bind(CommandContribution).toService(AIChatFrontendContribution); bindToolProvider(AgentDelegationTool, bind); + + bind(ChatSessionTokenTrackerImpl).toSelf().inSingletonScope(); + bind(ChatSessionTokenTracker).toService(ChatSessionTokenTrackerImpl); + + bind(ChatSessionSummarizationServiceImpl).toSelf().inSingletonScope(); + bind(ChatSessionSummarizationService).toService(ChatSessionSummarizationServiceImpl); + bind(ChatSessionSummarizationServiceSymbol).toService(ChatSessionSummarizationService); + bind(FrontendApplicationContribution).toService(ChatSessionSummarizationServiceImpl); + + // Rebind LanguageModelService to use the chat-aware implementation with budget-aware tool loop + bind(ChatLanguageModelServiceImpl).toSelf().inSingletonScope(); + rebind(LanguageModelService).toService(ChatLanguageModelServiceImpl); }); diff --git a/packages/ai-chat/src/browser/chat-language-model-service.spec.ts b/packages/ai-chat/src/browser/chat-language-model-service.spec.ts new file mode 100644 index 0000000000000..392980d383971 --- /dev/null +++ b/packages/ai-chat/src/browser/chat-language-model-service.spec.ts @@ -0,0 +1,562 @@ +// ***************************************************************************** +// Copyright (C) 2025 EclipseSource GmbH. +// +// This program and the accompanying materials are made available under the +// terms of the Eclipse Public License v. 2.0 which is available at +// http://www.eclipse.org/legal/epl-2.0. +// +// This Source Code may also be made available under the following Secondary +// Licenses when the conditions for such availability set forth in the Eclipse +// Public License v. 2.0 are satisfied: GNU General Public License, version 2 +// with the GNU Classpath Exception which is available at +// https://www.gnu.org/software/classpath/license.html. +// +// SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-only WITH Classpath-exception-2.0 +// ***************************************************************************** + +import { expect } from 'chai'; +import * as sinon from 'sinon'; +import { Container } from '@theia/core/shared/inversify'; +import { ILogger, PreferenceService } from '@theia/core'; +import { + LanguageModel, + LanguageModelRegistry, + LanguageModelStreamResponse, + LanguageModelStreamResponsePart, + UserRequest, + isLanguageModelStreamResponse +} from '@theia/ai-core'; +import { ChatLanguageModelServiceImpl } from './chat-language-model-service'; +import { ChatSessionTokenTracker, CHAT_TOKEN_THRESHOLD } from './chat-session-token-tracker'; +import { ChatSessionSummarizationService } from './chat-session-summarization-service'; +import { BUDGET_AWARE_TOOL_LOOP_PREF } from '../common/ai-chat-preferences'; +import { PREFERENCE_NAME_REQUEST_SETTINGS } from '@theia/ai-core/lib/common/ai-core-preferences'; + +describe('ChatLanguageModelServiceImpl', () => { + let container: Container; + let service: ChatLanguageModelServiceImpl; + let mockLanguageModel: sinon.SinonStubbedInstance; + let mockPreferenceService: sinon.SinonStubbedInstance; + let mockTokenTracker: sinon.SinonStubbedInstance; + let mockSummarizationService: sinon.SinonStubbedInstance; + let mockLogger: sinon.SinonStubbedInstance; + + beforeEach(() => { + container = new Container(); + + // Create mocks + mockLanguageModel = { + id: 'test/model', + name: 'Test Model', + request: sinon.stub() + } as unknown as sinon.SinonStubbedInstance; + + mockPreferenceService = { + get: sinon.stub() + } as unknown as sinon.SinonStubbedInstance; + + mockTokenTracker = { + getSessionInputTokens: sinon.stub(), + getSessionOutputTokens: sinon.stub(), + getSessionTotalTokens: sinon.stub(), + updateSessionTokens: sinon.stub() + } as unknown as sinon.SinonStubbedInstance; + + mockSummarizationService = { + hasSummary: sinon.stub(), + markPendingSplit: sinon.stub(), + checkAndHandleSummarization: sinon.stub().resolves(false) + } as unknown as sinon.SinonStubbedInstance; + + mockLogger = { + info: sinon.stub(), + warn: sinon.stub(), + error: sinon.stub(), + debug: sinon.stub() + } as unknown as sinon.SinonStubbedInstance; + + // Bind mocks + container.bind(LanguageModelRegistry).toConstantValue({} as LanguageModelRegistry); + container.bind(PreferenceService).toConstantValue(mockPreferenceService); + container.bind(ChatSessionTokenTracker).toConstantValue(mockTokenTracker); + container.bind(ChatSessionSummarizationService).toConstantValue(mockSummarizationService); + container.bind(ILogger).toConstantValue(mockLogger); + container.bind(ChatLanguageModelServiceImpl).toSelf().inSingletonScope(); + + service = container.get(ChatLanguageModelServiceImpl); + + // Default preference setup + mockPreferenceService.get.withArgs(PREFERENCE_NAME_REQUEST_SETTINGS, []).returns([]); + }); + + afterEach(() => { + sinon.restore(); + }); + + describe('sendRequest', () => { + it('should delegate to super when preference is disabled', async () => { + mockPreferenceService.get.withArgs(BUDGET_AWARE_TOOL_LOOP_PREF, false).returns(false); + + const request: UserRequest = { + sessionId: 'session-1', + requestId: 'request-1', + messages: [{ actor: 'user', type: 'text', text: 'Hello' }], + tools: [{ id: 'tool-1', name: 'test-tool', parameters: { type: 'object', properties: {} }, handler: async () => 'result' }] + }; + + const mockStream = createMockStream([{ content: 'Response' }]); + mockLanguageModel.request.resolves({ stream: mockStream }); + + const response = await service.sendRequest(mockLanguageModel, request); + + expect(isLanguageModelStreamResponse(response)).to.be.true; + expect(mockSummarizationService.markPendingSplit.called).to.be.false; + }); + + it('should delegate to super when request has no tools', async () => { + mockPreferenceService.get.withArgs(BUDGET_AWARE_TOOL_LOOP_PREF, false).returns(true); + + const request: UserRequest = { + sessionId: 'session-1', + requestId: 'request-1', + messages: [{ actor: 'user', type: 'text', text: 'Hello' }] + // No tools + }; + + const mockStream = createMockStream([{ content: 'Response' }]); + mockLanguageModel.request.resolves({ stream: mockStream }); + + const response = await service.sendRequest(mockLanguageModel, request); + + expect(isLanguageModelStreamResponse(response)).to.be.true; + expect(mockSummarizationService.markPendingSplit.called).to.be.false; + }); + + it('should use budget-aware handling when preference is enabled and tools are present', async () => { + mockPreferenceService.get.withArgs(BUDGET_AWARE_TOOL_LOOP_PREF, false).returns(true); + mockTokenTracker.getSessionInputTokens.returns(100); // Below threshold + + const request: UserRequest = { + sessionId: 'session-1', + requestId: 'request-1', + messages: [{ actor: 'user', type: 'text', text: 'Hello' }], + tools: [{ id: 'tool-1', name: 'test-tool', parameters: { type: 'object', properties: {} }, handler: async () => 'result' }] + }; + + const mockStream = createMockStream([{ content: 'Response' }]); + mockLanguageModel.request.resolves({ stream: mockStream }); + + const response = await service.sendRequest(mockLanguageModel, request); + + expect(isLanguageModelStreamResponse(response)).to.be.true; + + // Consume stream to trigger the actual request + // eslint-disable-next-line @typescript-eslint/no-unused-vars + for await (const _part of (response as LanguageModelStreamResponse).stream) { + // just consume + } + + // Verify singleRoundTrip was set + expect(mockLanguageModel.request.calledOnce).to.be.true; + const actualRequest = mockLanguageModel.request.firstCall.args[0] as UserRequest; + expect(actualRequest.singleRoundTrip).to.be.true; + }); + }); + + describe('budget checking', () => { + it('should call markPendingSplit after tool execution when budget is exceeded mid-loop', async () => { + mockPreferenceService.get.withArgs(BUDGET_AWARE_TOOL_LOOP_PREF, false).returns(true); + // Return over threshold - budget check happens after tool execution + mockTokenTracker.getSessionInputTokens.returns(CHAT_TOKEN_THRESHOLD + 1000); + + const toolHandler = sinon.stub().resolves('tool result'); + const request: UserRequest = { + sessionId: 'session-1', + requestId: 'request-1', + messages: [{ actor: 'user', type: 'text', text: 'Hello' }], + tools: [{ + id: 'tool-1', + name: 'test-tool', + parameters: { type: 'object', properties: {} }, + handler: toolHandler + }] + }; + + // Model returns tool call without result + const mockStream = createMockStream([ + { content: 'Let me use a tool' }, + { tool_calls: [{ id: 'call-1', function: { name: 'test-tool', arguments: '{}' }, finished: true }] } + ]); + + mockLanguageModel.request.resolves({ stream: mockStream }); + + const response = await service.sendRequest(mockLanguageModel, request); + + // Consume stream to trigger the tool loop + // eslint-disable-next-line @typescript-eslint/no-unused-vars + for await (const _part of (response as LanguageModelStreamResponse).stream) { + // just consume + } + + // Verify markPendingSplit was called after tool execution (mid-turn budget exceeded) + expect(mockSummarizationService.markPendingSplit.calledOnce).to.be.true; + const markPendingCall = mockSummarizationService.markPendingSplit.firstCall; + expect(markPendingCall.args[0]).to.equal('session-1'); + expect(markPendingCall.args[1]).to.equal('request-1'); + // Third arg should be pending tool calls array + expect(markPendingCall.args[2]).to.be.an('array'); + // Fourth arg should be tool results map + expect(markPendingCall.args[3]).to.be.instanceOf(Map); + + // Loop should exit after split + expect(mockLanguageModel.request.calledOnce).to.be.true; + }); + + it('should not trigger markPendingSplit when no tool calls are made', async () => { + mockPreferenceService.get.withArgs(BUDGET_AWARE_TOOL_LOOP_PREF, false).returns(true); + mockTokenTracker.getSessionInputTokens.returns(CHAT_TOKEN_THRESHOLD + 1000); + + const request: UserRequest = { + sessionId: 'session-1', + requestId: 'request-1', + messages: [{ actor: 'user', type: 'text', text: 'Hello' }], + tools: [{ id: 'tool-1', name: 'test-tool', parameters: { type: 'object', properties: {} }, handler: async () => 'result' }] + }; + + // Model returns response without tool calls + const mockStream = createMockStream([{ content: 'Response' }]); + mockLanguageModel.request.resolves({ stream: mockStream }); + + const response = await service.sendRequest(mockLanguageModel, request); + + // Consume stream to trigger the actual request + // eslint-disable-next-line @typescript-eslint/no-unused-vars + for await (const _part of (response as LanguageModelStreamResponse).stream) { + // just consume + } + + // markPendingSplit should NOT be called when there are no pending tool calls + expect(mockSummarizationService.markPendingSplit.called).to.be.false; + }); + + it('should preserve original messages (no message rebuilding)', async () => { + mockPreferenceService.get.withArgs(BUDGET_AWARE_TOOL_LOOP_PREF, false).returns(true); + mockTokenTracker.getSessionInputTokens.returns(100); // Below threshold + + const request: UserRequest = { + sessionId: 'session-1', + requestId: 'request-1', + messages: [ + { actor: 'system', type: 'text', text: 'System prompt' }, + { actor: 'user', type: 'text', text: 'Hello' } + ], + tools: [{ id: 'tool-1', name: 'test-tool', parameters: { type: 'object', properties: {} }, handler: async () => 'result' }] + }; + + const mockStream = createMockStream([{ content: 'Response' }]); + mockLanguageModel.request.resolves({ stream: mockStream }); + + const response = await service.sendRequest(mockLanguageModel, request); + + // Consume stream to trigger the actual request + // eslint-disable-next-line @typescript-eslint/no-unused-vars + for await (const _part of (response as LanguageModelStreamResponse).stream) { + // just consume + } + + // Verify original messages are preserved + expect(mockLanguageModel.request.calledOnce).to.be.true; + const actualRequest = mockLanguageModel.request.firstCall.args[0] as UserRequest; + expect(actualRequest.messages).to.have.length(2); + expect((actualRequest.messages[0] as { text: string }).text).to.equal('System prompt'); + expect((actualRequest.messages[1] as { text: string }).text).to.equal('Hello'); + }); + }); + + describe('tool loop handling', () => { + it('should execute tools and continue loop when model respects singleRoundTrip', async () => { + mockPreferenceService.get.withArgs(BUDGET_AWARE_TOOL_LOOP_PREF, false).returns(true); + mockTokenTracker.getSessionInputTokens.returns(100); + + const toolHandler = sinon.stub().resolves('tool result'); + const request: UserRequest = { + sessionId: 'session-1', + requestId: 'request-1', + messages: [{ actor: 'user', type: 'text', text: 'Hello' }], + tools: [{ + id: 'tool-1', + name: 'test-tool', + parameters: { type: 'object', properties: {} }, + handler: toolHandler + }] + }; + + // First call: model returns tool call without result (respected singleRoundTrip) + const firstStream = createMockStream([ + { content: 'Let me use a tool' }, + { tool_calls: [{ id: 'call-1', function: { name: 'test-tool', arguments: '{}' }, finished: true }] } + ]); + + // Second call: model returns final response + const secondStream = createMockStream([ + { content: 'Done!' } + ]); + + mockLanguageModel.request + .onFirstCall().resolves({ stream: firstStream }) + .onSecondCall().resolves({ stream: secondStream }); + + const response = await service.sendRequest(mockLanguageModel, request); + expect(isLanguageModelStreamResponse(response)).to.be.true; + + // Consume the stream to trigger the tool loop + const parts: LanguageModelStreamResponsePart[] = []; + for await (const part of (response as LanguageModelStreamResponse).stream) { + parts.push(part); + } + + // Verify tool was executed + expect(toolHandler.calledOnce).to.be.true; + + // Verify two LLM calls were made (initial + continuation) + expect(mockLanguageModel.request.calledTwice).to.be.true; + }); + + it('should not execute tools when model ignores singleRoundTrip (has results)', async () => { + mockPreferenceService.get.withArgs(BUDGET_AWARE_TOOL_LOOP_PREF, false).returns(true); + mockTokenTracker.getSessionInputTokens.returns(100); + + const toolHandler = sinon.stub().resolves('tool result'); + const request: UserRequest = { + sessionId: 'session-1', + requestId: 'request-1', + messages: [{ actor: 'user', type: 'text', text: 'Hello' }], + tools: [{ + id: 'tool-1', + name: 'test-tool', + parameters: { type: 'object', properties: {} }, + handler: toolHandler + }] + }; + + // Model returns tool call WITH result (ignored singleRoundTrip, handled internally) + const mockStream = createMockStream([ + { content: 'Let me use a tool' }, + { tool_calls: [{ id: 'call-1', function: { name: 'test-tool', arguments: '{}' }, finished: true, result: 'internal result' }] }, + { content: 'Done!' } + ]); + + mockLanguageModel.request.resolves({ stream: mockStream }); + + const response = await service.sendRequest(mockLanguageModel, request); + expect(isLanguageModelStreamResponse(response)).to.be.true; + + // Consume the stream + const parts: LanguageModelStreamResponsePart[] = []; + for await (const part of (response as LanguageModelStreamResponse).stream) { + parts.push(part); + } + + // Tool should NOT have been executed by our service (model did it) + expect(toolHandler.called).to.be.false; + + // Only one LLM call (model handled everything) + expect(mockLanguageModel.request.calledOnce).to.be.true; + }); + }); + + describe('subRequestId handling', () => { + it('should set subRequestId with format requestId-0 on first call', async () => { + mockPreferenceService.get.withArgs(BUDGET_AWARE_TOOL_LOOP_PREF, false).returns(true); + mockTokenTracker.getSessionInputTokens.returns(100); + + const request: UserRequest = { + sessionId: 'session-1', + requestId: 'request-1', + messages: [{ actor: 'user', type: 'text', text: 'Hello' }], + tools: [{ id: 'tool-1', name: 'test-tool', parameters: { type: 'object', properties: {} }, handler: async () => 'result' }] + }; + + const mockStream = createMockStream([{ content: 'Response' }]); + mockLanguageModel.request.resolves({ stream: mockStream }); + + const response = await service.sendRequest(mockLanguageModel, request); + + // Consume stream to trigger the actual request + // eslint-disable-next-line @typescript-eslint/no-unused-vars + for await (const _part of (response as LanguageModelStreamResponse).stream) { + // just consume + } + + expect(mockLanguageModel.request.calledOnce).to.be.true; + const actualRequest = mockLanguageModel.request.firstCall.args[0] as UserRequest; + expect(actualRequest.subRequestId).to.equal('request-1-0'); + }); + + it('should increment subRequestId across multiple tool loop iterations', async () => { + mockPreferenceService.get.withArgs(BUDGET_AWARE_TOOL_LOOP_PREF, false).returns(true); + mockTokenTracker.getSessionInputTokens.returns(100); + + const toolHandler = sinon.stub().resolves('tool result'); + const request: UserRequest = { + sessionId: 'session-1', + requestId: 'request-1', + messages: [{ actor: 'user', type: 'text', text: 'Hello' }], + tools: [{ + id: 'tool-1', + name: 'test-tool', + parameters: { type: 'object', properties: {} }, + handler: toolHandler + }] + }; + + // First call: model returns tool call without result (respected singleRoundTrip) + const firstStream = createMockStream([ + { content: 'Let me use a tool' }, + { tool_calls: [{ id: 'call-1', function: { name: 'test-tool', arguments: '{}' }, finished: true }] } + ]); + + // Second call: model returns another tool call + const secondStream = createMockStream([ + { content: 'Using another tool' }, + { tool_calls: [{ id: 'call-2', function: { name: 'test-tool', arguments: '{}' }, finished: true }] } + ]); + + // Third call: model returns final response + const thirdStream = createMockStream([ + { content: 'Done!' } + ]); + + mockLanguageModel.request + .onFirstCall().resolves({ stream: firstStream }) + .onSecondCall().resolves({ stream: secondStream }) + .onThirdCall().resolves({ stream: thirdStream }); + + const response = await service.sendRequest(mockLanguageModel, request); + + // Consume the stream to trigger the tool loop + // eslint-disable-next-line @typescript-eslint/no-unused-vars + for await (const _part of (response as LanguageModelStreamResponse).stream) { + // just consume + } + + // Verify three LLM calls were made + expect(mockLanguageModel.request.calledThrice).to.be.true; + + // Verify subRequestId increments: request-1-0, request-1-1, request-1-2 + const firstCallRequest = mockLanguageModel.request.firstCall.args[0] as UserRequest; + expect(firstCallRequest.subRequestId).to.equal('request-1-0'); + + const secondCallRequest = mockLanguageModel.request.secondCall.args[0] as UserRequest; + expect(secondCallRequest.subRequestId).to.equal('request-1-1'); + + const thirdCallRequest = mockLanguageModel.request.thirdCall.args[0] as UserRequest; + expect(thirdCallRequest.subRequestId).to.equal('request-1-2'); + }); + }); + + describe('error handling', () => { + it('should throw error when model returns non-streaming response', async () => { + mockPreferenceService.get.withArgs(BUDGET_AWARE_TOOL_LOOP_PREF, false).returns(true); + mockTokenTracker.getSessionInputTokens.returns(100); + + const request: UserRequest = { + sessionId: 'session-1', + requestId: 'request-1', + messages: [{ actor: 'user', type: 'text', text: 'Hello' }], + tools: [{ id: 'tool-1', name: 'test-tool', parameters: { type: 'object', properties: {} }, handler: async () => 'result' }] + }; + + // Model returns a non-streaming response (just text, not a stream) + mockLanguageModel.request.resolves({ text: 'Non-streaming response' }); + + const response = await service.sendRequest(mockLanguageModel, request); + expect(isLanguageModelStreamResponse(response)).to.be.true; + + // Consuming the stream should throw + try { + // eslint-disable-next-line @typescript-eslint/no-unused-vars + for await (const _part of (response as LanguageModelStreamResponse).stream) { + // Should throw before we get here + } + expect.fail('Should have thrown an error'); + } catch (error) { + expect(error).to.be.instanceOf(Error); + expect((error as Error).message).to.equal('Budget-aware tool loop requires streaming response. Model returned non-streaming response.'); + } + }); + + it('should log context-too-long errors and re-throw', async () => { + mockPreferenceService.get.withArgs(BUDGET_AWARE_TOOL_LOOP_PREF, false).returns(true); + mockTokenTracker.getSessionInputTokens.returns(100); + + const request: UserRequest = { + sessionId: 'session-1', + requestId: 'request-1', + messages: [{ actor: 'user', type: 'text', text: 'Hello' }], + tools: [{ id: 'tool-1', name: 'test-tool', parameters: { type: 'object', properties: {} }, handler: async () => 'result' }] + }; + + const contextError = new Error('Request too long: context exceeds maximum token limit'); + mockLanguageModel.request.rejects(contextError); + + const response = await service.sendRequest(mockLanguageModel, request); + expect(isLanguageModelStreamResponse(response)).to.be.true; + + // Consuming the stream should throw + try { + // eslint-disable-next-line @typescript-eslint/no-unused-vars + for await (const _part of (response as LanguageModelStreamResponse).stream) { + // Should throw before we get here + } + expect.fail('Should have thrown an error'); + } catch (error) { + expect(error).to.equal(contextError); + expect(mockLogger.error.calledOnce).to.be.true; + expect(mockLogger.error.firstCall.args[0]).to.include('Context too long'); + } + }); + + it('should propagate non-context errors without special logging', async () => { + mockPreferenceService.get.withArgs(BUDGET_AWARE_TOOL_LOOP_PREF, false).returns(true); + mockTokenTracker.getSessionInputTokens.returns(100); + + const request: UserRequest = { + sessionId: 'session-1', + requestId: 'request-1', + messages: [{ actor: 'user', type: 'text', text: 'Hello' }], + tools: [{ id: 'tool-1', name: 'test-tool', parameters: { type: 'object', properties: {} }, handler: async () => 'result' }] + }; + + const networkError = new Error('Network connection failed'); + mockLanguageModel.request.rejects(networkError); + + const response = await service.sendRequest(mockLanguageModel, request); + + try { + // eslint-disable-next-line @typescript-eslint/no-unused-vars + for await (const _part of (response as LanguageModelStreamResponse).stream) { + // Should throw + } + expect.fail('Should have thrown an error'); + } catch (error) { + expect(error).to.equal(networkError); + // Should not have logged "Context too long" message + expect(mockLogger.error.called).to.be.false; + } + }); + }); +}); + +/** + * Helper to create a mock async iterable stream from an array of parts. + */ +function createMockStream(parts: LanguageModelStreamResponsePart[]): AsyncIterable { + return { + async *[Symbol.asyncIterator](): AsyncIterator { + for (const part of parts) { + yield part; + } + } + }; +} diff --git a/packages/ai-chat/src/browser/chat-language-model-service.ts b/packages/ai-chat/src/browser/chat-language-model-service.ts new file mode 100644 index 0000000000000..277408e186613 --- /dev/null +++ b/packages/ai-chat/src/browser/chat-language-model-service.ts @@ -0,0 +1,349 @@ +// ***************************************************************************** +// Copyright (C) 2025 EclipseSource GmbH. +// +// This program and the accompanying materials are made available under the +// terms of the Eclipse Public License v. 2.0 which is available at +// http://www.eclipse.org/legal/epl-2.0. +// +// This Source Code may also be made available under the following Secondary +// Licenses when the conditions for such availability set forth in the Eclipse +// Public License v. 2.0 are satisfied: GNU General Public License, version 2 +// with the GNU Classpath Exception which is available at +// https://www.gnu.org/software/classpath/license.html. +// +// SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-only WITH Classpath-exception-2.0 +// ***************************************************************************** + +import { injectable, inject } from '@theia/core/shared/inversify'; +import { ILogger, PreferenceService } from '@theia/core'; +import { LanguageModelServiceImpl } from '@theia/ai-core/lib/common/language-model-service'; +import { + LanguageModel, + LanguageModelResponse, + LanguageModelStreamResponse, + LanguageModelStreamResponsePart, + UserRequest, + isLanguageModelStreamResponse, + isToolCallResponsePart, + ToolCall, + ToolRequest, + ToolCallResult, + LanguageModelMessage, + LanguageModelRegistry +} from '@theia/ai-core'; +import { BUDGET_AWARE_TOOL_LOOP_PREF } from '../common/ai-chat-preferences'; +import { ChatSessionTokenTracker, CHAT_TOKEN_THRESHOLD } from './chat-session-token-tracker'; +import { ChatSessionSummarizationService } from './chat-session-summarization-service'; +import { applyRequestSettings } from '@theia/ai-core/lib/browser/frontend-language-model-service'; + +/** + * Chat-specific language model service that adds budget-aware tool loop handling. + * Extends LanguageModelServiceImpl to intercept sendRequest() calls. + * + * When the experimental preference is enabled, this service: + * 1. Sets singleRoundTrip=true to prevent models from handling tool loops internally + * 2. Manages the tool loop externally with budget checks between iterations + * 3. Triggers summarization when token budget is exceeded mid-turn + * + * Models that don't support singleRoundTrip will ignore the flag - this is detected + * by checking if tool_calls have results attached (model handled internally) vs + * no results (model respected the flag). + */ +@injectable() +export class ChatLanguageModelServiceImpl extends LanguageModelServiceImpl { + + @inject(LanguageModelRegistry) + protected override languageModelRegistry: LanguageModelRegistry; + + @inject(ChatSessionTokenTracker) + protected readonly tokenTracker: ChatSessionTokenTracker; + + @inject(PreferenceService) + protected readonly preferenceService: PreferenceService; + + @inject(ILogger) + protected readonly logger: ILogger; + + @inject(ChatSessionSummarizationService) + protected readonly summarizationService: ChatSessionSummarizationService; + + override async sendRequest( + languageModel: LanguageModel, + request: UserRequest + ): Promise { + applyRequestSettings(request, languageModel.id, request.agentId, this.preferenceService); + + const budgetAwareEnabled = this.preferenceService.get(BUDGET_AWARE_TOOL_LOOP_PREF, false); + + if (budgetAwareEnabled && request.tools?.length) { + return this.sendRequestWithBudgetAwareness(languageModel, request); + } + + return super.sendRequest(languageModel, request); + } + + /** + * Send request with budget-aware tool loop handling. + * Manages the tool loop externally, checking token budget between iterations + * and triggering summarization when needed. + */ + protected async sendRequestWithBudgetAwareness( + languageModel: LanguageModel, + request: UserRequest + ): Promise { + const modifiedRequest: UserRequest = { + ...request, + singleRoundTrip: true + }; + return this.executeToolLoop(languageModel, modifiedRequest); + } + + /** + * Execute the tool loop, handling tool calls and budget checks between iterations. + * This method coordinates the overall flow, delegating to helper methods for specific tasks. + */ + protected async executeToolLoop( + languageModel: LanguageModel, + request: UserRequest + ): Promise { + const that = this; + const sessionId = request.sessionId; + const tools = request.tools ?? []; + + // State that persists across the async iterator + let currentMessages = [...request.messages]; + + const asyncIterator = { + async *[Symbol.asyncIterator](): AsyncIterator { + let continueLoop = true; + let iteration = 0; + + while (continueLoop) { + continueLoop = false; + + // Get response from model + const response = await that.sendSingleRoundTripRequest( + languageModel, request, currentMessages, sessionId, iteration + ); + + // Process the stream and collect tool calls + const streamProcessor = that.processResponseStream(response.stream); + let streamResult: IteratorResult; + + // Yield all parts from the stream processor + while (!(streamResult = await streamProcessor.next()).done) { + yield streamResult.value; + } + + const { pendingToolCalls, modelHandledLoop } = streamResult.value; + + // If model handled the loop internally, we're done + if (modelHandledLoop) { + return; + } + + // If there are pending tool calls, execute them and check if we need to split + if (pendingToolCalls.length > 0) { + const { toolResults, shouldSplit } = await that.executeToolsAndCheckBudget( + pendingToolCalls, tools, sessionId + ); + + if (shouldSplit && sessionId) { + // Budget exceeded - mark pending split and exit cleanly + that.summarizationService.markPendingSplit(sessionId, request.requestId, pendingToolCalls, toolResults); + return; + } + + // Normal case - append tool messages and continue loop + currentMessages = that.appendToolMessages( + currentMessages, + pendingToolCalls, + toolResults + ); + + // Yield tool call results + const resultsToYield = pendingToolCalls.map(tc => ({ + finished: true, + id: tc.id, + result: toolResults.get(tc.id!), + function: tc.function + })); + yield { tool_calls: resultsToYield }; + + iteration++; + continueLoop = true; + } + } + } + }; + + return { stream: asyncIterator }; + } + + /** + * Send a single round-trip request to the language model. + * Handles context-too-long errors and ensures streaming response. + */ + protected async sendSingleRoundTripRequest( + languageModel: LanguageModel, + request: UserRequest, + currentMessages: LanguageModelMessage[], + sessionId: string | undefined, + iteration: number + ): Promise { + const currentRequest: UserRequest = { + ...request, + messages: currentMessages, + singleRoundTrip: true, + subRequestId: `${request.requestId}-${iteration}` + }; + + let response: LanguageModelResponse; + try { + response = await LanguageModelServiceImpl.prototype.sendRequest.call( + this, languageModel, currentRequest + ); + } catch (error) { + const errorMessage = error instanceof Error ? error.message : String(error); + if (errorMessage.toLowerCase().includes('context') || + errorMessage.toLowerCase().includes('token') || + errorMessage.toLowerCase().includes('too long') || + errorMessage.toLowerCase().includes('max_tokens')) { + this.logger.error( + 'Context too long error for session ' + sessionId + '. ' + + 'Cannot recover - summarization also requires an LLM call.', + error + ); + } + throw error; + } + + if (!isLanguageModelStreamResponse(response)) { + throw new Error('Budget-aware tool loop requires streaming response. Model returned non-streaming response.'); + } + + return response; + } + + /** + * Process a response stream, collecting tool calls and yielding parts. + * @returns Object with pendingToolCalls and whether the model handled the loop internally + */ + protected async *processResponseStream( + stream: AsyncIterable + ): AsyncGenerator { + const pendingToolCalls: ToolCall[] = []; + let modelHandledLoop = false; + + for await (const part of stream) { + if (isToolCallResponsePart(part)) { + for (const toolCall of part.tool_calls) { + // If any tool call has a result, the model handled the loop internally + if (toolCall.result !== undefined) { + modelHandledLoop = true; + } + // Collect finished tool calls without results (model respected singleRoundTrip) + if (toolCall.finished && toolCall.result === undefined && toolCall.id) { + pendingToolCalls.push(toolCall); + } + } + } + yield part; + } + + return { pendingToolCalls, modelHandledLoop }; + } + + /** + * Execute pending tool calls and check if budget is exceeded. + * Returns a signal indicating if the turn should be split. + */ + protected async executeToolsAndCheckBudget( + pendingToolCalls: ToolCall[], + tools: ToolRequest[], + sessionId: string | undefined + ): Promise<{ toolResults: Map; shouldSplit: boolean }> { + const toolResults = await this.executeTools(pendingToolCalls, tools); + + const shouldSplit = sessionId !== undefined && this.isBudgetExceeded(sessionId); + + return { toolResults, shouldSplit }; + } + + /** + * Check if the token budget is exceeded for a session. + */ + protected isBudgetExceeded(sessionId: string | undefined): boolean { + if (!sessionId) { + return false; + } + const tokens = this.tokenTracker.getSessionInputTokens(sessionId); + return tokens !== undefined && tokens >= CHAT_TOKEN_THRESHOLD; + } + + /** + * Execute tool calls and collect results. + */ + protected async executeTools( + toolCalls: ToolCall[], + toolRequests: ToolRequest[] + ): Promise> { + const results = new Map(); + + for (const toolCall of toolCalls) { + const toolRequest = toolRequests.find(t => t.name === toolCall.function?.name); + if (toolRequest && toolCall.id && toolCall.function?.arguments) { + try { + const result = await toolRequest.handler(toolCall.function.arguments); + results.set(toolCall.id, result); + } catch (error) { + this.logger.error(`Tool execution failed for ${toolCall.function?.name}:`, error); + results.set(toolCall.id, { type: 'error', data: String(error) } as ToolCallResult); + } + } + } + + return results; + } + + /** + * Append tool_use and tool_result messages to the message array. + */ + protected appendToolMessages( + messages: LanguageModelMessage[], + toolCalls: ToolCall[], + toolResults: Map + ): LanguageModelMessage[] { + const newMessages: LanguageModelMessage[] = [...messages]; + + // Add tool_use messages (AI requesting tool calls) + for (const toolCall of toolCalls) { + if (toolCall.id && toolCall.function?.name) { + newMessages.push({ + actor: 'ai', + type: 'tool_use', + id: toolCall.id, + name: toolCall.function.name, + input: toolCall.function.arguments ? JSON.parse(toolCall.function.arguments) : {} + }); + } + } + + // Add tool_result messages (user providing results) + for (const toolCall of toolCalls) { + if (toolCall.id && toolCall.function?.name) { + const result = toolResults.get(toolCall.id); + newMessages.push({ + actor: 'user', + type: 'tool_result', + tool_use_id: toolCall.id, + name: toolCall.function.name, + content: result + }); + } + } + + return newMessages; + } +} + diff --git a/packages/ai-chat/src/browser/chat-session-store-impl.spec.ts b/packages/ai-chat/src/browser/chat-session-store-impl.spec.ts index 61e862d980739..a57b905cf1ba4 100644 --- a/packages/ai-chat/src/browser/chat-session-store-impl.spec.ts +++ b/packages/ai-chat/src/browser/chat-session-store-impl.spec.ts @@ -34,6 +34,7 @@ import { BinaryBuffer } from '@theia/core/lib/common/buffer'; import { ChatSessionIndex, ChatSessionMetadata } from '../common/chat-session-store'; import { PERSISTED_SESSION_LIMIT_PREF } from '../common/ai-chat-preferences'; import { ChatAgentLocation } from '../common/chat-agents'; +import { ChatSessionTokenTracker } from './chat-session-token-tracker'; disableJSDOM(); @@ -113,6 +114,12 @@ describe('ChatSessionStoreImpl', () => { container.bind('ChatSessionStore').toConstantValue(mockLogger); container.bind(ILogger).toConstantValue(mockLogger).whenTargetNamed('ChatSessionStore'); + const mockTokenTracker = { + getSessionInputTokens: sandbox.stub().returns(undefined), + getBranchTokensForSession: sandbox.stub().returns(undefined) + } as unknown as ChatSessionTokenTracker; + container.bind(ChatSessionTokenTracker).toConstantValue(mockTokenTracker); + container.bind(ChatSessionStoreImpl).toSelf().inSingletonScope(); chatSessionStore = container.get(ChatSessionStoreImpl); diff --git a/packages/ai-chat/src/browser/chat-session-store-impl.ts b/packages/ai-chat/src/browser/chat-session-store-impl.ts index f4ee2a062ce17..8ea2ae984eb20 100644 --- a/packages/ai-chat/src/browser/chat-session-store-impl.ts +++ b/packages/ai-chat/src/browser/chat-session-store-impl.ts @@ -27,6 +27,7 @@ import { ChatModel } from '../common/chat-model'; import { ChatSessionIndex, ChatSessionStore, ChatModelWithMetadata, ChatSessionMetadata } from '../common/chat-session-store'; import { PERSISTED_SESSION_LIMIT_PREF } from '../common/ai-chat-preferences'; import { SerializedChatData, CHAT_DATA_VERSION } from '../common/chat-model-serialization'; +import { ChatSessionTokenTracker } from './chat-session-token-tracker'; const INDEX_FILE = 'index.json'; @@ -50,6 +51,9 @@ export class ChatSessionStoreImpl implements ChatSessionStore { @inject(PreferenceService) protected readonly preferenceService: PreferenceService; + @inject(ChatSessionTokenTracker) + protected readonly tokenTracker: ChatSessionTokenTracker; + protected storageRoot?: URI; protected indexCache?: ChatSessionIndex; protected storePromise: Promise = Promise.resolve(); @@ -75,7 +79,9 @@ export class ChatSessionStoreImpl implements ChatSessionStore { title: session.title, pinnedAgentId: session.pinnedAgentId, saveDate: session.saveDate, - model: modelData + model: modelData, + lastInputTokens: this.tokenTracker.getSessionInputTokens(session.model.id), + branchTokens: this.tokenTracker.getBranchTokensForSession(session.model.id) }; this.logger.debug('Writing session to file', { sessionId: session.model.id, diff --git a/packages/ai-chat/src/browser/chat-session-summarization-service.spec.ts b/packages/ai-chat/src/browser/chat-session-summarization-service.spec.ts new file mode 100644 index 0000000000000..1eb1fb036bce0 --- /dev/null +++ b/packages/ai-chat/src/browser/chat-session-summarization-service.spec.ts @@ -0,0 +1,560 @@ +// ***************************************************************************** +// Copyright (C) 2025 EclipseSource GmbH. +// +// This program and the accompanying materials are made available under the +// terms of the Eclipse Public License v. 2.0 which is available at +// http://www.eclipse.org/legal/epl-2.0. +// +// This Source Code may also be made available under the following Secondary +// Licenses when the conditions for such availability set forth in the Eclipse +// Public License v. 2.0 are satisfied: GNU General Public License, version 2 +// with the GNU Classpath Exception which is available at +// https://www.gnu.org/software/classpath/license.html. +// +// SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-only WITH Classpath-exception-2.0 +// ***************************************************************************** + +import { expect } from 'chai'; +import * as sinon from 'sinon'; +import { Container } from '@theia/core/shared/inversify'; +import { Emitter, ILogger } from '@theia/core'; +import { ToolCall, ToolCallResult, UsageResponsePart } from '@theia/ai-core'; +import { ChatSessionSummarizationServiceImpl } from './chat-session-summarization-service'; +import { ChatSessionTokenTracker } from './chat-session-token-tracker'; +import { ChatService, SessionCreatedEvent, SessionDeletedEvent } from '../common/chat-service'; +import { ChatSession } from '../common'; +import { ChatSessionStore } from '../common/chat-session-store'; + +describe('ChatSessionSummarizationServiceImpl', () => { + let container: Container; + let service: ChatSessionSummarizationServiceImpl; + let tokenTracker: sinon.SinonStubbedInstance; + let chatService: sinon.SinonStubbedInstance; + let logger: sinon.SinonStubbedInstance; + + let sessionEventEmitter: Emitter; + let sessionRegistry: Map; + let sessionStore: sinon.SinonStubbedInstance; + + // Helper to create a mock UsageResponsePart + function createUsageResponsePart(params: { + input_tokens: number; + output_tokens: number; + cache_creation_input_tokens?: number; + cache_read_input_tokens?: number; + }): UsageResponsePart { + return { + input_tokens: params.input_tokens, + output_tokens: params.output_tokens, + cache_creation_input_tokens: params.cache_creation_input_tokens, + cache_read_input_tokens: params.cache_read_input_tokens + }; + } + + // Helper to create a mock session + function createMockSession(sessionId: string, activeBranchId: string, branches: { id: string }[] = []): ChatSession { + const modelChangeEmitter = new Emitter(); + const allBranches = branches.length > 0 ? branches : [{ id: activeBranchId }]; + return { + id: sessionId, + isActive: true, + model: { + getBranch: sinon.stub().callsFake((requestId: string) => { + // Return branch based on requestId pattern: 'request-for-branchX' => { id: 'branchX' } + const match = requestId.match(/request-for-(.+)/); + if (match) { + return { id: match[1] }; + } + return undefined; + }), + getBranches: sinon.stub().returns(allBranches), + getRequest: sinon.stub().callsFake((requestId: string) => { + if (requestId.includes('summary')) { + return { request: { kind: 'summary' } }; + } + return { request: { kind: 'user' } }; + }), + onDidChange: modelChangeEmitter.event + } + } as unknown as ChatSession; + } + + beforeEach(() => { + container = new Container(); + + // Create emitter for session event simulation + sessionEventEmitter = new Emitter(); + + // Create session registry for dynamic lookup + sessionRegistry = new Map(); + + // Create stubs + const branchTokensMap = new Map(); + tokenTracker = { + resetSessionTokens: sinon.stub(), + getSessionInputTokens: sinon.stub(), + getSessionOutputTokens: sinon.stub(), + getSessionTotalTokens: sinon.stub(), + updateSessionTokens: sinon.stub(), + onSessionTokensUpdated: sinon.stub(), + setBranchTokens: sinon.stub().callsFake((sessionId: string, branchId: string, tokens: number) => { + branchTokensMap.set(`${sessionId}:${branchId}`, tokens); + }), + getBranchTokens: sinon.stub().callsFake((sessionId: string, branchId: string) => branchTokensMap.get(`${sessionId}:${branchId}`)), + getBranchTokensForSession: sinon.stub().callsFake((sessionId: string) => { + const result: { [branchId: string]: number } = {}; + const prefix = `${sessionId}:`; + for (const [key, value] of branchTokensMap.entries()) { + if (key.startsWith(prefix)) { + const branchId = key.substring(prefix.length); + result[branchId] = value; + } + } + return result; + }), + restoreBranchTokens: sinon.stub().callsFake((sessionId: string, branchTokens: { [branchId: string]: number }) => { + for (const [branchId, tokens] of Object.entries(branchTokens)) { + branchTokensMap.set(`${sessionId}:${branchId}`, tokens); + } + }), + clearSessionBranchTokens: sinon.stub().callsFake((sessionId: string) => { + const prefix = `${sessionId}:`; + for (const key of branchTokensMap.keys()) { + if (key.startsWith(prefix)) { + branchTokensMap.delete(key); + } + } + }) + } as unknown as sinon.SinonStubbedInstance; + + chatService = { + getSession: sinon.stub().callsFake((sessionId: string) => sessionRegistry.get(sessionId)), + getSessions: sinon.stub().returns([]), + onSessionEvent: sessionEventEmitter.event, + sendRequest: sinon.stub() + } as unknown as sinon.SinonStubbedInstance; + + logger = { + info: sinon.stub(), + warn: sinon.stub(), + error: sinon.stub(), + debug: sinon.stub() + } as unknown as sinon.SinonStubbedInstance; + + sessionStore = { + storeSessions: sinon.stub().resolves(), + readSession: sinon.stub().resolves(undefined), + deleteSession: sinon.stub().resolves(), + clearAllSessions: sinon.stub().resolves(), + getSessionIndex: sinon.stub().resolves({}) + } as unknown as sinon.SinonStubbedInstance; + + // Bind to container + container.bind(ChatSessionTokenTracker).toConstantValue(tokenTracker); + container.bind(ChatService).toConstantValue(chatService); + container.bind(ILogger).toConstantValue(logger); + container.bind(ChatSessionStore).toConstantValue(sessionStore); + container.bind(ChatSessionSummarizationServiceImpl).toSelf().inSingletonScope(); + + service = container.get(ChatSessionSummarizationServiceImpl); + // Manually call init since @postConstruct won't run in tests + (service as unknown as { init: () => void }).init(); + }); + + afterEach(() => { + sinon.restore(); + sessionEventEmitter.dispose(); + sessionRegistry.clear(); + }); + + describe('markPendingSplit', () => { + it('should store pending split data', async () => { + const sessionId = 'session-1'; + const requestId = 'request-1'; + const pendingToolCalls: ToolCall[] = [ + { id: 'tool-1', function: { name: 'test_tool', arguments: '{}' }, finished: false } + ]; + const toolResults = new Map([['tool-1', 'result']]); + + // Create mock session for handleMidTurnSplit + const session = createMockSession(sessionId, 'branch-1'); + sessionRegistry.set(sessionId, session); + const modelStub = session.model as sinon.SinonStubbedInstance; + (modelStub as unknown as { addRequest: sinon.SinonStub }).addRequest = sinon.stub().returns({ + id: 'new-request', + response: { + response: { + content: [], + clearContent: sinon.stub(), + addContent: sinon.stub(), + addContents: sinon.stub(), + asDisplayString: sinon.stub().returns('Summary text') + } + } + }); + (modelStub as unknown as { getRequests: sinon.SinonStub }).getRequests = sinon.stub().returns([]); + + service.markPendingSplit(sessionId, requestId, pendingToolCalls, toolResults); + + // Verify pending split is stored by calling checkAndHandleSummarization + // which consumes the pending split and returns true + const mockAgent = { invoke: sinon.stub().resolves() }; + const mockResponse = { + isComplete: false, + complete: sinon.stub(), + response: { + content: [], + clearContent: sinon.stub(), + addContent: sinon.stub(), + addContents: sinon.stub() + } + }; + const mockRequest = { + id: requestId, + request: { kind: 'user' }, + response: mockResponse + }; + const result = await service.checkAndHandleSummarization( + sessionId, + mockAgent as unknown as import('../common').ChatAgent, + mockRequest as unknown as import('../common').MutableChatRequestModel, + undefined + ); + + // Should return true because pending split was consumed + expect(result).to.be.true; + }); + }); + + describe('checkAndHandleSummarization', () => { + it('should return false when request kind is summary', async () => { + const sessionId = 'session-1'; + const mockAgent = { invoke: sinon.stub() }; + const mockRequest = { + id: 'request-1', + request: { kind: 'summary' }, + response: { isComplete: false } + }; + const usage = createUsageResponsePart({ input_tokens: 100, output_tokens: 50 }); + + const result = await service.checkAndHandleSummarization( + sessionId, + mockAgent as unknown as import('../common').ChatAgent, + mockRequest as unknown as import('../common').MutableChatRequestModel, + usage + ); + + expect(result).to.be.false; + }); + + it('should return false when request kind is continuation and below threshold', async () => { + const sessionId = 'session-1'; + const mockAgent = { invoke: sinon.stub() }; + const mockRequest = { + id: 'request-1', + request: { kind: 'continuation' }, + response: { isComplete: false } + }; + const usage = createUsageResponsePart({ input_tokens: 100, output_tokens: 50 }); // Below threshold + + const result = await service.checkAndHandleSummarization( + sessionId, + mockAgent as unknown as import('../common').ChatAgent, + mockRequest as unknown as import('../common').MutableChatRequestModel, + usage + ); + + expect(result).to.be.false; + // Verify token tracker was still updated for continuation requests + expect(tokenTracker.updateSessionTokens.calledWith(sessionId, 100, 50)).to.be.true; + }); + + it('should not skip continuation request when it exceeds threshold', async () => { + const sessionId = 'session-1'; + const session = createMockSession(sessionId, 'branch-1'); + sessionRegistry.set(sessionId, session); + + const mockAgent = { invoke: sinon.stub() }; + const completeStub = sinon.stub(); + const mockRequest = { + id: 'request-1', + request: { kind: 'continuation' }, + response: { + isComplete: false, + complete: completeStub + } + }; + // 7000 tokens > CHAT_TOKEN_THRESHOLD (6300) + const usage = createUsageResponsePart({ input_tokens: 7000, output_tokens: 500 }); + + // Mock model.insertSummary for performSummarization + const modelStub = session.model as sinon.SinonStubbedInstance; + (modelStub as unknown as { insertSummary: sinon.SinonStub }).insertSummary = sinon.stub().resolves('Summary text'); + + // Call the method - it may or may not fully complete summarization + // depending on mocks, but the key behavior is that it doesn't skip + // the threshold check for continuation requests when above threshold + await service.checkAndHandleSummarization( + sessionId, + mockAgent as unknown as import('../common').ChatAgent, + mockRequest as unknown as import('../common').MutableChatRequestModel, + usage + ); + + // Verify token tracker was updated with high token values + // This confirms the method processed the usage data and didn't skip early + expect(tokenTracker.updateSessionTokens.calledWith(sessionId, 7000, 500)).to.be.true; + }); + + it('should return false when tokens are below threshold', async () => { + const sessionId = 'session-1'; + + const mockAgent = { invoke: sinon.stub() }; + const mockRequest = { + id: 'request-1', + request: { kind: 'user' }, + response: { isComplete: false } + }; + const usage = createUsageResponsePart({ input_tokens: 100, output_tokens: 50 }); // Below threshold + + const result = await service.checkAndHandleSummarization( + sessionId, + mockAgent as unknown as import('../common').ChatAgent, + mockRequest as unknown as import('../common').MutableChatRequestModel, + usage + ); + + expect(result).to.be.false; + }); + + it('should update token tracker with usage data', async () => { + const sessionId = 'session-1'; + const mockAgent = { invoke: sinon.stub() }; + const mockRequest = { + id: 'request-1', + request: { kind: 'user' }, + response: { isComplete: false } + }; + const usage = createUsageResponsePart({ + input_tokens: 1000, + output_tokens: 200, + cache_creation_input_tokens: 100, + cache_read_input_tokens: 50 + }); + + await service.checkAndHandleSummarization( + sessionId, + mockAgent as unknown as import('../common').ChatAgent, + mockRequest as unknown as import('../common').MutableChatRequestModel, + usage + ); + + // Total input = input_tokens + cache_creation + cache_read = 1000 + 100 + 50 = 1150 + expect(tokenTracker.updateSessionTokens.calledWith(sessionId, 1150, 200)).to.be.true; + }); + + it('should consume pending split and handle mid-turn split', async () => { + const sessionId = 'session-1'; + const requestId = 'request-1'; + const pendingToolCalls: ToolCall[] = [ + { id: 'tool-1', function: { name: 'test_tool', arguments: '{}' }, finished: false } + ]; + const toolResults = new Map([['tool-1', 'result']]); + + // Create mock session + const session = createMockSession(sessionId, 'branch-1'); + sessionRegistry.set(sessionId, session); + + // Mark pending split + service.markPendingSplit(sessionId, requestId, pendingToolCalls, toolResults); + + const mockAgent = { invoke: sinon.stub().resolves() }; + const mockResponse = { + isComplete: false, + complete: sinon.stub(), + response: { + content: [], + clearContent: sinon.stub(), + addContent: sinon.stub(), + addContents: sinon.stub() + } + }; + const mockRequest = { + id: requestId, + request: { kind: 'user' }, + response: mockResponse + }; + const usage = createUsageResponsePart({ input_tokens: 100, output_tokens: 50 }); + + // Mock model methods needed for handleMidTurnSplit + const modelStub = session.model as sinon.SinonStubbedInstance; + (modelStub as unknown as { addRequest: sinon.SinonStub }).addRequest = sinon.stub().returns({ + id: 'new-request', + response: { + response: { + content: [], + clearContent: sinon.stub(), + addContent: sinon.stub(), + addContents: sinon.stub(), + asDisplayString: sinon.stub().returns('Summary text') + } + } + }); + (modelStub as unknown as { getRequests: sinon.SinonStub }).getRequests = sinon.stub().returns([]); + + const result = await service.checkAndHandleSummarization( + sessionId, + mockAgent as unknown as import('../common').ChatAgent, + mockRequest as unknown as import('../common').MutableChatRequestModel, + usage + ); + + // Should return true because pending split was handled + expect(result).to.be.true; + // Response should be completed + expect(mockResponse.complete.called).to.be.true; + }); + }); + + describe('branch change handling', () => { + it('should restore stored tokens when branch changes', () => { + const sessionId = 'session-4'; + const branchA = 'branch-A'; + const branchB = 'branch-B'; + + // Pre-populate branch tokens via tracker + tokenTracker.setBranchTokens(sessionId, branchA, 2000); + tokenTracker.setBranchTokens(sessionId, branchB, 4000); + + // Create session and fire created event to set up branch change listener + const modelChangeEmitter = new Emitter(); + const session = { + id: sessionId, + isActive: true, + model: { + getBranch: sinon.stub(), + getBranches: sinon.stub().returns([{ id: branchB }]), + getRequest: sinon.stub(), + onDidChange: modelChangeEmitter.event + } + } as unknown as ChatSession; + + sessionRegistry.set(sessionId, session); + + // Fire session created event to set up listener + sessionEventEmitter.fire({ type: 'created', sessionId }); + + // Simulate branch change to branch A + modelChangeEmitter.fire({ + kind: 'changeHierarchyBranch', + branch: { id: branchA } + }); + + // Verify tokenTracker.resetSessionTokens was called with branch A's tokens + expect(tokenTracker.resetSessionTokens.calledWith(sessionId, 2000)).to.be.true; + }); + + it('should emit undefined when switching to branch with no stored tokens', () => { + const sessionId = 'session-5'; + const unknownBranchId = 'branch-unknown'; + + // Create session without pre-populating tokens for the unknown branch + const modelChangeEmitter = new Emitter(); + const session = { + id: sessionId, + isActive: true, + model: { + getBranch: sinon.stub(), + getBranches: sinon.stub().returns([{ id: 'branch-other' }]), + getRequest: sinon.stub(), + onDidChange: modelChangeEmitter.event + } + } as unknown as ChatSession; + + sessionRegistry.set(sessionId, session); + + // Fire session created event to set up listener + sessionEventEmitter.fire({ type: 'created', sessionId }); + + // Simulate branch change to unknown branch + modelChangeEmitter.fire({ + kind: 'changeHierarchyBranch', + branch: { id: unknownBranchId } + }); + + // Verify tokenTracker.resetSessionTokens was called with undefined + expect(tokenTracker.resetSessionTokens.calledWith(sessionId, undefined)).to.be.true; + }); + + it('should populate branchTokens on persistence restore', () => { + const sessionId = 'restored-session'; + const activeBranchId = 'branch-restored'; + + // Create session + const modelChangeEmitter = new Emitter(); + const session = { + id: sessionId, + isActive: true, + model: { + getBranch: sinon.stub(), + getBranches: sinon.stub().returns([{ id: activeBranchId }]), + getRequest: sinon.stub(), + onDidChange: modelChangeEmitter.event + } + } as unknown as ChatSession; + + sessionRegistry.set(sessionId, session); + + // Fire session created event with branchTokens data + const branchTokensData = { + 'branch-restored': 8000, + 'branch-other': 3000 + }; + sessionEventEmitter.fire({ + type: 'created', + sessionId, + branchTokens: branchTokensData + }); + + // Verify restoreBranchTokens was called with correct data + expect((tokenTracker.restoreBranchTokens as sinon.SinonStub).calledWith(sessionId, branchTokensData)).to.be.true; + + // Verify getBranchTokensForSession returns restored data + const branchTokens = tokenTracker.getBranchTokensForSession(sessionId); + expect(branchTokens).to.deep.equal(branchTokensData); + }); + }); + + describe('cleanupSession', () => { + it('should clean up all session data when session is deleted', () => { + const sessionId = 'session-to-cleanup'; + + // Pre-populate branch tokens via tracker + tokenTracker.setBranchTokens(sessionId, 'branch-A', 1000); + tokenTracker.setBranchTokens(sessionId, 'branch-B', 2000); + tokenTracker.setBranchTokens('other-session', 'branch-X', 5000); + + // Add pending split + service.markPendingSplit(sessionId, 'request-1', [], new Map()); + + // Add to triggeredBranches + const triggeredBranchesSet = (service as unknown as { triggeredBranches: Set }).triggeredBranches; + triggeredBranchesSet.add(`${sessionId}: branch-A`); + triggeredBranchesSet.add(`${sessionId}: branch-B`); + triggeredBranchesSet.add('other-session: branch-X'); + + // Fire session deleted event + sessionEventEmitter.fire({ type: 'deleted', sessionId }); + + // Verify clearSessionBranchTokens was called + expect((tokenTracker.clearSessionBranchTokens as sinon.SinonStub).calledWith(sessionId)).to.be.true; + + // Verify triggeredBranches entries for deleted session are removed + expect(triggeredBranchesSet.has(`${sessionId}: branch-A`)).to.be.false; + expect(triggeredBranchesSet.has(`${sessionId}: branch-B`)).to.be.false; + + // Verify other session's triggeredBranches entries are preserved + expect(triggeredBranchesSet.has('other-session: branch-X')).to.be.true; + }); + }); +}); diff --git a/packages/ai-chat/src/browser/chat-session-summarization-service.ts b/packages/ai-chat/src/browser/chat-session-summarization-service.ts new file mode 100644 index 0000000000000..fb9167809100c --- /dev/null +++ b/packages/ai-chat/src/browser/chat-session-summarization-service.ts @@ -0,0 +1,521 @@ +// ***************************************************************************** +// Copyright (C) 2025 EclipseSource GmbH. +// +// This program and the accompanying materials are made available under the +// terms of the Eclipse Public License v. 2.0 which is available at +// http://www.eclipse.org/legal/epl-2.0. +// +// This Source Code may also be made available under the following Secondary +// Licenses when the conditions for such availability set forth in the Eclipse +// Public License v. 2.0 are satisfied: GNU General Public License, version 2 +// with the GNU Classpath Exception which is available at +// https://www.gnu.org/software/classpath/license.html. +// +// SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-only WITH Classpath-exception-2.0 +// ***************************************************************************** + +import { inject, injectable, postConstruct } from '@theia/core/shared/inversify'; +import { ILogger, nls } from '@theia/core'; +import { FrontendApplicationContribution } from '@theia/core/lib/browser'; +import { ToolCall, ToolCallResult, UsageResponsePart } from '@theia/ai-core'; +import { + ChatAgent, + ChatService, + ChatSession, + ErrorChatResponseContent, + ErrorChatResponseContentImpl, + MutableChatModel, + MutableChatRequestModel, + MutableChatResponseModel, + ParsedChatRequest, + SummaryChatResponseContentImpl, + ToolCallChatResponseContent, + ToolCallChatResponseContentImpl +} from '../common'; +import { isSessionCreatedEvent, isSessionDeletedEvent } from '../common/chat-service'; +import { + CHAT_TOKEN_THRESHOLD, + ChatSessionTokenTracker +} from './chat-session-token-tracker'; + +export const ChatSessionSummarizationService = Symbol('ChatSessionSummarizationService'); + +/** + * Service that automatically summarizes chat sessions when token usage exceeds the threshold. + * + * When the threshold is exceeded: + * 1. Marks older messages as stale (excluding them from future prompts) + * 2. Invokes ChatSessionSummaryAgent to generate a summary + * 3. Inserts a summary node into the chat + */ +export interface ChatSessionSummarizationService { + /** + * Check if a session has been summarized (has stale messages). + */ + hasSummary(sessionId: string): boolean; + + /** + * Mark a pending mid-turn split. Called by the tool loop when budget exceeded. + * The split will be handled by checkAndHandleSummarization() after addContentsToResponse(). + */ + markPendingSplit( + sessionId: string, + requestId: string, + pendingToolCalls: ToolCall[], + toolResults: Map + ): void; + + /** + * Check and handle summarization after response content is added. + * Handles both mid-turn splits (from markPendingSplit) and between-turn summarization. + * + * @param sessionId The session ID + * @param agent The chat agent to invoke for summary/continuation + * @param request The current request being processed + * @param usage Usage data from the response stream for synchronous token tracking. + * May be undefined if the stream ended early (e.g., due to budget-exceeded split). + * @returns true if summarization was triggered (caller should skip onResponseComplete), false otherwise + */ + checkAndHandleSummarization( + sessionId: string, + agent: ChatAgent, + request: MutableChatRequestModel, + usage: UsageResponsePart | undefined + ): Promise; +} + +@injectable() +export class ChatSessionSummarizationServiceImpl implements ChatSessionSummarizationService, FrontendApplicationContribution { + @inject(ChatSessionTokenTracker) + protected readonly tokenTracker: ChatSessionTokenTracker; + + @inject(ChatService) + protected readonly chatService: ChatService; + + @inject(ILogger) + protected readonly logger: ILogger; + + /** + * Set of sessionIds currently being summarized to prevent concurrent summarization. + */ + protected summarizingSession = new Set(); + + /** + * Tracks which branches have triggered summarization. + * Key format: `${sessionId}:${branchId}` + * + * Used for deduplication (prevents multiple triggers for the same branch during a single growth cycle). + * + * **Cleanup behavior:** + * - After successful summarization: The branch key is REMOVED to allow future re-triggering + * when tokens grow again past the threshold. + * - On session deletion: All entries with matching sessionId prefix are removed. + * - On branch change: New branches automatically get fresh tracking since their branchId + * differs from previously tracked branches. + */ + protected triggeredBranches: Set = new Set(); + + /** + * Stores pending mid-turn split data, keyed by sessionId. + * Consumed by checkAndHandleSummarization() after addContentsToResponse(). + */ + protected pendingSplits = new Map; + }>(); + + @postConstruct() + protected init(): void { + // Listen for new sessions and set up branch change listeners + this.chatService.onSessionEvent(event => { + if (isSessionCreatedEvent(event)) { + const session = this.chatService.getSession(event.sessionId); + if (session) { + this.setupBranchChangeListener(session); + } + // Restore branch tokens from persisted data + if (event.branchTokens) { + this.tokenTracker.restoreBranchTokens(event.sessionId, event.branchTokens); + } + // Emit initial token count for active branch + if (session) { + const activeBranchId = this.getActiveBranchId(session); + if (activeBranchId) { + const tokens = this.tokenTracker.getBranchTokens(event.sessionId, activeBranchId); + this.tokenTracker.resetSessionTokens(event.sessionId, tokens); + } + } + } else if (isSessionDeletedEvent(event)) { + this.cleanupSession(event.sessionId); + } + }); + } + + /** + * Called when the frontend application starts. + * Required by FrontendApplicationContribution to ensure this service is instantiated. + */ + onStart(): void { + // Set up branch change listeners for existing sessions + for (const session of this.chatService.getSessions()) { + this.setupBranchChangeListener(session); + } + } + + /** + * Get the active branch ID for a session. + */ + protected getActiveBranchId(session: ChatSession): string | undefined { + return session.model.getBranches().at(-1)?.id; + } + + /** + * Set up a listener for branch changes in a chat session. + * When a branch change occurs (e.g., user edits an older message), reset token tracking. + */ + protected setupBranchChangeListener(session: ChatSession): void { + session.model.onDidChange(event => { + if (event.kind === 'changeHierarchyBranch') { + const storedTokens = this.tokenTracker.getBranchTokens(session.id, event.branch.id); + this.tokenTracker.resetSessionTokens(session.id, storedTokens); + } + }); + } + + markPendingSplit( + sessionId: string, + requestId: string, + pendingToolCalls: ToolCall[], + toolResults: Map + ): void { + this.pendingSplits.set(sessionId, { requestId, pendingToolCalls, toolResults }); + } + + /** + * Update token tracking during streaming. + * Called when usage data is received in the stream, before the response completes. + * This enables real-time token count updates in the UI. + */ + updateTokens(sessionId: string, usage: UsageResponsePart): void { + const totalInputTokens = usage.input_tokens + (usage.cache_creation_input_tokens ?? 0) + (usage.cache_read_input_tokens ?? 0); + this.tokenTracker.updateSessionTokens(sessionId, totalInputTokens, usage.output_tokens); + } + + async checkAndHandleSummarization( + sessionId: string, + agent: ChatAgent, + request: MutableChatRequestModel, + usage: UsageResponsePart | undefined + ): Promise { + // Check for pending mid-turn split first (may exist even without usage data) + const pendingSplit = this.pendingSplits.get(sessionId); + if (pendingSplit) { + // Consume immediately to prevent re-entry + this.pendingSplits.delete(sessionId); + await this.handleMidTurnSplit(sessionId, agent, request, pendingSplit); + return true; + } + + // If no usage data, nothing more to do + if (!usage) { + return false; + } + + // Always skip summary requests before any token work + if (request.request.kind === 'summary') { + return false; + } + + // Calculate tokens for all other requests (user and continuation) + const totalInputTokens = usage.input_tokens + (usage.cache_creation_input_tokens ?? 0) + (usage.cache_read_input_tokens ?? 0); + this.tokenTracker.updateSessionTokens(sessionId, totalInputTokens, usage.output_tokens); + + // Skip continuation requests only if below threshold + if (request.request.kind === 'continuation' && totalInputTokens < CHAT_TOKEN_THRESHOLD) { + return false; + } + + // Check threshold with fresh data + if (totalInputTokens >= CHAT_TOKEN_THRESHOLD) { + const session = this.chatService.getSession(sessionId); + if (session) { + // Complete current response first if not already + if (!request.response.isComplete) { + request.response.complete(); + } + await this.performSummarization(sessionId, session.model as MutableChatModel); + return true; + } else { + this.logger.warn(`Session ${sessionId} not found for between-turn summarization`); + } + } + + return false; + } + + protected async handleMidTurnSplit( + sessionId: string, + agent: ChatAgent, + request: MutableChatRequestModel, + pendingSplit: { requestId: string; pendingToolCalls: ToolCall[]; toolResults: Map } + ): Promise { + const session = this.chatService.getSession(sessionId); + if (!session) { + this.logger.warn(`Session ${sessionId} not found for mid-turn split`); + return; + } + const model = session.model as MutableChatModel; + + // Step 1: Remove pending tool calls from current response + this.removePendingToolCallsFromResponse(request.response, pendingSplit.pendingToolCalls); + + // Step 2: Complete current response + request.response.complete(); + + // Step 3: Create summary request (stale marking deferred so summary sees full history) + // eslint-disable-next-line max-len + const summaryPrompt = 'Please provide a concise summary of our conversation so far, capturing all key requirements, decisions, context, and pending tasks so we can seamlessly continue. Do not include conversational elements, questions, or offers to continue. Do not start with a heading - output only the summary content.'; + + const summaryParsedRequest: ParsedChatRequest = { + request: { text: summaryPrompt, kind: 'summary' }, + parts: [{ kind: 'text', text: summaryPrompt, promptText: summaryPrompt, range: { start: 0, endExclusive: summaryPrompt.length } }], + toolRequests: new Map(), + variables: [] + }; + const summaryRequest = model.addRequest(summaryParsedRequest, undefined, { variables: [] }); + + // Invoke agent for summary (will populate summaryRequest.response) + await agent.invoke(summaryRequest); + + // Reset token tracking with summary output as new baseline BEFORE continuation + const summaryOutputTokens = this.tokenTracker.getSessionOutputTokens(sessionId) ?? 0; + this.updateTokenTrackingAfterSummary(sessionId, summaryOutputTokens); + + // Get summary text from response + const summaryText = summaryRequest.response.response.asDisplayString()?.trim() || ''; + + // Replace response content with SummaryChatResponseContent for proper rendering + summaryRequest.response.response.clearContent(); + summaryRequest.response.response.addContent(new SummaryChatResponseContentImpl(summaryText)); + + // Step 4: Mark ALL requests stale AFTER summary is generated (summary needed full history) + const allRequestsAfterSummary = model.getRequests(); + for (const req of allRequestsAfterSummary) { + // Don't mark continuation request stale (it will be created next) + if (req.request.kind !== 'continuation') { + (req as MutableChatRequestModel).isStale = true; + } + } + + // Step 5: Create continuation request with summary and pending tool calls + const continuationSuffix = 'The tool call above was executed. Please continue with your task ' + + 'based on the result. If you need to make more tool calls to complete the task, please do so. ' + + 'Once you have all the information needed, provide your final response.'; + const continuationInstruction = `${summaryText}\n\n${continuationSuffix}`; + + const continuationParsedRequest: ParsedChatRequest = { + request: { text: continuationInstruction, kind: 'continuation' }, + parts: [{ kind: 'text', text: continuationInstruction, promptText: continuationInstruction, range: { start: 0, endExclusive: continuationInstruction.length } }], + toolRequests: new Map(), + variables: [] + }; + const continuationRequest = model.addRequest(continuationParsedRequest, undefined, { variables: [] }); + + // Add tool call content to response for UI display + for (const toolCall of pendingSplit.pendingToolCalls) { + const result = pendingSplit.toolResults.get(toolCall.id!); + const toolContent = new ToolCallChatResponseContentImpl( + toolCall.id, + toolCall.function?.name, + toolCall.function?.arguments, + true, // finished + result + ); + continuationRequest.response.response.addContent(toolContent); + } + + // Step 6: Invoke agent for continuation (token tracking will update normally) + await agent.invoke(continuationRequest); + } + + protected removePendingToolCallsFromResponse( + response: MutableChatResponseModel, + pendingToolCalls: ToolCall[] + ): void { + const pendingIds = new Set(pendingToolCalls.map(tc => tc.id).filter(Boolean)); + const content = response.response.content; + + const filteredContent = content.filter(c => { + if (ToolCallChatResponseContent.is(c) && c.id && pendingIds.has(c.id)) { + return false; + } + return true; + }); + + response.response.clearContent(); + response.response.addContents(filteredContent); + } + + /** + * Execute a callback with summarization lock for the session. + * Ensures lock is released even if callback throws. + */ + protected async withSummarizationLock( + sessionId: string, + callback: () => Promise + ): Promise { + if (this.summarizingSession.has(sessionId)) { + return undefined; + } + this.summarizingSession.add(sessionId); + try { + return await callback(); + } finally { + this.summarizingSession.delete(sessionId); + } + } + + /** + * Update token tracking after successful summarization. + */ + protected updateTokenTrackingAfterSummary( + sessionId: string, + outputTokens: number + ): void { + this.tokenTracker.resetSessionTokens(sessionId, outputTokens); + // Update branch tokens and allow future re-triggering + const session = this.chatService.getSession(sessionId); + if (session) { + const activeBranchId = this.getActiveBranchId(session); + if (activeBranchId) { + const branchKey = `${sessionId}:${activeBranchId}`; + this.tokenTracker.setBranchTokens(sessionId, activeBranchId, outputTokens); + this.triggeredBranches.delete(branchKey); + } + } + } + + /** + * Core summarization logic shared by both threshold-triggered and explicit mid-turn summarization. + * + * @param skipReorder If true, skip removing/re-adding the trigger request (for mid-turn summarization + * where the request is actively being processed with tool calls) + * @returns The summary text on success, or `undefined` on failure + */ + protected async performSummarization(sessionId: string, model: MutableChatModel, skipReorder?: boolean): Promise { + return this.withSummarizationLock(sessionId, async () => { + // Always use 'end' position - other positions break hierarchy structure + const position = 'end'; + // eslint-disable-next-line max-len + const summaryPrompt = 'Please provide a concise summary of our conversation so far, capturing all key requirements, decisions, context, and pending tasks so we can seamlessly continue. ' + + 'Do not include conversational elements, questions, or offers to continue. Do not start with a heading - output only the summary content.'; + + try { + const summaryText = await model.insertSummary( + async () => { + const invocation = await this.chatService.sendRequest(sessionId, { + text: summaryPrompt, + kind: 'summary' + }); + if (!invocation) { + return undefined; + } + + const request = await invocation.requestCompleted; + const response = await invocation.responseCompleted; + + // Validate response + const summaryResponseText = response.response.asDisplayString()?.trim(); + if (response.isError) { + this.logger.error(`Summary response has error: ${response.errorObject?.message}`); + return undefined; + } + if (!summaryResponseText) { + this.logger.error(`Summary response text is empty. Content count: ${response.response.content.length}, ` + + `content kinds: ${response.response.content.map(c => c.kind).join(', ')}`); + return undefined; + } + + // Replace agent's markdown content with SummaryChatResponseContent for proper rendering + const mutableRequest = request as MutableChatRequestModel; + mutableRequest.response.response.clearContent(); + mutableRequest.response.response.addContent(new SummaryChatResponseContentImpl(summaryResponseText)); + + return { + requestId: request.id, + summaryText: summaryResponseText + }; + }, + position + ); + + if (!summaryText) { + this.logger.warn(`Summarization failed for session ${sessionId}`); + this.notifyUserOfFailure(model); + return undefined; + } + + // Get output tokens from tracker (handleTokenUsage now tracks summary requests) + const outputTokens = this.tokenTracker.getSessionOutputTokens(sessionId) ?? 0; + + this.updateTokenTrackingAfterSummary(sessionId, outputTokens); + + return summaryText; + } catch (error) { + this.logger.error(`Failed to summarize session ${sessionId}: `, error); + this.notifyUserOfFailure(model); + return undefined; + } + }); + } + + /** + * Notify the user that summarization failed by adding an error message to the chat. + */ + protected notifyUserOfFailure(model: MutableChatModel): void { + const requests = model.getRequests(); + const currentRequest = requests.at(-1); + if (!currentRequest) { + return; + } + + // Avoid duplicate warnings + const lastContent = currentRequest.response.response.content.at(-1); + const alreadyWarned = ErrorChatResponseContent.is(lastContent) && + lastContent.error.message.includes('summarization'); + if (!alreadyWarned) { + const errorMessage = nls.localize( + 'theia/ai-chat/summarizationFailed', + 'Chat summarization failed. The conversation is approaching token limits and may fail soon. Consider starting a new chat session or reducing context to continue.' + ); + currentRequest.response.response.addContent( + new ErrorChatResponseContentImpl(new Error(errorMessage)) + ); + } + } + + hasSummary(sessionId: string): boolean { + const session = this.chatService.getSession(sessionId); + if (!session) { + return false; + } + + const model = session.model as MutableChatModel; + return model.getRequests().some(r => (r as MutableChatRequestModel).isStale); + } + + /** + * Clean up token tracking data when a session is deleted. + */ + protected cleanupSession(sessionId: string): void { + this.tokenTracker.clearSessionBranchTokens(sessionId); + this.pendingSplits.delete(sessionId); + const prefix = `${sessionId}: `; + for (const key of this.triggeredBranches.keys()) { + if (key.startsWith(prefix)) { + this.triggeredBranches.delete(key); + } + } + } +} diff --git a/packages/ai-chat/src/browser/chat-session-token-tracker.spec.ts b/packages/ai-chat/src/browser/chat-session-token-tracker.spec.ts new file mode 100644 index 0000000000000..02202c6a336cb --- /dev/null +++ b/packages/ai-chat/src/browser/chat-session-token-tracker.spec.ts @@ -0,0 +1,253 @@ +// ***************************************************************************** +// Copyright (C) 2025 EclipseSource GmbH. +// +// This program and the accompanying materials are made available under the +// terms of the Eclipse Public License v. 2.0 which is available at +// http://www.eclipse.org/legal/epl-2.0. +// +// This Source Code may also be made available under the following Secondary +// Licenses when the conditions for such availability set forth in the Eclipse +// Public License v. 2.0 are satisfied: GNU General Public License, version 2 +// with the GNU Classpath Exception which is available at +// https://www.gnu.org/software/classpath/license.html. +// +// SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-only WITH Classpath-exception-2.0 +// ***************************************************************************** + +import { expect } from 'chai'; +import { Container } from '@theia/core/shared/inversify'; +import { ChatSessionTokenTrackerImpl } from './chat-session-token-tracker'; +import { SessionTokenUpdateEvent } from '../common/chat-session-token-tracker'; + +describe('ChatSessionTokenTrackerImpl', () => { + let container: Container; + let tracker: ChatSessionTokenTrackerImpl; + + beforeEach(() => { + container = new Container(); + container.bind(ChatSessionTokenTrackerImpl).toSelf().inSingletonScope(); + tracker = container.get(ChatSessionTokenTrackerImpl); + }); + + describe('getSessionInputTokens', () => { + it('should return undefined for unknown session', () => { + expect(tracker.getSessionInputTokens('unknown-session')).to.be.undefined; + }); + }); + + describe('getSessionOutputTokens', () => { + it('should return undefined for unknown session', () => { + expect(tracker.getSessionOutputTokens('unknown-session')).to.be.undefined; + }); + }); + + describe('getSessionTotalTokens', () => { + it('should return undefined for unknown session', () => { + expect(tracker.getSessionTotalTokens('unknown-session')).to.be.undefined; + }); + + it('should return input tokens when only input is set', () => { + const sessionId = 'session-1'; + tracker.updateSessionTokens(sessionId, 5000); + expect(tracker.getSessionTotalTokens(sessionId)).to.equal(5000); + }); + + it('should return sum of input and output tokens', () => { + const sessionId = 'session-1'; + tracker.updateSessionTokens(sessionId, 5000, 100); + expect(tracker.getSessionTotalTokens(sessionId)).to.equal(5100); + }); + }); + + describe('resetSessionTokens', () => { + it('should update token count and fire onSessionTokensUpdated', () => { + const sessionId = 'session-1'; + const updateEvents: SessionTokenUpdateEvent[] = []; + + tracker.onSessionTokensUpdated(event => updateEvents.push(event)); + + // Set initial token count via resetSessionTokens + tracker.resetSessionTokens(sessionId, 50000); + + expect(tracker.getSessionInputTokens(sessionId)).to.equal(50000); + expect(tracker.getSessionOutputTokens(sessionId)).to.be.undefined; + expect(updateEvents).to.have.length(1); + expect(updateEvents[0].sessionId).to.equal(sessionId); + expect(updateEvents[0].inputTokens).to.equal(50000); + expect(updateEvents[0].outputTokens).to.be.undefined; + + // Reset to new baseline (simulating post-summarization) + const newTokenCount = 10000; + tracker.resetSessionTokens(sessionId, newTokenCount); + + expect(tracker.getSessionInputTokens(sessionId)).to.equal(newTokenCount); + expect(tracker.getSessionOutputTokens(sessionId)).to.be.undefined; + expect(updateEvents).to.have.length(2); + expect(updateEvents[1].sessionId).to.equal(sessionId); + expect(updateEvents[1].inputTokens).to.equal(newTokenCount); + expect(updateEvents[1].outputTokens).to.be.undefined; + }); + + it('should delete token count and emit undefined when called with undefined', () => { + const sessionId = 'session-1'; + const updateEvents: SessionTokenUpdateEvent[] = []; + + tracker.onSessionTokensUpdated(event => updateEvents.push(event)); + + // Set initial token count via resetSessionTokens + tracker.resetSessionTokens(sessionId, 50000); + + expect(tracker.getSessionInputTokens(sessionId)).to.equal(50000); + expect(updateEvents).to.have.length(1); + + // Reset to undefined (simulating switch to branch with no prior LLM requests) + tracker.resetSessionTokens(sessionId, undefined); + + expect(tracker.getSessionInputTokens(sessionId)).to.be.undefined; + expect(tracker.getSessionOutputTokens(sessionId)).to.be.undefined; + expect(updateEvents).to.have.length(2); + expect(updateEvents[1].sessionId).to.equal(sessionId); + expect(updateEvents[1].inputTokens).to.be.undefined; + expect(updateEvents[1].outputTokens).to.be.undefined; + }); + + it('should clear output tokens when resetting', () => { + const sessionId = 'session-1'; + + // Set both input and output tokens via updateSessionTokens + tracker.updateSessionTokens(sessionId, 5000, 500); + expect(tracker.getSessionOutputTokens(sessionId)).to.equal(500); + + // Reset should clear output tokens + tracker.resetSessionTokens(sessionId, 3000); + expect(tracker.getSessionInputTokens(sessionId)).to.equal(3000); + expect(tracker.getSessionOutputTokens(sessionId)).to.be.undefined; + }); + }); + + describe('updateSessionTokens', () => { + it('should set input tokens and reset output to 0 when input provided', () => { + const sessionId = 'session-1'; + const updateEvents: SessionTokenUpdateEvent[] = []; + + tracker.onSessionTokensUpdated(event => updateEvents.push(event)); + + tracker.updateSessionTokens(sessionId, 5000); + + expect(tracker.getSessionInputTokens(sessionId)).to.equal(5000); + expect(tracker.getSessionOutputTokens(sessionId)).to.equal(0); + expect(updateEvents).to.have.length(1); + expect(updateEvents[0].inputTokens).to.equal(5000); + expect(updateEvents[0].outputTokens).to.equal(0); + }); + + it('should update output tokens progressively', () => { + const sessionId = 'session-1'; + const updateEvents: SessionTokenUpdateEvent[] = []; + + tracker.onSessionTokensUpdated(event => updateEvents.push(event)); + + // Initial request with input tokens + tracker.updateSessionTokens(sessionId, 5000, 0); + expect(tracker.getSessionOutputTokens(sessionId)).to.equal(0); + + // Progressive updates during streaming + tracker.updateSessionTokens(sessionId, undefined, 100); + expect(tracker.getSessionOutputTokens(sessionId)).to.equal(100); + expect(tracker.getSessionInputTokens(sessionId)).to.equal(5000); // Input unchanged + + tracker.updateSessionTokens(sessionId, undefined, 250); + expect(tracker.getSessionOutputTokens(sessionId)).to.equal(250); + + expect(updateEvents).to.have.length(3); + expect(updateEvents[2].inputTokens).to.equal(5000); + expect(updateEvents[2].outputTokens).to.equal(250); + }); + + it('should reset output to 0 when new input tokens arrive (new request)', () => { + const sessionId = 'session-1'; + + // First request + tracker.updateSessionTokens(sessionId, 5000, 500); + expect(tracker.getSessionTotalTokens(sessionId)).to.equal(5500); + + // New request starts - input tokens set, output resets + tracker.updateSessionTokens(sessionId, 5500); + expect(tracker.getSessionInputTokens(sessionId)).to.equal(5500); + expect(tracker.getSessionOutputTokens(sessionId)).to.equal(0); + expect(tracker.getSessionTotalTokens(sessionId)).to.equal(5500); + }); + + it('should not update input tokens when input is 0', () => { + const sessionId = 'session-1'; + + tracker.updateSessionTokens(sessionId, 5000); + expect(tracker.getSessionInputTokens(sessionId)).to.equal(5000); + + // Input of 0 should not reset + tracker.updateSessionTokens(sessionId, 0, 100); + expect(tracker.getSessionInputTokens(sessionId)).to.equal(5000); + expect(tracker.getSessionOutputTokens(sessionId)).to.equal(100); + }); + }); + + describe('branch token methods', () => { + it('should set and get branch tokens', () => { + const sessionId = 'session-1'; + const branchId = 'branch-1'; + + expect(tracker.getBranchTokens(sessionId, branchId)).to.be.undefined; + + tracker.setBranchTokens(sessionId, branchId, 5000); + + expect(tracker.getBranchTokens(sessionId, branchId)).to.equal(5000); + }); + + it('should get all branch tokens for a session', () => { + const sessionId = 'session-1'; + + tracker.setBranchTokens(sessionId, 'branch-1', 1000); + tracker.setBranchTokens(sessionId, 'branch-2', 2000); + tracker.setBranchTokens('other-session', 'branch-3', 3000); + + const result = tracker.getBranchTokensForSession(sessionId); + + expect(result).to.deep.equal({ + 'branch-1': 1000, + 'branch-2': 2000 + }); + }); + + it('should return empty object when no branch tokens exist for session', () => { + const result = tracker.getBranchTokensForSession('unknown-session'); + expect(result).to.deep.equal({}); + }); + + it('should restore branch tokens from persisted data', () => { + const sessionId = 'session-1'; + const branchTokens = { + 'branch-1': 1000, + 'branch-2': 2000 + }; + + tracker.restoreBranchTokens(sessionId, branchTokens); + + expect(tracker.getBranchTokens(sessionId, 'branch-1')).to.equal(1000); + expect(tracker.getBranchTokens(sessionId, 'branch-2')).to.equal(2000); + }); + + it('should clear all branch tokens for a session', () => { + const sessionId = 'session-1'; + + tracker.setBranchTokens(sessionId, 'branch-1', 1000); + tracker.setBranchTokens(sessionId, 'branch-2', 2000); + tracker.setBranchTokens('other-session', 'branch-3', 3000); + + tracker.clearSessionBranchTokens(sessionId); + + expect(tracker.getBranchTokens(sessionId, 'branch-1')).to.be.undefined; + expect(tracker.getBranchTokens(sessionId, 'branch-2')).to.be.undefined; + expect(tracker.getBranchTokens('other-session', 'branch-3')).to.equal(3000); + }); + }); +}); diff --git a/packages/ai-chat/src/browser/chat-session-token-tracker.ts b/packages/ai-chat/src/browser/chat-session-token-tracker.ts new file mode 100644 index 0000000000000..0c39b9ccc6e6c --- /dev/null +++ b/packages/ai-chat/src/browser/chat-session-token-tracker.ts @@ -0,0 +1,157 @@ +// ***************************************************************************** +// Copyright (C) 2025 EclipseSource GmbH. +// +// This program and the accompanying materials are made available under the +// terms of the Eclipse Public License v. 2.0 which is available at +// http://www.eclipse.org/legal/epl-2.0. +// +// This Source Code may also be made available under the following Secondary +// Licenses when the conditions for such availability set forth in the Eclipse +// Public License v. 2.0 are satisfied: GNU General Public License, version 2 +// with the GNU Classpath Exception which is available at +// https://www.gnu.org/software/classpath/license.html. +// +// SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-only WITH Classpath-exception-2.0 +// ***************************************************************************** + +import { injectable } from '@theia/core/shared/inversify'; +import { Emitter } from '@theia/core'; +import { ChatSessionTokenTracker, SessionTokenUpdateEvent } from '../common/chat-session-token-tracker'; + +// Re-export from common for backwards compatibility +export { ChatSessionTokenTracker, SessionTokenUpdateEvent } from '../common/chat-session-token-tracker'; + +/** + * Event fired when a session's token usage crosses the threshold. + */ +export interface SessionTokenThresholdEvent { + sessionId: string; + inputTokens: number; +} + +/** + * The maximum token budget for a chat session. + * This represents the approximate context window size that chat sessions target. + */ +export const CHAT_TOKEN_BUDGET = 200000; + +/** + * The percentage of the token budget at which summarization is triggered. + */ +export const CHAT_TOKEN_THRESHOLD_PERCENT = 0.9; + +/** + * The token threshold at which summarization is triggered. + * When input tokens reach this value (90% of budget), the system will + * attempt to summarize the conversation to stay within context limits. + */ +export const CHAT_TOKEN_THRESHOLD = CHAT_TOKEN_BUDGET * CHAT_TOKEN_THRESHOLD_PERCENT; + +@injectable() +export class ChatSessionTokenTrackerImpl implements ChatSessionTokenTracker { + protected readonly onSessionTokensUpdatedEmitter = new Emitter(); + readonly onSessionTokensUpdated = this.onSessionTokensUpdatedEmitter.event; + + /** + * Map of sessionId -> latest inputTokens count. + * Updated when token usage is reported for requests in that session. + */ + protected sessionInputTokens = new Map(); + + /** + * Map of sessionId -> latest outputTokens count. + * Updated progressively during streaming. + */ + protected sessionOutputTokens = new Map(); + + /** + * Map of branch tokens. Key format: `${sessionId}:${branchId}` + */ + protected branchTokens = new Map(); + + getSessionInputTokens(sessionId: string): number | undefined { + return this.sessionInputTokens.get(sessionId); + } + + getSessionOutputTokens(sessionId: string): number | undefined { + return this.sessionOutputTokens.get(sessionId); + } + + getSessionTotalTokens(sessionId: string): number | undefined { + const input = this.sessionInputTokens.get(sessionId); + const output = this.sessionOutputTokens.get(sessionId); + if (input === undefined && output === undefined) { + return undefined; + } + return (input ?? 0) + (output ?? 0); + } + + /** + * Reset the session's token count to a new baseline. + * Called after summarization to reflect the reduced token usage. + * The new count should reflect only the summary + any non-stale messages. + * + * @param sessionId - The session ID to reset + * @param newTokenCount - The new token count, or `undefined` to indicate unknown state. + * When `undefined`, deletes the stored count and emits `{ inputTokens: undefined, outputTokens: undefined }`. + */ + resetSessionTokens(sessionId: string, newTokenCount: number | undefined): void { + if (newTokenCount === undefined) { + this.sessionInputTokens.delete(sessionId); + } else { + this.sessionInputTokens.set(sessionId, newTokenCount); + } + this.sessionOutputTokens.delete(sessionId); + this.onSessionTokensUpdatedEmitter.fire({ sessionId, inputTokens: newTokenCount, outputTokens: undefined }); + } + + updateSessionTokens(sessionId: string, inputTokens?: number, outputTokens?: number): void { + if (inputTokens !== undefined && inputTokens > 0) { + this.sessionInputTokens.set(sessionId, inputTokens); + this.sessionOutputTokens.set(sessionId, 0); + } + if (outputTokens !== undefined) { + this.sessionOutputTokens.set(sessionId, outputTokens); + } + this.onSessionTokensUpdatedEmitter.fire({ + sessionId, + inputTokens: this.sessionInputTokens.get(sessionId), + outputTokens: this.sessionOutputTokens.get(sessionId) + }); + } + + setBranchTokens(sessionId: string, branchId: string, tokens: number): void { + this.branchTokens.set(`${sessionId}:${branchId}`, tokens); + } + + getBranchTokens(sessionId: string, branchId: string): number | undefined { + return this.branchTokens.get(`${sessionId}:${branchId}`); + } + + getBranchTokensForSession(sessionId: string): { [branchId: string]: number } { + const result: { [branchId: string]: number } = {}; + const prefix = `${sessionId}:`; + for (const [key, value] of this.branchTokens.entries()) { + if (key.startsWith(prefix)) { + const branchId = key.substring(prefix.length); + result[branchId] = value; + } + } + return result; + } + + restoreBranchTokens(sessionId: string, branchTokens: { [branchId: string]: number }): void { + for (const [branchId, tokens] of Object.entries(branchTokens)) { + this.branchTokens.set(`${sessionId}:${branchId}`, tokens); + } + } + + clearSessionBranchTokens(sessionId: string): void { + const prefix = `${sessionId}:`; + for (const key of this.branchTokens.keys()) { + if (key.startsWith(prefix)) { + this.branchTokens.delete(key); + } + } + } +} diff --git a/packages/ai-chat/src/browser/index.ts b/packages/ai-chat/src/browser/index.ts new file mode 100644 index 0000000000000..67734e22c894f --- /dev/null +++ b/packages/ai-chat/src/browser/index.ts @@ -0,0 +1,17 @@ +// ***************************************************************************** +// Copyright (C) 2025 EclipseSource GmbH. +// +// This program and the accompanying materials are made available under the +// terms of the Eclipse Public License v. 2.0 which is available at +// http://www.eclipse.org/legal/epl-2.0. +// +// This Source Code may also be made available under the following Secondary +// Licenses when the conditions for such availability set forth in the Eclipse +// Public License v. 2.0 are satisfied: GNU General Public License, version 2 +// with the GNU Classpath Exception which is available at +// https://www.gnu.org/software/classpath/license.html. +// +// SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-only WITH Classpath-exception-2.0 +// ***************************************************************************** + +export * from './chat-session-token-tracker'; diff --git a/packages/ai-chat/src/common/ai-chat-preferences.ts b/packages/ai-chat/src/common/ai-chat-preferences.ts index bd40ccb1aaac6..9a3226b23aafd 100644 --- a/packages/ai-chat/src/common/ai-chat-preferences.ts +++ b/packages/ai-chat/src/common/ai-chat-preferences.ts @@ -20,6 +20,7 @@ import { nls, PreferenceSchema } from '@theia/core'; export const DEFAULT_CHAT_AGENT_PREF = 'ai-features.chat.defaultChatAgent'; export const PIN_CHAT_AGENT_PREF = 'ai-features.chat.pinChatAgent'; export const BYPASS_MODEL_REQUIREMENT_PREF = 'ai-features.chat.bypassModelRequirement'; +export const BUDGET_AWARE_TOOL_LOOP_PREF = 'ai-features.chat.experimentalBudgetAwareToolLoop'; export const PERSISTED_SESSION_LIMIT_PREF = 'ai-features.chat.persistedSessionLimit'; export const aiChatPreferences: PreferenceSchema = { @@ -46,6 +47,15 @@ export const aiChatPreferences: PreferenceSchema = { default: false, title: AI_CORE_PREFERENCES_TITLE, }, + [BUDGET_AWARE_TOOL_LOOP_PREF]: { + type: 'boolean', + description: nls.localize('theia/ai/chat/budgetAwareToolLoop/description', + 'Experimental: Enable budget-aware tool loop. When enabled, the chat agent can trigger summarization mid-turn \ +if the token budget is exceeded during tool call loops. This prevents API errors from context overflow. \ +Requires language model support (currently only Anthropic models).'), + default: false, + title: AI_CORE_PREFERENCES_TITLE, + }, [PERSISTED_SESSION_LIMIT_PREF]: { type: 'number', description: nls.localize('theia/ai/chat/persistedSessionLimit/description', diff --git a/packages/ai-chat/src/common/chat-agents.ts b/packages/ai-chat/src/common/chat-agents.ts index 752e2e5bf0191..efc71dfe5d379 100644 --- a/packages/ai-chat/src/common/chat-agents.ts +++ b/packages/ai-chat/src/common/chat-agents.ts @@ -42,6 +42,7 @@ import { TextMessage, ToolCall, ToolRequest, + UsageResponsePart, } from '@theia/ai-core'; import { Agent, @@ -51,7 +52,7 @@ import { LanguageModelStreamResponsePart } from '@theia/ai-core/lib/common'; import { ContributionProvider, ILogger, isArray, nls } from '@theia/core'; -import { inject, injectable, named, postConstruct } from '@theia/core/shared/inversify'; +import { inject, injectable, named, optional, postConstruct } from '@theia/core/shared/inversify'; import { ChatAgentService } from './chat-agent-service'; import { ChatModel, @@ -97,6 +98,39 @@ export namespace SystemMessageDescription { } } +/** + * Symbol for optional injection of the ChatSummarizationService. + * This allows browser implementations to provide summarization support. + */ +export const ChatSessionSummarizationServiceSymbol = Symbol('ChatSessionSummarizationService'); + +/** + * Minimal interface for chat summarization service. + * Used by AbstractChatAgent to optionally trigger summarization. + */ +export interface ChatSummarizationService { + /** + * Update token tracking during streaming. + * Called when usage data is received in the stream, before the response completes. + */ + updateTokens( + sessionId: string, + usage: UsageResponsePart + ): void; + + /** + * Check and handle summarization after response completes. + * Called after the stream ends with the final accumulated usage. + * Usage may be undefined if the stream ended early (e.g., due to budget-exceeded split). + */ + checkAndHandleSummarization( + sessionId: string, + agent: ChatAgent, + request: MutableChatRequestModel, + usage: UsageResponsePart | undefined + ): Promise; +} + export interface ChatSessionContext extends AIVariableContext { request?: ChatRequestModel; model: ChatModel; @@ -170,6 +204,9 @@ export abstract class AbstractChatAgent implements ChatAgent { @inject(DefaultResponseContentFactory) protected defaultContentFactory: DefaultResponseContentFactory; + @inject(ChatSessionSummarizationServiceSymbol) @optional() + protected summarizationService?: ChatSummarizationService; + readonly abstract id: string; readonly abstract name: string; readonly abstract languageModelRequirements: LanguageModelRequirement[]; @@ -231,8 +268,11 @@ export abstract class AbstractChatAgent implements ChatAgent { ]; const languageModelResponse = await this.sendLlmRequest(request, messages, tools, languageModel); - await this.addContentsToResponse(languageModelResponse, request); - await this.onResponseComplete(request); + const usage = await this.addContentsToResponse(languageModelResponse, request); + const summarizationHandled = await this.checkSummarization(request, usage); + if (!summarizationHandled) { + await this.onResponseComplete(request); + } } catch (e) { this.handleError(request, e); @@ -291,7 +331,13 @@ export abstract class AbstractChatAgent implements ChatAgent { model: ChatModel, includeResponseInProgress = false ): Promise { const requestMessages = model.getRequests().flatMap(request => { + // Skip stale requests entirely - their content is replaced by summary nodes + if (request.isStale === true) { + return []; + } + const messages: LanguageModelMessage[] = []; + const text = request.message.parts.map(part => part.promptText).join(''); if (text.length > 0) { messages.push({ @@ -314,7 +360,7 @@ export abstract class AbstractChatAgent implements ChatAgent { })); messages.push(...imageMessages); - if (request.response.isComplete || includeResponseInProgress) { + if (request.response.isComplete || includeResponseInProgress || request.request.kind === 'continuation') { const responseMessages: LanguageModelMessage[] = request.response.response.content .filter(c => { // we do not send errors or informational content @@ -398,6 +444,30 @@ export abstract class AbstractChatAgent implements ChatAgent { return undefined; } + /** + * Hook called after addContentsToResponse() to check if summarization is needed. + * Returns true if summarization was triggered (response handling is complete). + * Returns false if no summarization needed (caller should call onResponseComplete). + * + * Uses the injected ChatSummarizationService if available (browser context). + * Returns false in non-browser contexts where the service is not injected. + * Returns false if no usage data is available (LLM doesn't report usage). + * + * @param request The chat request model + * @param usage Usage data from the response stream for synchronous token tracking + */ + protected async checkSummarization(request: MutableChatRequestModel, usage: UsageResponsePart | undefined): Promise { + if (this.summarizationService) { + return this.summarizationService.checkAndHandleSummarization( + request.session.id, + this, + request, + usage + ); + } + return false; + } + /** * Invoked after the response by the LLM completed successfully. * @@ -408,17 +478,18 @@ export abstract class AbstractChatAgent implements ChatAgent { return request.response.complete(); } - protected abstract addContentsToResponse(languageModelResponse: LanguageModelResponse, request: MutableChatRequestModel): Promise; + protected abstract addContentsToResponse(languageModelResponse: LanguageModelResponse, request: MutableChatRequestModel): Promise; } @injectable() export abstract class AbstractTextToModelParsingChatAgent extends AbstractChatAgent { - protected async addContentsToResponse(languageModelResponse: LanguageModelResponse, request: MutableChatRequestModel): Promise { + protected async addContentsToResponse(languageModelResponse: LanguageModelResponse, request: MutableChatRequestModel): Promise { const responseAsText = await getTextOfResponse(languageModelResponse); const parsedCommand = await this.parseTextResponse(responseAsText); const content = this.createResponseContent(parsedCommand, request); request.response.response.addContent(content); + return undefined; // Text responses don't have usage data in the same way } protected abstract parseTextResponse(text: string): Promise; @@ -448,15 +519,14 @@ export abstract class AbstractStreamParsingChatAgent extends AbstractChatAgent { @inject(ToolCallChatResponseContentFactory) protected toolCallResponseContentFactory: ToolCallChatResponseContentFactory; - protected override async addContentsToResponse(languageModelResponse: LanguageModelResponse, request: MutableChatRequestModel): Promise { + protected override async addContentsToResponse(languageModelResponse: LanguageModelResponse, request: MutableChatRequestModel): Promise { if (isLanguageModelTextResponse(languageModelResponse)) { const contents = this.parseContents(languageModelResponse.text, request); request.response.response.addContents(contents); - return; + return undefined; } if (isLanguageModelStreamResponse(languageModelResponse)) { - await this.addStreamResponse(languageModelResponse, request); - return; + return this.addStreamResponse(languageModelResponse, request); } this.logger.error( 'Received unknown response in agent. Return response as text' @@ -466,17 +536,42 @@ export abstract class AbstractStreamParsingChatAgent extends AbstractChatAgent { JSON.stringify(languageModelResponse) ) ); + return undefined; } - protected async addStreamResponse(languageModelResponse: LanguageModelStreamResponse, request: MutableChatRequestModel): Promise { + protected async addStreamResponse(languageModelResponse: LanguageModelStreamResponse, request: MutableChatRequestModel): Promise { let completeTextBuffer = ''; let startIndex = request.response.response.content.length; + // Accumulate usage data across multiple usage parts (some providers yield input/output separately) + let accumulatedUsage: UsageResponsePart | undefined; for await (const token of languageModelResponse.stream) { // Skip unknown tokens. For example OpenAI sends empty tokens around tool calls if (!isLanguageModelStreamResponsePart(token)) { console.debug(`Unknown token: '${JSON.stringify(token)}'. Skipping`); continue; } + // Accumulate usage data (some providers yield input/output separately) + if (isUsageResponsePart(token)) { + if (!accumulatedUsage) { + accumulatedUsage = { ...token }; + } else { + // Accumulate non-zero values (providers may yield input and output in separate parts) + if (token.input_tokens > 0) { + accumulatedUsage.input_tokens = token.input_tokens; + } + if (token.output_tokens > 0) { + accumulatedUsage.output_tokens = token.output_tokens; + } + if (token.cache_creation_input_tokens !== undefined) { + accumulatedUsage.cache_creation_input_tokens = token.cache_creation_input_tokens; + } + if (token.cache_read_input_tokens !== undefined) { + accumulatedUsage.cache_read_input_tokens = token.cache_read_input_tokens; + } + } + // Update token tracking in real-time during streaming + this.summarizationService?.updateTokens(request.session.id, accumulatedUsage); + } const newContent = this.parse(token, request); if (!isTextResponsePart(token)) { // For non-text tokens (like tool calls), add them directly @@ -489,8 +584,7 @@ export abstract class AbstractStreamParsingChatAgent extends AbstractChatAgent { startIndex = request.response.response.content.length; completeTextBuffer = ''; } else { - // parse the entire text so far (since beginning of the stream or last non-text token) - // and replace the entire content with the currently parsed content parts + // Parse accumulated text and replace with parsed content parts completeTextBuffer += token.content; const parsedContents = this.parseContents(completeTextBuffer, request); @@ -503,6 +597,7 @@ export abstract class AbstractStreamParsingChatAgent extends AbstractChatAgent { request.response.response.addContents(parsedContents); } } + return accumulatedUsage; } protected parse(token: LanguageModelStreamResponsePart, request: MutableChatRequestModel): ChatResponseContent | ChatResponseContent[] { diff --git a/packages/ai-chat/src/common/chat-auto-save.spec.ts b/packages/ai-chat/src/common/chat-auto-save.spec.ts index db87465db5f0c..6e000bda4e99a 100644 --- a/packages/ai-chat/src/common/chat-auto-save.spec.ts +++ b/packages/ai-chat/src/common/chat-auto-save.spec.ts @@ -372,4 +372,5 @@ describe('Chat Auto-Save Mechanism', () => { expect(sessionStore.saveCount).to.be.greaterThan(0); }); }); + }); diff --git a/packages/ai-chat/src/common/chat-content-deserializer.spec.ts b/packages/ai-chat/src/common/chat-content-deserializer.spec.ts index 656847d3f9d19..02ab44832fb3a 100644 --- a/packages/ai-chat/src/common/chat-content-deserializer.spec.ts +++ b/packages/ai-chat/src/common/chat-content-deserializer.spec.ts @@ -30,6 +30,7 @@ import { MarkdownChatResponseContentImpl, ProgressChatResponseContentImpl, QuestionResponseContentImpl, + SummaryChatResponseContentImpl, TextChatResponseContentImpl, ThinkingChatResponseContentImpl, ToolCallChatResponseContentImpl @@ -347,6 +348,36 @@ describe('Chat Content Serialization', () => { }); }); + describe('SummaryChatResponseContentImpl', () => { + it('should serialize and deserialize correctly', async () => { + const original = new SummaryChatResponseContentImpl('This is a summary of the conversation.'); + const serialized = original.toSerializable?.(); + + expect(serialized).to.not.be.undefined; + expect(serialized!.kind).to.equal('summary'); + expect(serialized!.data).to.deep.equal({ content: 'This is a summary of the conversation.' }); + + // Simulate caller populating fallbackMessage + const withFallback = { + ...serialized!, + fallbackMessage: original.asString?.() || original.toString() + }; + + const deserialized = await registry.deserialize(withFallback); + expect(deserialized.kind).to.equal('summary'); + expect(deserialized.asString?.()).to.equal('This is a summary of the conversation.'); + }); + + it('should include summary prefix in language model message', () => { + const original = new SummaryChatResponseContentImpl('Summary content'); + const message = original.toLanguageModelMessage(); + + expect(message.type).to.equal('text'); + expect(message.text).to.include('[Summary of previous conversation]'); + expect(message.text).to.include('Summary content'); + }); + }); + describe('ChatContentDeserializerRegistry', () => { it('should handle unknown content types with fallback', async () => { const unknownContent = { diff --git a/packages/ai-chat/src/common/chat-content-deserializer.ts b/packages/ai-chat/src/common/chat-content-deserializer.ts index 7f862b832f222..62bd9160c0632 100644 --- a/packages/ai-chat/src/common/chat-content-deserializer.ts +++ b/packages/ai-chat/src/common/chat-content-deserializer.ts @@ -25,6 +25,7 @@ import { MarkdownChatResponseContentImpl, ProgressChatResponseContentImpl, QuestionResponseContentImpl, + SummaryChatResponseContentImpl, TextChatResponseContentImpl, ThinkingChatResponseContentImpl, ToolCallChatResponseContentImpl, @@ -39,7 +40,8 @@ import { HorizontalLayoutContentData, ProgressContentData, ErrorContentData, - QuestionContentData + QuestionContentData, + SummaryContentData } from './chat-model'; import { SerializableChatResponseContentData } from './chat-model-serialization'; import { ContributionProvider, ILogger, MaybePromise } from '@theia/core'; @@ -324,5 +326,10 @@ export class DefaultChatContentDeserializerContribution implements ChatContentDe data.selectedOption ) }); + + registry.register({ + kind: 'summary', + deserialize: (data: SummaryContentData) => new SummaryChatResponseContentImpl(data.content) + }); } } diff --git a/packages/ai-chat/src/common/chat-model-hierarchy.spec.ts b/packages/ai-chat/src/common/chat-model-hierarchy.spec.ts new file mode 100644 index 0000000000000..9439b1c04a311 --- /dev/null +++ b/packages/ai-chat/src/common/chat-model-hierarchy.spec.ts @@ -0,0 +1,197 @@ +// ***************************************************************************** +// Copyright (C) 2025 EclipseSource GmbH. +// +// This program and the accompanying materials are made available under the +// terms of the Eclipse Public License v. 2.0 which is available at +// http://www.eclipse.org/legal/epl-2.0. +// +// This Source Code may also be made available under the following Secondary +// Licenses when the conditions for such availability set forth in the Eclipse +// Public License v. 2.0 are satisfied: GNU General Public License, version 2 +// with the GNU Classpath Exception which is available at +// https://www.gnu.org/software/classpath/license.html. +// +// SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-only WITH Classpath-exception-2.0 +// ***************************************************************************** + +import { expect } from 'chai'; +import { ChatAgentLocation } from './chat-agents'; +import { MutableChatModel } from './chat-model'; +import { ParsedChatRequest } from './parsed-chat-request'; + +describe('ChatRequestHierarchyBranchImpl', () => { + + function createParsedRequest(text: string): ParsedChatRequest { + return { + request: { text }, + parts: [{ + kind: 'text', + text, + promptText: text, + range: { start: 0, endExclusive: text.length } + }], + toolRequests: new Map(), + variables: [] + }; + } + + describe('remove()', () => { + it('should not fire onDidChange when removing the last item from a branch', () => { + const model = new MutableChatModel(ChatAgentLocation.Panel); + const request = model.addRequest(createParsedRequest('Single request')); + request.response.complete(); + + const branch = model.getBranch(request.id); + expect(branch).to.not.be.undefined; + + let changeEventFired = false; + model.onDidChange(event => { + if (event.kind === 'changeHierarchyBranch') { + changeEventFired = true; + } + }); + + branch!.remove(request); + + expect(changeEventFired).to.be.false; + expect(branch!.items.length).to.equal(0); + }); + + it('should fire onDidChange when removing a non-last item from a branch', () => { + const model = new MutableChatModel(ChatAgentLocation.Panel); + const request1 = model.addRequest(createParsedRequest('First request')); + request1.response.complete(); + + // Add second request as an alternative in the same branch + const branch = model.getBranch(request1.id); + expect(branch).to.not.be.undefined; + const request2 = model.addRequest({ + ...createParsedRequest('Second request'), + request: { + text: 'Second request', + referencedRequestId: request1.id + } + }); + request2.response.complete(); + + let changeEventFired = false; + model.onDidChange(event => { + if (event.kind === 'changeHierarchyBranch') { + changeEventFired = true; + } + }); + + branch!.remove(request1); + + expect(changeEventFired).to.be.true; + expect(branch!.items.length).to.equal(1); + }); + + it('should set activeBranchIndex to -1 when branch becomes empty', () => { + const model = new MutableChatModel(ChatAgentLocation.Panel); + const request = model.addRequest(createParsedRequest('Single request')); + request.response.complete(); + + const branch = model.getBranch(request.id); + expect(branch).to.not.be.undefined; + expect(branch!.activeBranchIndex).to.equal(0); + + branch!.remove(request); + + expect(branch!.activeBranchIndex).to.equal(-1); + expect(branch!.items.length).to.equal(0); + }); + + it('should correctly adjust activeBranchIndex when removing item before active', () => { + const model = new MutableChatModel(ChatAgentLocation.Panel); + const request1 = model.addRequest(createParsedRequest('First request')); + request1.response.complete(); + + // Add second request as an alternative in the same branch + const branch = model.getBranch(request1.id); + expect(branch).to.not.be.undefined; + + const request2 = model.addRequest({ + ...createParsedRequest('Second request'), + request: { + text: 'Second request', + referencedRequestId: request1.id + } + }); + request2.response.complete(); + + // After adding request2, it should be active (index 1) + expect(branch!.activeBranchIndex).to.equal(1); + + branch!.remove(request1); + + // After removing request1 (index 0), active index should be adjusted to 0 + expect(branch!.activeBranchIndex).to.equal(0); + expect(branch!.items.length).to.equal(1); + expect(branch!.get().id).to.equal(request2.id); + }); + }); + + describe('get()', () => { + it('should throw meaningful error when called on empty branch', () => { + const model = new MutableChatModel(ChatAgentLocation.Panel); + const request = model.addRequest(createParsedRequest('Single request')); + request.response.complete(); + + const branch = model.getBranch(request.id); + expect(branch).to.not.be.undefined; + + // Remove the request to make branch empty + branch!.remove(request); + expect(branch!.items.length).to.equal(0); + + // get() should throw meaningful error instead of crashing + expect(() => branch!.get()).to.throw('Cannot get request from empty branch'); + }); + + it('should return request when branch has items', () => { + const model = new MutableChatModel(ChatAgentLocation.Panel); + const request = model.addRequest(createParsedRequest('Test request')); + request.response.complete(); + + const branch = model.getBranch(request.id); + expect(branch).to.not.be.undefined; + expect(branch!.items.length).to.equal(1); + + // get() should return the request + expect(branch!.get().id).to.equal(request.id); + }); + }); + + describe('dispose()', () => { + it('should not throw when disposing an empty branch', () => { + const model = new MutableChatModel(ChatAgentLocation.Panel); + const request = model.addRequest(createParsedRequest('Single request')); + request.response.complete(); + + const branch = model.getBranch(request.id); + expect(branch).to.not.be.undefined; + + // Remove the request to make branch empty + branch!.remove(request); + expect(branch!.items.length).to.equal(0); + + // dispose() should not throw on empty branch + expect(() => branch!.dispose()).to.not.throw(); + }); + + it('should dispose all items when branch has items', () => { + const model = new MutableChatModel(ChatAgentLocation.Panel); + const request = model.addRequest(createParsedRequest('Test request')); + request.response.complete(); + + const branch = model.getBranch(request.id); + expect(branch).to.not.be.undefined; + expect(branch!.items.length).to.equal(1); + + // dispose() should not throw and should clear items + expect(() => branch!.dispose()).to.not.throw(); + expect(branch!.items.length).to.equal(0); + }); + }); +}); diff --git a/packages/ai-chat/src/common/chat-model-insert-summary.spec.ts b/packages/ai-chat/src/common/chat-model-insert-summary.spec.ts new file mode 100644 index 0000000000000..fbc9ff24856fb --- /dev/null +++ b/packages/ai-chat/src/common/chat-model-insert-summary.spec.ts @@ -0,0 +1,257 @@ +// ***************************************************************************** +// Copyright (C) 2025 EclipseSource GmbH. +// +// This program and the accompanying materials are made available under the +// terms of the Eclipse Public License v. 2.0 which is available at +// http://www.eclipse.org/legal/epl-2.0. +// +// This Source Code may also be made available under the following Secondary +// Licenses when the conditions for such availability set forth in the Eclipse +// Public License v. 2.0 are satisfied: GNU General Public License, version 2 +// with the GNU Classpath Exception which is available at +// https://www.gnu.org/software/classpath/license.html. +// +// SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-only WITH Classpath-exception-2.0 +// ***************************************************************************** + +import { expect } from 'chai'; +import { ChatAgentLocation } from './chat-agents'; +import { MutableChatModel, SummaryChatResponseContent, SummaryChatResponseContentImpl } from './chat-model'; +import { ParsedChatRequest } from './parsed-chat-request'; + +describe('MutableChatModel.insertSummary()', () => { + + function createParsedRequest(text: string, kind?: 'user' | 'summary'): ParsedChatRequest { + return { + request: { text, kind }, + parts: [{ + kind: 'text', + text, + promptText: text, + range: { start: 0, endExclusive: text.length } + }], + toolRequests: new Map(), + variables: [] + }; + } + + function createModelWithRequests(count: number): MutableChatModel { + const model = new MutableChatModel(ChatAgentLocation.Panel); + for (let i = 1; i <= count; i++) { + const req = model.addRequest(createParsedRequest(`Request ${i}`)); + req.response.complete(); + } + return model; + } + + /** + * Helper to create a summary callback that simulates ChatService.sendRequest(). + * It creates the summary request directly on the model (as sendRequest would do internally) + * and returns the expected result structure. + */ + function createSummaryCallback(model: MutableChatModel, summaryText: string): () => Promise<{ requestId: string; summaryText: string } | undefined> { + return async () => { + // Simulate what ChatService.sendRequest() would do: create a request on the model + const summaryRequest = model.addRequest(createParsedRequest(summaryText, 'summary')); + // Add the summary content to the response + summaryRequest.response.response.addContent(new SummaryChatResponseContentImpl(summaryText)); + summaryRequest.response.complete(); + return { + requestId: summaryRequest.id, + summaryText + }; + }; + } + + describe('basic functionality', () => { + it('should return undefined when model has less than 2 requests', async () => { + const model = new MutableChatModel(ChatAgentLocation.Panel); + model.addRequest(createParsedRequest('Single request')); + + const result = await model.insertSummary( + async () => ({ requestId: 'test-id', summaryText: 'Summary text' }), + 'end' + ); + + expect(result).to.be.undefined; + }); + + it('should return undefined when model is empty', async () => { + const model = new MutableChatModel(ChatAgentLocation.Panel); + + const result = await model.insertSummary( + async () => ({ requestId: 'test-id', summaryText: 'Summary text' }), + 'end' + ); + + expect(result).to.be.undefined; + }); + + it('should return summary text on success', async () => { + const model = createModelWithRequests(3); + + const result = await model.insertSummary( + createSummaryCallback(model, 'This is a summary'), + 'end' + ); + + expect(result).to.equal('This is a summary'); + }); + }); + + describe('position: end', () => { + it('should append summary at the end', async () => { + const model = createModelWithRequests(3); + + await model.insertSummary( + createSummaryCallback(model, 'Summary text'), + 'end' + ); + + const requests = model.getRequests(); + // Should have 4 requests: 3 original + 1 summary + expect(requests).to.have.lengthOf(4); + expect(requests[3].request.kind).to.equal('summary'); + }); + + it('should mark ALL original requests as stale (including trigger)', async () => { + const model = createModelWithRequests(3); + + await model.insertSummary( + createSummaryCallback(model, 'Summary text'), + 'end' + ); + + const requests = model.getRequests(); + // All 3 original requests (indices 0-2) should be stale (including trigger request) + // This is because for between-turn summarization, the trigger request's content + // has ALREADY been summarized and should be excluded from future prompts + expect(requests[0].isStale).to.be.true; + expect(requests[1].isStale).to.be.true; + expect(requests[2].isStale).to.be.true; + // Summary request (index 3) should NOT be stale (it's created inside callback, after stale list is computed) + expect(requests[3].isStale).to.be.false; + }); + + it('should create SummaryChatResponseContent in response', async () => { + const model = createModelWithRequests(2); + + await model.insertSummary( + createSummaryCallback(model, 'The conversation summary'), + 'end' + ); + + const summaryRequest = model.getRequests().find(r => r.request.kind === 'summary'); + expect(summaryRequest).to.not.be.undefined; + + const content = summaryRequest!.response.response.content; + expect(content).to.have.lengthOf(1); + expect(SummaryChatResponseContent.is(content[0])).to.be.true; + expect((content[0] as SummaryChatResponseContent).content).to.equal('The conversation summary'); + }); + }); + + describe('callback failure handling', () => { + it('should return undefined on callback returning undefined (end position)', async () => { + const model = createModelWithRequests(3); + const originalRequestCount = model.getRequests().length; + + const result = await model.insertSummary( + async () => undefined, + 'end' + ); + + expect(result).to.be.undefined; + // Model should be unchanged - callback didn't create any request + expect(model.getRequests()).to.have.lengthOf(originalRequestCount); + // Stale flags should remain unchanged + model.getRequests().forEach(r => { + expect(r.isStale).to.be.false; + }); + }); + + it('should return undefined on callback throwing error (end position)', async () => { + const model = createModelWithRequests(3); + const originalRequestCount = model.getRequests().length; + + const result = await model.insertSummary( + async () => { throw new Error('Agent failed'); }, + 'end' + ); + + expect(result).to.be.undefined; + // Model should be unchanged - callback didn't create any request before throwing + expect(model.getRequests()).to.have.lengthOf(originalRequestCount); + // Stale flags should remain unchanged + model.getRequests().forEach(r => { + expect(r.isStale).to.be.false; + }); + }); + + }); + + describe('callback creates request via model', () => { + it('should find created request by requestId after callback returns', async () => { + const model = createModelWithRequests(2); + let createdRequestId: string | undefined; + + await model.insertSummary( + async () => { + // Simulate ChatService.sendRequest() creating a request + const createdSummaryRequest = model.addRequest(createParsedRequest('Summary', 'summary')); + createdSummaryRequest.response.response.addContent(new SummaryChatResponseContentImpl('Summary')); + createdSummaryRequest.response.complete(); + createdRequestId = createdSummaryRequest.id; + return { + requestId: createdSummaryRequest.id, + summaryText: 'Summary' + }; + }, + 'end' + ); + + // The summary request should be findable in the model + const summaryRequest = model.getRequests().find(r => r.id === createdRequestId); + expect(summaryRequest).to.not.be.undefined; + expect(summaryRequest!.request.kind).to.equal('summary'); + }); + + it('should return undefined if requestId references non-existent request', async () => { + const model = createModelWithRequests(2); + + const result = await model.insertSummary( + async () => ({ + requestId: 'non-existent-id', + summaryText: 'Summary' + }), + 'end' + ); + + // Should return undefined because request wasn't found + expect(result).to.be.undefined; + }); + }); + + describe('already stale requests', () => { + it('should not re-mark already stale requests', async () => { + const model = createModelWithRequests(4); + // Mark first request as already stale + model.getRequests()[0].isStale = true; + + await model.insertSummary( + createSummaryCallback(model, 'Summary'), + 'end' + ); + + const requests = model.getRequests(); + // First request was already stale, should remain stale + expect(requests[0].isStale).to.be.true; + // Second, third, and fourth requests should now be stale (all original requests) + expect(requests[1].isStale).to.be.true; + expect(requests[2].isStale).to.be.true; + expect(requests[3].isStale).to.be.true; + // Summary request (index 4) should NOT be stale + expect(requests[4].isStale).to.be.false; + }); + }); +}); diff --git a/packages/ai-chat/src/common/chat-model-serialization.spec.ts b/packages/ai-chat/src/common/chat-model-serialization.spec.ts index adc3e568d8b37..bb8e41a9965bc 100644 --- a/packages/ai-chat/src/common/chat-model-serialization.spec.ts +++ b/packages/ai-chat/src/common/chat-model-serialization.spec.ts @@ -35,6 +35,71 @@ describe('ChatModel Serialization and Restoration', () => { }; } + describe('isStale property serialization', () => { + it('should not include isStale when false (default)', () => { + const model = new MutableChatModel(ChatAgentLocation.Panel); + model.addRequest(createParsedRequest('Hello')); + + const serialized = model.toSerializable(); + + expect(serialized.requests[0].isStale).to.be.undefined; + }); + + it('should include isStale when true', () => { + const model = new MutableChatModel(ChatAgentLocation.Panel); + const req = model.addRequest(createParsedRequest('Hello')); + req.isStale = true; + + const serialized = model.toSerializable(); + + expect(serialized.requests[0].isStale).to.be.true; + }); + + it('should restore isStale flag correctly', () => { + const model1 = new MutableChatModel(ChatAgentLocation.Panel); + const req = model1.addRequest(createParsedRequest('Hello')); + req.isStale = true; + + const serialized = model1.toSerializable(); + const model2 = new MutableChatModel(serialized); + + expect(model2.getRequests()[0].isStale).to.be.true; + }); + + it('should default isStale to false when missing in serialized data', () => { + const serializedData = { + sessionId: 'test-session', + location: ChatAgentLocation.Panel, + hierarchy: { + rootBranchId: 'branch-root', + branches: { + 'branch-root': { + id: 'branch-root', + items: [{ requestId: 'request-1' }], + activeBranchIndex: 0 + } + } + }, + requests: [{ + id: 'request-1', + text: 'Hello' + // isStale is intentionally omitted + }], + responses: [{ + id: 'response-1', + requestId: 'request-1', + isComplete: true, + isError: false, + content: [] + }] + }; + + const model = new MutableChatModel(serializedData); + + expect(model.getRequests()[0].isStale).to.be.false; + }); + }); + describe('Simple tree serialization', () => { it('should serialize a chat with a single request', () => { const model = new MutableChatModel(ChatAgentLocation.Panel); diff --git a/packages/ai-chat/src/common/chat-model-serialization.ts b/packages/ai-chat/src/common/chat-model-serialization.ts index 2a9de4c222fcf..f700eb058d14e 100644 --- a/packages/ai-chat/src/common/chat-model-serialization.ts +++ b/packages/ai-chat/src/common/chat-model-serialization.ts @@ -41,6 +41,10 @@ export interface SerializableChatRequestData { id: string; text: string; agentId?: string; + /** The type of request. Defaults to 'user' if not specified (for backward compatibility). */ + kind?: 'user' | 'summary' | 'continuation'; + /** Indicates this request has been summarized and should be excluded from prompt construction */ + isStale?: boolean; changeSet?: { title: string; elements: SerializableChangeSetElement[]; @@ -126,6 +130,8 @@ export interface SerializedChatData { title?: string; model: SerializedChatModel; saveDate: number; + lastInputTokens?: number; + branchTokens?: { [branchId: string]: number }; } export interface SerializableChatsData { diff --git a/packages/ai-chat/src/common/chat-model.ts b/packages/ai-chat/src/common/chat-model.ts index 4ffd618d46c96..543c88554e989 100644 --- a/packages/ai-chat/src/common/chat-model.ts +++ b/packages/ai-chat/src/common/chat-model.ts @@ -243,6 +243,8 @@ export interface ChangeSetDecoration { readonly additionalInfoSuffixIcon?: string[]; } +export type ChatRequestKind = 'user' | 'summary' | 'continuation'; + export interface ChatRequest { readonly text: string; readonly displayText?: string; @@ -254,6 +256,10 @@ export interface ChatRequest { readonly referencedRequestId?: string; readonly variables?: readonly AIVariableResolutionRequest[]; readonly modeId?: string; + /** + * The type of request. Defaults to 'user' if not specified. + */ + readonly kind?: ChatRequestKind; } export interface ChatContext { @@ -269,6 +275,8 @@ export interface ChatRequestModel { readonly context: ChatContext; readonly agentId?: string; readonly data?: { [key: string]: unknown }; + /** Indicates this request has been summarized and should be excluded from prompt construction */ + readonly isStale?: boolean; toSerializable(): SerializableChatRequestData; } @@ -425,6 +433,10 @@ export interface ProgressContentData { message: string; } +export interface SummaryContentData { + content: string; +} + export interface ErrorContentData { message: string; stack?: string; @@ -497,6 +509,17 @@ export interface ProgressChatResponseContent message: string; } +/** + * A summary chat response content represents a condensed summary of previous chat messages. + * It is used when chat history exceeds token limits and older messages need to be summarized. + * The summary is displayed collapsed by default and is read-only. + */ +export interface SummaryChatResponseContent extends ChatResponseContent { + kind: 'summary'; + /** The summary text content */ + content: string; +} + export interface Location { uri: URI; position: Position; @@ -644,6 +667,17 @@ export namespace ProgressChatResponseContent { } } +export namespace SummaryChatResponseContent { + export function is(obj: unknown): obj is SummaryChatResponseContent { + return ( + ChatResponseContent.is(obj) && + obj.kind === 'summary' && + 'content' in obj && + typeof (obj as { content: unknown }).content === 'string' + ); + } +} + export type QuestionResponseHandler = ( selectedOption: { text: string, value?: string }, ) => void; @@ -826,6 +860,12 @@ export class MutableChatModel implements ChatModel, Disposable { for (const reqData of data.requests) { const respData = data.responses.find(r => r.requestId === reqData.id); const requestModel = new MutableChatRequestModel(this, reqData, respData); + // Subscribe to request changes and forward to model emitter (same as addRequest) + requestModel.onDidChange(event => { + if (!ChatChangeEvent.isChangeSetEvent(event)) { + this._onDidChangeEmitter.fire(event); + } + }, this, this.toDispose); requestMap.set(requestModel.id, requestModel); } @@ -901,6 +941,67 @@ export class MutableChatModel implements ChatModel, Disposable { return requestModel; } + /** + * Insert a summary into the model. + * Handles stale marking for older requests. + * + * Note: Only 'end' position is currently supported. The 'beforeLast' mode was removed + * because reordering breaks the hierarchy structure - the summary gets added as a + * continuation of the trigger request, and removing the trigger loses the branch connection. + * + * @param summaryCallback - Callback that creates the summary request via ChatService.sendRequest() + * and returns the requestId and summaryText, or undefined on failure. + * @param position - Currently only 'end' is supported (appends summary at end) + * @returns The summary text on success, or undefined on failure + */ + async insertSummary( + summaryCallback: () => Promise<{ requestId: string; summaryText: string } | undefined>, + position: 'end' + ): Promise { + const allRequests = this.getRequests(); + + // Need at least 2 requests to summarize + if (allRequests.length < 2) { + return undefined; + } + + // For 'end' position (between-turn), all existing requests are summarized and should be marked stale + // For other positions, preserve the most recent exchange + const requestToPreserve = position === 'end' ? undefined : allRequests[allRequests.length - 1]; + + // Identify which requests will be marked stale after successful summarization + const requestsToMarkStale = allRequests.filter(r => !r.isStale && r !== requestToPreserve); + + // Call the callback to create the summary request and invoke the agent + // NOTE: Stale marking happens AFTER the callback so the summary agent can see all messages + let result: { requestId: string; summaryText: string } | undefined; + try { + result = await summaryCallback(); + } catch (error) { + result = undefined; + } + + if (!result) { + return undefined; + } + + const { requestId, summaryText } = result; + + // Find the created summary request using findRequest (handles hierarchy search) + const summaryRequest = this._hierarchy.findRequest(requestId); + if (!summaryRequest) { + // Request not found - treat as failure + return undefined; + } + + // Mark older requests as stale + for (const request of requestsToMarkStale) { + request.isStale = true; + } + + return summaryText; + } + protected getTargetForRequestAddition(request: ParsedChatRequest): (addendum: MutableChatRequestModel) => void { const requestId = request.request.referencedRequestId; const branch = requestId !== undefined && this._hierarchy.findBranch(requestId); @@ -1199,7 +1300,7 @@ export class ChatRequestHierarchyImpl[] { - return Array.from(this.iterateBranches()); + return Array.from(this.iterateBranches()).filter(b => b.items.length > 0); } protected *iterateBranches(): Generator> { @@ -1313,6 +1414,9 @@ export class ChatRequestHierarchyBranchImpl i } get(): TRequest { + if (this.items.length === 0 || this._activeIndex < 0) { + throw new Error('Cannot get request from empty branch'); + } return this.items[this.activeBranchIndex].element; } @@ -1329,8 +1433,12 @@ export class ChatRequestHierarchyBranchImpl i const index = this.items.findIndex(version => version.element.id === requestId); if (index !== -1) { this.items.splice(index, 1); - if (this.activeBranchIndex >= index) { - this.activeBranchIndex--; + if (this.items.length === 0) { + // Branch is now empty - set directly without firing event + this._activeIndex = -1; + } else if (this._activeIndex >= index) { + // Branch still has items - use setter to fire change event + this.activeBranchIndex = Math.max(0, this._activeIndex - 1); } } } @@ -1390,7 +1498,8 @@ export class ChatRequestHierarchyBranchImpl i } dispose(): void { - if (Disposable.is(this.get())) { + // Dispose all items if they are disposable (check first item, not get() which throws on empty) + if (this.items.length > 0 && Disposable.is(this.items[0].element)) { this.items.forEach(({ element }) => Disposable.is(element) && element.dispose()); } this.items.length = 0; @@ -1471,6 +1580,7 @@ export class MutableChatRequestModel implements ChatRequestModel, EditableChatRe protected _agentId?: string; protected _data: { [key: string]: unknown }; protected _isEditing = false; + protected _isStale = false; readonly message: ParsedChatRequest; protected readonly toDispose = new DisposableCollection(); @@ -1527,9 +1637,14 @@ export class MutableChatRequestModel implements ChatRequestModel, EditableChatRe respData?: SerializableChatResponseData ): void { this._id = reqData.id; - this._request = { text: reqData.text }; + this._request = { + text: reqData.text, + kind: reqData.kind // undefined for old sessions = default 'user' behavior + }; this._agentId = reqData.agentId; this._data = {}; + // Restore isStale flag, defaulting to false for backward compatibility + this._isStale = reqData.isStale ?? false; // Create minimal context this._context = { variables: [] }; @@ -1619,6 +1734,14 @@ export class MutableChatRequestModel implements ChatRequestModel, EditableChatRe return this._agentId; } + get isStale(): boolean { + return this._isStale; + } + + set isStale(value: boolean) { + this._isStale = value; + } + cancelEdit(): void { if (this.isEditing) { this._isEditing = false; @@ -1648,15 +1771,22 @@ export class MutableChatRequestModel implements ChatRequestModel, EditableChatRe } toSerializable(): SerializableChatRequestData { - return { + const result: SerializableChatRequestData = { id: this.id, text: this.request.text, agentId: this.agentId, + // Only include kind when not default to minimize payload + kind: this.request.kind !== 'user' ? this.request.kind : undefined, changeSet: this._changeSet ? { title: this._changeSet.title, elements: this._changeSet.getElements().map(elem => elem.toSerializable?.()).filter((elem): elem is SerializableChangeSetElement => elem !== undefined) } : undefined }; + // Only include isStale when true to minimize payload + if (this._isStale) { + result.isStale = true; + } + return result; } dispose(): void { @@ -2605,6 +2735,47 @@ export class ProgressChatResponseContentImpl implements ProgressChatResponseCont } } +/** + * Implementation of SummaryChatResponseContent. + * Represents a summary of previous chat messages that exceeded token limits. + * The summary content is included in prompts to maintain conversation context. + */ +export class SummaryChatResponseContentImpl implements SummaryChatResponseContent { + readonly kind = 'summary'; + protected _content: string; + + constructor(content: string) { + this._content = content; + } + + get content(): string { + return this._content; + } + + asString(): string { + return this._content; + } + + asDisplayString(): string | undefined { + return this._content; + } + + toLanguageModelMessage(): TextMessage { + return { + actor: 'ai', + type: 'text', + text: `[Summary of previous conversation]\n${this._content}` + }; + } + + toSerializable(): SerializableChatResponseContentData { + return { + kind: 'summary', + data: { content: this._content } + }; + } +} + /** * Fallback content for unknown content types. * Used when a deserializer is not available (e.g., content from removed extension). diff --git a/packages/ai-chat/src/common/chat-request-parser.spec.ts b/packages/ai-chat/src/common/chat-request-parser.spec.ts index bf1f76e7e5ae8..d2f9c3d5225f7 100644 --- a/packages/ai-chat/src/common/chat-request-parser.spec.ts +++ b/packages/ai-chat/src/common/chat-request-parser.spec.ts @@ -20,19 +20,24 @@ import { ChatRequestParserImpl } from './chat-request-parser'; import { ChatAgentLocation } from './chat-agents'; import { ChatContext, ChatRequest } from './chat-model'; import { expect } from 'chai'; -import { AIVariable, DefaultAIVariableService, ResolvedAIVariable, ToolInvocationRegistryImpl, ToolRequest } from '@theia/ai-core'; +import { AIVariable, AIVariableService, ResolvedAIVariable, ToolInvocationRegistry, ToolRequest } from '@theia/ai-core'; import { ILogger, Logger } from '@theia/core'; import { ParsedChatRequestTextPart, ParsedChatRequestVariablePart } from './parsed-chat-request'; describe('ChatRequestParserImpl', () => { const chatAgentService = sinon.createStubInstance(ChatAgentServiceImpl); - const variableService = sinon.createStubInstance(DefaultAIVariableService); - const toolInvocationRegistry = sinon.createStubInstance(ToolInvocationRegistryImpl); + const variableService = { + getVariable: sinon.stub(), + resolveVariable: sinon.stub() + } as unknown as sinon.SinonStubbedInstance; + const toolInvocationRegistry = { + getFunction: sinon.stub() + } as unknown as sinon.SinonStubbedInstance; const logger: ILogger = sinon.createStubInstance(Logger); const parser = new ChatRequestParserImpl(chatAgentService, variableService, toolInvocationRegistry, logger); beforeEach(() => { - // Reset our stubs before each test + // Reset all stubs before each test to ensure test isolation sinon.reset(); }); diff --git a/packages/ai-chat/src/common/chat-service.ts b/packages/ai-chat/src/common/chat-service.ts index 143d17de20424..8654b1836fbbf 100644 --- a/packages/ai-chat/src/common/chat-service.ts +++ b/packages/ai-chat/src/common/chat-service.ts @@ -41,7 +41,7 @@ import { import { ChatRequestParser } from './chat-request-parser'; import { ChatSessionNamingService } from './chat-session-naming-service'; import { ParsedChatRequest, ParsedChatRequestAgentPart } from './parsed-chat-request'; -import { ChatSessionIndex, ChatSessionStore } from './chat-session-store'; +import { ChatModelWithMetadata, ChatSessionIndex, ChatSessionStore } from './chat-session-store'; import { ChatContentDeserializerRegistry } from './chat-content-deserializer'; import { ChangeSetDeserializationContext, ChangeSetElementDeserializerRegistry } from './change-set-element-deserializer'; import { SerializableChangeSetElement, SerializedChatModel } from './chat-model-serialization'; @@ -85,6 +85,8 @@ export function isActiveSessionChangedEvent(obj: unknown): obj is ActiveSessionC export interface SessionCreatedEvent { type: 'created'; sessionId: string; + tokenCount?: number; + branchTokens?: { [branchId: string]: number }; } export function isSessionCreatedEvent(obj: unknown): obj is SessionCreatedEvent { @@ -274,7 +276,11 @@ export class ChatServiceImpl implements ChatService { return undefined; } - this.cancelIncompleteRequests(session); + // Don't cancel incomplete requests for summary requests - they run concurrently + // with the active request during mid-turn summarization + if (request.kind !== 'summary') { + this.cancelIncompleteRequests(session); + } const resolutionContext: ChatSessionContext = { model: session.model }; const resolvedContext = await this.resolveChatContext(request.variables ?? session.model.context.getVariables(), resolutionContext); @@ -315,7 +321,11 @@ export class ChatServiceImpl implements ChatService { } }); - agent.invoke(requestModel).catch(error => requestModel.response.error(error)); + // Don't invoke agent for continuation requests - they are containers for moved tool results + // The tool loop handles completing the response externally + if (request.kind !== 'continuation') { + agent.invoke(requestModel).catch(error => requestModel.response.error(error)); + } return invocation; } @@ -329,6 +339,9 @@ export class ChatServiceImpl implements ChatService { } protected updateSessionMetadata(session: ChatSessionInternal, request: MutableChatRequestModel): void { + if (request.request.kind === 'summary') { + return; + } session.lastInteraction = new Date(); if (session.title) { return; @@ -450,9 +463,12 @@ export class ChatServiceImpl implements ChatService { } // Store session with title and pinned agent info - this.sessionStore.storeSessions( - { model: session.model, title: session.title, pinnedAgentId: session.pinnedAgent?.id } - ).catch(error => { + const sessionData: ChatModelWithMetadata = { + model: session.model, + title: session.title, + pinnedAgentId: session.pinnedAgent?.id + }; + this.sessionStore.storeSessions(sessionData).catch(error => { this.logger.error('Failed to store chat sessions', error); }); } @@ -513,7 +529,12 @@ export class ChatServiceImpl implements ChatService { }; this._sessions.push(session); this.setupAutoSaveForSession(session); - this.onSessionEventEmitter.fire({ type: 'created', sessionId: session.id }); + this.onSessionEventEmitter.fire({ + type: 'created', + sessionId: session.id, + tokenCount: serialized.lastInputTokens, + branchTokens: serialized.branchTokens + }); this.logger.debug('Session successfully restored and registered', { sessionId, title: session.title }); diff --git a/packages/ai-chat/src/common/chat-session-store.ts b/packages/ai-chat/src/common/chat-session-store.ts index 719bd6502001b..ea84e5287f870 100644 --- a/packages/ai-chat/src/common/chat-session-store.ts +++ b/packages/ai-chat/src/common/chat-session-store.ts @@ -24,6 +24,8 @@ export interface ChatModelWithMetadata { model: ChatModel; title?: string; pinnedAgentId?: string; + lastInputTokens?: number; + branchTokens?: { [branchId: string]: number }; } export interface ChatSessionStore { diff --git a/packages/ai-chat/src/common/chat-session-summary-agent.ts b/packages/ai-chat/src/common/chat-session-summary-agent.ts index 298a27aea0ccc..67d001ab75544 100644 --- a/packages/ai-chat/src/common/chat-session-summary-agent.ts +++ b/packages/ai-chat/src/common/chat-session-summary-agent.ts @@ -31,6 +31,7 @@ export class ChatSessionSummaryAgent extends AbstractStreamParsingChatAgent impl override description = nls.localize('theia/ai/chat/chatSessionSummaryAgent/description', 'Agent for generating chat session summaries.'); override prompts: PromptVariantSet[] = [CHAT_SESSION_SUMMARY_PROMPT]; protected readonly defaultLanguageModelPurpose = 'chat-session-summary'; + protected override systemPromptId: string = CHAT_SESSION_SUMMARY_PROMPT.id; languageModelRequirements: LanguageModelRequirement[] = [{ purpose: 'chat-session-summary', identifier: 'default/summarize', diff --git a/packages/ai-chat/src/common/chat-session-token-tracker.ts b/packages/ai-chat/src/common/chat-session-token-tracker.ts new file mode 100644 index 0000000000000..c67558d46da0f --- /dev/null +++ b/packages/ai-chat/src/common/chat-session-token-tracker.ts @@ -0,0 +1,128 @@ +// ***************************************************************************** +// Copyright (C) 2025 EclipseSource GmbH. +// +// This program and the accompanying materials are made available under the +// terms of the Eclipse Public License v. 2.0 which is available at +// http://www.eclipse.org/legal/epl-2.0. +// +// This Source Code may also be made available under the following Secondary +// Licenses when the conditions for such availability set forth in the Eclipse +// Public License v. 2.0 are satisfied: GNU General Public License, version 2 +// with the GNU Classpath Exception which is available at +// https://www.gnu.org/software/classpath/license.html. +// +// SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-only WITH Classpath-exception-2.0 +// ***************************************************************************** + +import { Event } from '@theia/core'; + +/** + * Event fired when a session's token count is updated. + */ +export interface SessionTokenUpdateEvent { + sessionId: string; + /** + * The input token count for the active branch. + * - `number`: Known token count from the most recent LLM response + * - `undefined`: Unknown/not yet measured (branch has never had an LLM request; do NOT coerce to 0) + */ + inputTokens: number | undefined; + /** + * The output token count for the active branch. + * - `number`: Known token count from the most recent LLM response (updated progressively during streaming) + * - `undefined`: Unknown/not yet measured + */ + outputTokens: number | undefined; +} + +export const ChatSessionTokenTracker = Symbol('ChatSessionTokenTracker'); + +/** + * Service that tracks token usage per chat session. + * + * Listens to token usage updates from the backend and correlates them with + * chat sessions via requestId. When a session's input tokens exceed the + * threshold (90% of 200k), it emits an event for summarization. + */ +export interface ChatSessionTokenTracker { + /** + * Event fired when a session's token count is updated. + */ + readonly onSessionTokensUpdated: Event; + + /** + * Get the latest input token count for a session. + * Returns the inputTokens from the most recent request in the session. + */ + getSessionInputTokens(sessionId: string): number | undefined; + + /** + * Get the latest output token count for a session. + * Returns the outputTokens from the most recent request in the session. + */ + getSessionOutputTokens(sessionId: string): number | undefined; + + /** + * Get the total token count for a session (input + output). + * Returns the sum of input and output tokens, representing current context window usage. + * Returns undefined if neither input nor output tokens are known. + */ + getSessionTotalTokens(sessionId: string): number | undefined; + + /** + * Reset the session's token count to a new baseline. + * Called after summarization to reflect the reduced token usage. + * + * @param sessionId - The session ID to reset + * @param newTokenCount - The new token count, or `undefined` to indicate unknown state. + * When `undefined`, deletes the stored count and emits `{ inputTokens: undefined, outputTokens: undefined }`. + */ + resetSessionTokens(sessionId: string, newTokenCount: number | undefined): void; + + /** + * Update session tokens with input and/or output token counts. + * Called when token usage information is received from the LLM. + * + * @param sessionId - The session ID to update + * @param inputTokens - The input token count (sets new baseline, resets output to 0) + * @param outputTokens - The output token count (updated progressively during streaming) + */ + updateSessionTokens(sessionId: string, inputTokens?: number, outputTokens?: number): void; + + /** + * Store token count for a specific branch. + * @param sessionId - The session ID + * @param branchId - The branch ID + * @param tokens - The token count + */ + setBranchTokens(sessionId: string, branchId: string, tokens: number): void; + + /** + * Get token count for a specific branch. + * @param sessionId - The session ID + * @param branchId - The branch ID + * @returns The token count, or undefined if not tracked + */ + getBranchTokens(sessionId: string, branchId: string): number | undefined; + + /** + * Get all branch token counts for a session. + * @param sessionId - The session ID + * @returns Object with branchId keys and token count values, or empty object if no data + */ + getBranchTokensForSession(sessionId: string): { [branchId: string]: number }; + + /** + * Restore branch tokens from persisted data. + * @param sessionId - The session ID + * @param branchTokens - Object with branchId keys and token count values + */ + restoreBranchTokens(sessionId: string, branchTokens: { [branchId: string]: number }): void; + + /** + * Clear all branch token data for a session. + * Called when a session is deleted. + * @param sessionId - The session ID + */ + clearSessionBranchTokens(sessionId: string): void; +} diff --git a/packages/ai-chat/src/common/index.ts b/packages/ai-chat/src/common/index.ts index be96415422ba6..7486de135c9a5 100644 --- a/packages/ai-chat/src/common/index.ts +++ b/packages/ai-chat/src/common/index.ts @@ -23,6 +23,7 @@ export * from './chat-model-util'; export * from './chat-request-parser'; export * from './chat-service'; export * from './chat-session-store'; +export * from './chat-session-token-tracker'; export * from './custom-chat-agent'; export * from './parsed-chat-request'; export * from './context-variables'; diff --git a/packages/ai-core/src/browser/frontend-language-model-service.ts b/packages/ai-core/src/browser/frontend-language-model-service.ts index 087ed8aafea9a..7b6bba2a13a77 100644 --- a/packages/ai-core/src/browser/frontend-language-model-service.ts +++ b/packages/ai-core/src/browser/frontend-language-model-service.ts @@ -31,30 +31,12 @@ export class FrontendLanguageModelServiceImpl extends LanguageModelServiceImpl { languageModel: LanguageModel, languageModelRequest: UserRequest ): Promise { - const requestSettings = this.preferenceService.get(PREFERENCE_NAME_REQUEST_SETTINGS, []); - - const ids = languageModel.id.split('/'); - const matchingSetting = mergeRequestSettings(requestSettings, ids[1], ids[0], languageModelRequest.agentId); - if (matchingSetting?.requestSettings) { - // Merge the settings, with user request taking precedence - languageModelRequest.settings = { - ...matchingSetting.requestSettings, - ...languageModelRequest.settings - }; - } - if (matchingSetting?.clientSettings) { - // Merge the clientSettings, with user request taking precedence - languageModelRequest.clientSettings = { - ...matchingSetting.clientSettings, - ...languageModelRequest.clientSettings - }; - } - + applyRequestSettings(languageModelRequest, languageModel.id, languageModelRequest.agentId, this.preferenceService); return super.sendRequest(languageModel, languageModelRequest); } } -export const mergeRequestSettings = (requestSettings: RequestSetting[], modelId: string, providerId: string, agentId?: string): RequestSetting => { +const mergeRequestSettings = (requestSettings: RequestSetting[], modelId: string, providerId: string, agentId?: string): RequestSetting => { const prioritizedSettings = Prioritizeable.prioritizeAllSync(requestSettings, setting => getRequestSettingSpecificity(setting, { modelId, @@ -65,3 +47,31 @@ export const mergeRequestSettings = (requestSettings: RequestSetting[], modelId: const matchingSetting = prioritizedSettings.reduceRight((acc, cur) => ({ ...acc, ...cur.value }), {} as RequestSetting); return matchingSetting; }; + +/** + * Apply request settings from preferences to a user request. + * Merges settings based on model ID, provider ID, and agent ID specificity. + */ +export const applyRequestSettings = ( + request: UserRequest, + languageModelId: string, + agentId: string | undefined, + preferenceService: PreferenceService +): void => { + const requestSettings = preferenceService.get(PREFERENCE_NAME_REQUEST_SETTINGS, []); + const ids = languageModelId.split('/'); + const matchingSetting = mergeRequestSettings(requestSettings, ids[1], ids[0], agentId); + + if (matchingSetting?.requestSettings) { + request.settings = { + ...matchingSetting.requestSettings, + ...request.settings + }; + } + if (matchingSetting?.clientSettings) { + request.clientSettings = { + ...matchingSetting.clientSettings, + ...request.clientSettings + }; + } +}; diff --git a/packages/ai-core/src/common/language-model.ts b/packages/ai-core/src/common/language-model.ts index 1f9b87b3a52ca..e8aa9453dd829 100644 --- a/packages/ai-core/src/common/language-model.ts +++ b/packages/ai-core/src/common/language-model.ts @@ -167,7 +167,20 @@ export interface LanguageModelRequest { tools?: ToolRequest[]; response_format?: { type: 'text' } | { type: 'json_object' } | ResponseFormatJsonSchema; settings?: { [key: string]: unknown }; - clientSettings?: { keepToolCalls: boolean; keepThinking: boolean } + clientSettings?: { keepToolCalls: boolean; keepThinking: boolean }; + /** + * If true, the model should return after the first LLM response + * without executing tool calls. The caller handles tool execution + * and continuation. + * + * Models that don't support this property ignore it and handle + * the tool loop internally (current behavior). This allows gradual + * migration - once all models support it, the old tool loop code + * can be removed. + * + * Default: false (model handles tool loop internally). + */ + singleRoundTrip?: boolean; } export interface ResponseFormatJsonSchema { type: 'json_schema'; @@ -222,6 +235,10 @@ export const isLanguageModelStreamResponsePart = (part: unknown): part is Langua export interface UsageResponsePart { input_tokens: number; output_tokens: number; + /** Input tokens written to cache (Anthropic-specific) */ + cache_creation_input_tokens?: number; + /** Input tokens read from cache (Anthropic-specific) */ + cache_read_input_tokens?: number; } export const isUsageResponsePart = (part: unknown): part is UsageResponsePart => !!(part && typeof part === 'object' && diff --git a/packages/ai-core/src/common/token-usage-service.ts b/packages/ai-core/src/common/token-usage-service.ts index 5c8b7211c63b0..c67252912ec68 100644 --- a/packages/ai-core/src/common/token-usage-service.ts +++ b/packages/ai-core/src/common/token-usage-service.ts @@ -33,6 +33,8 @@ export interface TokenUsage { timestamp: Date; /** Request identifier */ requestId: string; + /** Session identifier */ + sessionId?: string; } export interface TokenUsageParams { @@ -46,6 +48,8 @@ export interface TokenUsageParams { readCachedInputTokens?: number; /** Request identifier */ requestId: string; + /** Session identifier */ + sessionId?: string; } export interface TokenUsageService { diff --git a/packages/ai-core/src/node/token-usage-service-impl.ts b/packages/ai-core/src/node/token-usage-service-impl.ts index 8dad2ac1ae156..c68a17f07549a 100644 --- a/packages/ai-core/src/node/token-usage-service-impl.ts +++ b/packages/ai-core/src/node/token-usage-service-impl.ts @@ -46,7 +46,8 @@ export class TokenUsageServiceImpl implements TokenUsageService { outputTokens: params.outputTokens, model, timestamp: new Date(), - requestId: params.requestId + requestId: params.requestId, + sessionId: params.sessionId }; this.tokenUsages.push(usage); diff --git a/packages/ai-google/src/node/google-language-model.ts b/packages/ai-google/src/node/google-language-model.ts index e8c986d871e35..4b5a35c6dc0ca 100644 --- a/packages/ai-google/src/node/google-language-model.ts +++ b/packages/ai-google/src/node/google-language-model.ts @@ -309,16 +309,15 @@ export class GoogleModel implements LanguageModel { yield { content: chunk.text }; } - // Report token usage if available - if (chunk.usageMetadata && that.tokenUsageService && that.id) { + // Yield usage data when available + if (chunk.usageMetadata) { const promptTokens = chunk.usageMetadata.promptTokenCount; const completionTokens = chunk.usageMetadata.candidatesTokenCount; - if (promptTokens && completionTokens) { - that.tokenUsageService.recordTokenUsage(that.id, { - inputTokens: promptTokens, - outputTokens: completionTokens, - requestId: request.requestId - }).catch(error => console.error('Error recording token usage:', error)); + if (promptTokens !== undefined && completionTokens !== undefined) { + yield { + input_tokens: promptTokens, + output_tokens: completionTokens + }; } } } @@ -444,18 +443,6 @@ export class GoogleModel implements LanguageModel { responseText = model.text ?? ''; } - // Record token usage if available - if (model.usageMetadata && this.tokenUsageService) { - const promptTokens = model.usageMetadata.promptTokenCount; - const completionTokens = model.usageMetadata.candidatesTokenCount; - if (promptTokens && completionTokens) { - await this.tokenUsageService.recordTokenUsage(this.id, { - inputTokens: promptTokens, - outputTokens: completionTokens, - requestId: request.requestId - }); - } - } return { text: responseText }; } catch (error) { throw new Error(`Failed to get response from Gemini API: ${error instanceof Error ? error.message : 'Unknown error'}`); diff --git a/packages/ai-history/src/browser/ai-history-exchange-card.tsx b/packages/ai-history/src/browser/ai-history-exchange-card.tsx index f97a02e9f3cd1..832ed52338e61 100644 --- a/packages/ai-history/src/browser/ai-history-exchange-card.tsx +++ b/packages/ai-history/src/browser/ai-history-exchange-card.tsx @@ -29,6 +29,22 @@ const getTextFromResponse = (response: LanguageModelExchangeRequestResponse): st for (const chunk of response.parts) { if ('content' in chunk && chunk.content) { result += chunk.content; + } else if ('tool_calls' in chunk && Array.isArray(chunk.tool_calls)) { + // Format tool calls for display - only show finished tool calls to avoid duplicates + for (const toolCall of chunk.tool_calls) { + if (toolCall.finished && toolCall.function?.name) { + result += `[Tool Call: ${toolCall.function.name}`; + if (toolCall.function.arguments) { + // Truncate long arguments for readability + const args = toolCall.function.arguments; + const truncatedArgs = args.length > 100 ? args.substring(0, 100) + '...' : args; + result += `(${truncatedArgs})`; + } + result += ']\n'; + } + } + } else if ('thought' in chunk && chunk.thought) { + result += `[Thinking: ${chunk.thought.substring(0, 100)}...]\n`; } } return result; diff --git a/packages/ai-ide/src/common/orchestrator-chat-agent.ts b/packages/ai-ide/src/common/orchestrator-chat-agent.ts index 93731af3c165e..852f389052ba4 100644 --- a/packages/ai-ide/src/common/orchestrator-chat-agent.ts +++ b/packages/ai-ide/src/common/orchestrator-chat-agent.ts @@ -14,7 +14,16 @@ // SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-only WITH Classpath-exception-2.0 // ***************************************************************************** -import { AIVariableContext, getJsonOfText, getTextOfResponse, LanguageModel, LanguageModelMessage, LanguageModelRequirement, LanguageModelResponse } from '@theia/ai-core'; +import { + AIVariableContext, + getJsonOfText, + getTextOfResponse, + LanguageModel, + LanguageModelMessage, + LanguageModelRequirement, + LanguageModelResponse, + UsageResponsePart +} from '@theia/ai-core'; import { inject, injectable } from '@theia/core/shared/inversify'; import { ChatAgentService } from '@theia/ai-chat/lib/common/chat-agent-service'; import { ChatToolRequest } from '@theia/ai-chat/lib/common/chat-tool-request-service'; @@ -126,7 +135,7 @@ export class OrchestratorChatAgent extends AbstractStreamParsingChatAgent { ); } - protected override async addContentsToResponse(response: LanguageModelResponse, request: MutableChatRequestModel): Promise { + protected override async addContentsToResponse(response: LanguageModelResponse, request: MutableChatRequestModel): Promise { const responseText = await getTextOfResponse(response); let agentIds: string[] = []; @@ -184,6 +193,9 @@ export class OrchestratorChatAgent extends AbstractStreamParsingChatAgent { // Get the original request if available const originalRequest = '__originalRequest' in request ? request.__originalRequest as MutableChatRequestModel : request; await agent.invoke(originalRequest); + + // Orchestrator delegates to another agent, no usage data to return + return undefined; } } diff --git a/packages/ai-ollama/src/node/ollama-language-model.ts b/packages/ai-ollama/src/node/ollama-language-model.ts index ab17b810ca97f..1a48cda94ee91 100644 --- a/packages/ai-ollama/src/node/ollama-language-model.ts +++ b/packages/ai-ollama/src/node/ollama-language-model.ts @@ -17,7 +17,6 @@ import { LanguageModel, LanguageModelParsedResponse, - LanguageModelRequest, LanguageModelMessage, LanguageModelResponse, LanguageModelStreamResponse, @@ -27,10 +26,11 @@ import { ToolRequestParametersProperties, ImageContent, TokenUsageService, - LanguageModelStatus + LanguageModelStatus, + UserRequest } from '@theia/ai-core'; import { CancellationToken } from '@theia/core'; -import { ChatRequest, Message, Ollama, Options, Tool, ToolCall as OllamaToolCall, ChatResponse } from 'ollama'; +import { ChatRequest, Message, Ollama, Options, Tool, ToolCall as OllamaToolCall } from 'ollama'; export const OllamaModelIdentifier = Symbol('OllamaModelIdentifier'); @@ -58,7 +58,7 @@ export class OllamaModel implements LanguageModel { protected readonly tokenUsageService?: TokenUsageService ) { } - async request(request: LanguageModelRequest, cancellationToken?: CancellationToken): Promise { + async request(request: UserRequest, cancellationToken?: CancellationToken): Promise { const settings = this.getSettings(request); const ollama = this.initializeOllama(); const stream = !(request.settings?.stream === false); // true by default, false only if explicitly specified @@ -71,7 +71,8 @@ export class OllamaModel implements LanguageModel { stream }; const structured = request.response_format?.type === 'json_schema'; - return this.dispatchRequest(ollama, ollamaRequest, structured, cancellationToken); + const sessionId = request.sessionId; + return this.dispatchRequest(ollama, ollamaRequest, structured, sessionId, cancellationToken); } /** @@ -79,14 +80,20 @@ export class OllamaModel implements LanguageModel { * @param request The language model request containing specific settings. * @returns A partial ChatRequest object containing the merged settings. */ - protected getSettings(request: LanguageModelRequest): Partial { + protected getSettings(request: UserRequest): Partial { const settings = request.settings ?? {}; return { options: settings as Partial }; } - protected async dispatchRequest(ollama: Ollama, ollamaRequest: ExtendedChatRequest, structured: boolean, cancellation?: CancellationToken): Promise { + protected async dispatchRequest( + ollama: Ollama, + ollamaRequest: ExtendedChatRequest, + structured: boolean, + sessionId?: string, + cancellation?: CancellationToken + ): Promise { // Handle structured output request if (structured) { @@ -95,14 +102,19 @@ export class OllamaModel implements LanguageModel { if (isNonStreaming(ollamaRequest)) { // handle non-streaming request - return this.handleNonStreamingRequest(ollama, ollamaRequest, cancellation); + return this.handleNonStreamingRequest(ollama, ollamaRequest, sessionId, cancellation); } // handle streaming request - return this.handleStreamingRequest(ollama, ollamaRequest, cancellation); + return this.handleStreamingRequest(ollama, ollamaRequest, sessionId, cancellation); } - protected async handleStreamingRequest(ollama: Ollama, chatRequest: ExtendedChatRequest, cancellation?: CancellationToken): Promise { + protected async handleStreamingRequest( + ollama: Ollama, + chatRequest: ExtendedChatRequest, + sessionId?: string, + cancellation?: CancellationToken + ): Promise { const responseStream = await ollama.chat({ ...chatRequest, stream: true, @@ -146,7 +158,13 @@ export class OllamaModel implements LanguageModel { } if (chunk.done) { - that.recordTokenUsage(chunk); + // Yield usage data when stream completes + if (chunk.prompt_eval_count !== undefined && chunk.eval_count !== undefined) { + yield { + input_tokens: chunk.prompt_eval_count, + output_tokens: chunk.eval_count + }; + } if (chunk.done_reason && chunk.done_reason !== 'stop') { throw new Error('Ollama stopped unexpectedly. Reason: ' + chunk.done_reason); @@ -169,6 +187,7 @@ export class OllamaModel implements LanguageModel { const continuedResponse = await that.handleStreamingRequest( ollama, chatRequest, + sessionId, cancellation ); @@ -222,7 +241,12 @@ export class OllamaModel implements LanguageModel { } } - protected async handleNonStreamingRequest(ollama: Ollama, chatRequest: ExtendedNonStreamingChatRequest, cancellation?: CancellationToken): Promise { + protected async handleNonStreamingRequest( + ollama: Ollama, + chatRequest: ExtendedNonStreamingChatRequest, + sessionId?: string, + cancellation?: CancellationToken + ): Promise { try { // even though we have a non-streaming request, we still use the streaming version for two reasons: // 1. we can abort the stream if the request is cancelled instead of having to wait for the entire response @@ -249,9 +273,8 @@ export class OllamaModel implements LanguageModel { toolCalls.push(...chunk.message.tool_calls); } - // if the response is done, record the token usage and check the done reason + // if the response is done, check the done reason if (chunk.done) { - this.recordTokenUsage(chunk); lastUpdated = chunk.created_at; if (chunk.done_reason && chunk.done_reason !== 'stop') { throw new Error('Ollama stopped unexpectedly. Reason: ' + chunk.done_reason); @@ -273,7 +296,7 @@ export class OllamaModel implements LanguageModel { } // recurse to get the final response content (the intermediate content remains hidden, it is only part of the conversation) - return this.handleNonStreamingRequest(ollama, chatRequest); + return this.handleNonStreamingRequest(ollama, chatRequest, sessionId); } // if no tool calls are necessary, return the final response content @@ -315,16 +338,6 @@ export class OllamaModel implements LanguageModel { return toolCallsForResponse; } - private recordTokenUsage(response: ChatResponse): void { - if (this.tokenUsageService && response.prompt_eval_count && response.eval_count) { - this.tokenUsageService.recordTokenUsage(this.id, { - inputTokens: response.prompt_eval_count, - outputTokens: response.eval_count, - requestId: `ollama_${response.created_at}` - }).catch(error => console.error('Error recording token usage:', error)); - } - } - protected initializeOllama(): Ollama { const host = this.host(); if (!host) { diff --git a/packages/ai-openai/src/node/openai-language-model.ts b/packages/ai-openai/src/node/openai-language-model.ts index 408c599446332..802b0d5fcbd69 100644 --- a/packages/ai-openai/src/node/openai-language-model.ts +++ b/packages/ai-openai/src/node/openai-language-model.ts @@ -161,7 +161,7 @@ export class OpenAiModel implements LanguageModel { }); } - return { stream: new StreamingAsyncIterator(runner, request.requestId, cancellationToken, this.tokenUsageService, this.id) }; + return { stream: new StreamingAsyncIterator(runner, request.requestId, request.sessionId, cancellationToken, this.tokenUsageService, this.id) }; } protected async handleNonStreamingRequest(openai: OpenAI, request: UserRequest): Promise { @@ -174,18 +174,6 @@ export class OpenAiModel implements LanguageModel { const message = response.choices[0].message; - // Record token usage if token usage service is available - if (this.tokenUsageService && response.usage) { - await this.tokenUsageService.recordTokenUsage( - this.id, - { - inputTokens: response.usage.prompt_tokens, - outputTokens: response.usage.completion_tokens, - requestId: request.requestId - } - ); - } - return { text: message.content ?? '' }; @@ -209,18 +197,6 @@ export class OpenAiModel implements LanguageModel { console.error('Error in OpenAI chat completion stream:', JSON.stringify(message)); } - // Record token usage if token usage service is available - if (this.tokenUsageService && result.usage) { - await this.tokenUsageService.recordTokenUsage( - this.id, - { - inputTokens: result.usage.prompt_tokens, - outputTokens: result.usage.completion_tokens, - requestId: request.requestId - } - ); - } - return { content: message.content ?? '', parsed: message.parsed diff --git a/packages/ai-openai/src/node/openai-response-api-utils.ts b/packages/ai-openai/src/node/openai-response-api-utils.ts index 6f22f2a5a302b..6a92fc67b3069 100644 --- a/packages/ai-openai/src/node/openai-response-api-utils.ts +++ b/packages/ai-openai/src/node/openai-response-api-utils.ts @@ -95,7 +95,7 @@ export class OpenAiResponseApiUtils { input, ...settings }); - return { stream: this.createSimpleResponseApiStreamIterator(stream, request.requestId, modelId, tokenUsageService, cancellationToken) }; + return { stream: this.createSimpleResponseApiStreamIterator(stream, request.requestId, request.sessionId, modelId, tokenUsageService, cancellationToken) }; } else { const response = await openai.responses.create({ model: model as ResponsesModel, @@ -104,18 +104,6 @@ export class OpenAiResponseApiUtils { ...settings }); - // Record token usage if available - if (tokenUsageService && response.usage) { - await tokenUsageService.recordTokenUsage( - modelId, - { - inputTokens: response.usage.input_tokens, - outputTokens: response.usage.output_tokens, - requestId: request.requestId - } - ); - } - return { text: response.output_text || '' }; } } @@ -168,6 +156,7 @@ export class OpenAiResponseApiUtils { protected createSimpleResponseApiStreamIterator( stream: AsyncIterable, requestId: string, + sessionId: string, modelId: string, tokenUsageService?: TokenUsageService, cancellationToken?: CancellationToken @@ -185,15 +174,12 @@ export class OpenAiResponseApiUtils { content: event.delta }; } else if (event.type === 'response.completed') { - if (tokenUsageService && event.response?.usage) { - await tokenUsageService.recordTokenUsage( - modelId, - { - inputTokens: event.response.usage.input_tokens, - outputTokens: event.response.usage.output_tokens, - requestId - } - ); + // Yield usage data when response completes + if (event.response?.usage) { + yield { + input_tokens: event.response.usage.input_tokens, + output_tokens: event.response.usage.output_tokens + }; } } else if (event.type === 'error') { console.error('Response API error:', event.message); @@ -320,8 +306,6 @@ class ResponseApiToolCallIterator implements AsyncIterableIterator(); - protected totalInputTokens = 0; - protected totalOutputTokens = 0; protected iteration = 0; protected readonly maxIterations: number; protected readonly tools: FunctionTool[] | undefined; @@ -446,12 +430,6 @@ class ResponseApiToolCallIterator implements AsyncIterableIterator { this.done = true; - // Record final token usage - if (this.tokenUsageService && (this.totalInputTokens > 0 || this.totalOutputTokens > 0)) { - try { - await this.tokenUsageService.recordTokenUsage( - this.modelId, - { - inputTokens: this.totalInputTokens, - outputTokens: this.totalOutputTokens, - requestId: this.request.requestId - } - ); - } catch (error) { - console.error('Error recording token usage:', error); - } - } - // Resolve any outstanding requests if (this.terminalError) { this.requestQueue.forEach(request => request.reject(this.terminalError)); diff --git a/packages/ai-openai/src/node/openai-streaming-iterator.spec.ts b/packages/ai-openai/src/node/openai-streaming-iterator.spec.ts index 66d20acd612f7..1df730a9a10e3 100644 --- a/packages/ai-openai/src/node/openai-streaming-iterator.spec.ts +++ b/packages/ai-openai/src/node/openai-streaming-iterator.spec.ts @@ -45,7 +45,7 @@ describe('StreamingAsyncIterator', () => { }); function createIterator(withCancellationToken = false): StreamingAsyncIterator { - return new StreamingAsyncIterator(mockStream, '', withCancellationToken ? cts.token : undefined); + return new StreamingAsyncIterator(mockStream, '', '', withCancellationToken ? cts.token : undefined); } it('should yield messages in the correct order when consumed immediately', async () => { diff --git a/packages/ai-openai/src/node/openai-streaming-iterator.ts b/packages/ai-openai/src/node/openai-streaming-iterator.ts index 15ca438469cb3..1b94e85d7bd5f 100644 --- a/packages/ai-openai/src/node/openai-streaming-iterator.ts +++ b/packages/ai-openai/src/node/openai-streaming-iterator.ts @@ -14,7 +14,7 @@ // SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-only WITH Classpath-exception-2.0 // ***************************************************************************** -import { LanguageModelStreamResponsePart, TokenUsageService, TokenUsageParams, ToolCallResult, ToolCallTextResult } from '@theia/ai-core'; +import { LanguageModelStreamResponsePart, ToolCallResult, ToolCallTextResult } from '@theia/ai-core'; import { CancellationError, CancellationToken, Disposable, DisposableCollection } from '@theia/core'; import { Deferred } from '@theia/core/lib/common/promise-util'; import { ChatCompletionStream, ChatCompletionStreamEvents } from 'openai/lib/ChatCompletionStream'; @@ -32,8 +32,9 @@ export class StreamingAsyncIterator implements AsyncIterableIterator { @@ -61,18 +62,15 @@ export class StreamingAsyncIterator implements AsyncIterableIterator { - // Handle token usage reporting - if (chunk.usage && this.tokenUsageService && this.model) { + // Yield token usage as UsageResponsePart when available + if (chunk.usage) { const inputTokens = chunk.usage.prompt_tokens || 0; const outputTokens = chunk.usage.completion_tokens || 0; if (inputTokens > 0 || outputTokens > 0) { - const tokenUsageParams: TokenUsageParams = { - inputTokens, - outputTokens, - requestId - }; - this.tokenUsageService.recordTokenUsage(this.model, tokenUsageParams) - .catch(error => console.error('Error recording token usage:', error)); + this.handleIncoming({ + input_tokens: inputTokens, + output_tokens: outputTokens + }); } } // OpenAI API defines the type of a tool_call as optional but fails if it is not set diff --git a/packages/ai-vercel-ai/src/node/vercel-ai-language-model.ts b/packages/ai-vercel-ai/src/node/vercel-ai-language-model.ts index 5711d008ea2e6..e598e41cb3bfe 100644 --- a/packages/ai-vercel-ai/src/node/vercel-ai-language-model.ts +++ b/packages/ai-vercel-ai/src/node/vercel-ai-language-model.ts @@ -33,11 +33,8 @@ import { CancellationToken, Disposable, ILogger } from '@theia/core'; import { CoreMessage, generateObject, - GenerateObjectResult, generateText, - GenerateTextResult, jsonSchema, - StepResult, streamText, TextStreamPart, tool, @@ -267,8 +264,6 @@ export class VercelAiModel implements LanguageModel { ...settings }); - await this.recordTokenUsage(response, request); - return { text: response.text }; } @@ -323,30 +318,12 @@ export class VercelAiModel implements LanguageModel { ...settings }); - await this.recordTokenUsage(response, request); - return { content: JSON.stringify(response.object), parsed: response.object }; } - private async recordTokenUsage( - result: GenerateObjectResult | GenerateTextResult, - request: UserRequest - ): Promise { - if (this.tokenUsageService && !isNaN(result.usage.completionTokens) && !isNaN(result.usage.promptTokens)) { - await this.tokenUsageService.recordTokenUsage( - this.id, - { - inputTokens: result.usage.promptTokens, - outputTokens: result.usage.completionTokens, - requestId: request.requestId - } - ); - } - } - protected async handleStreamingRequest( model: LanguageModelV1, request: UserRequest, @@ -366,15 +343,6 @@ export class VercelAiModel implements LanguageModel { maxRetries: this.maxRetries, toolCallStreaming: true, abortSignal, - onStepFinish: (stepResult: StepResult) => { - if (!isNaN(stepResult.usage.completionTokens) && !isNaN(stepResult.usage.promptTokens)) { - this.tokenUsageService?.recordTokenUsage(this.id, { - inputTokens: stepResult.usage.promptTokens, - outputTokens: stepResult.usage.completionTokens, - requestId: request.requestId - }); - } - }, ...settings });