Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion agents/src/voice/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,21 @@ import { SynthesizeStream, StreamAdapter as TTSStreamAdapter } from '../tts/inde
import type { VAD } from '../vad.js';
import type { AgentActivity } from './agent_activity.js';
import type { AgentSession, TurnDetectionMode } from './agent_session.js';
import type { SpeechHandle } from './speech_handle.js';

export const asyncLocalStorage = new AsyncLocalStorage<{ functionCall?: FunctionCall }>();
export type ActiveToolCall = {
functionCall: FunctionCall;
speechHandles: SpeechHandle[];
};

export function createActiveToolCall(functionCall: FunctionCall): ActiveToolCall {
return {
functionCall,
speechHandles: [],
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We keep track of all generateReply created inside tool call so we can wait them to be scheduled

};
}

export const toolCallContext = new AsyncLocalStorage<ActiveToolCall | undefined>();
export const STOP_RESPONSE_SYMBOL = Symbol('StopResponse');

export class StopResponse extends Error {
Expand Down
30 changes: 19 additions & 11 deletions agents/src/voice/agent_activity.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ import { TTS, type TTSError } from '../tts/tts.js';
import { Future, Task, cancelAndWait, waitFor } from '../utils.js';
import { VAD, type VADEvent } from '../vad.js';
import type { Agent, ModelSettings } from './agent.js';
import { StopResponse, asyncLocalStorage } from './agent.js';
import { StopResponse, toolCallContext } from './agent.js';
import { type AgentSession, type TurnDetectionMode } from './agent_session.js';
import {
AudioRecognition,
Expand Down Expand Up @@ -802,16 +802,18 @@ export class AgentActivity implements RecognitionHooks {
throw new Error('trying to generate reply without an LLM model');
}

const functionCall = asyncLocalStorage.getStore()?.functionCall;
if (toolChoice === undefined && functionCall !== undefined) {
// when generateReply is called inside a tool, set toolChoice to 'none' by default
toolChoice = 'none';
}

const handle = SpeechHandle.create({
allowInterruptions: allowInterruptions ?? this.allowInterruptions,
});

const activeToolCall = toolCallContext.getStore();
if (toolChoice === undefined && activeToolCall !== undefined) {
// when generateReply is called inside a tool, set toolChoice to 'none' by default
toolChoice = 'none';
// add the speech handle to the active tool call
activeToolCall.speechHandles.push(handle);
}

this.agentSession.emit(
AgentSessionEventTypes.SpeechCreated,
createSpeechCreatedEvent({
Expand Down Expand Up @@ -1232,12 +1234,14 @@ export class AgentActivity implements RecognitionHooks {
//TODO(AJS-272): before executing tools, make sure we generated all the text
// (this ensure everything is kept ordered)

const onToolExecutionStarted = (_: FunctionCall) => {
// TODO(brian): handle speech_handle item_added
const onToolExecutionStarted = (f: FunctionCall) => {
speechHandle._itemAdded([f]);
};

const onToolExecutionCompleted = (_: ToolExecutionOutput) => {
// TODO(brian): handle speech_handle item_added
const onToolExecutionCompleted = (out: ToolExecutionOutput) => {
if (out.toolCallOutput) {
speechHandle._itemAdded([out.toolCallOutput]);
}
};

const [executeToolsTask, toolOutput] = performToolExecutions({
Expand Down Expand Up @@ -1598,6 +1602,8 @@ export class AgentActivity implements RecognitionHooks {

const onToolExecutionStarted = (f: FunctionCall) => {
speechHandle._itemAdded([f]);
this.agent._chatCtx.insert(f);
this.agentSession.chatCtx.insert(f);
};

const onToolExecutionCompleted = (out: ToolExecutionOutput) => {
Expand Down Expand Up @@ -1735,6 +1741,8 @@ export class AgentActivity implements RecognitionHooks {
let ignoreTaskSwitch: boolean = false;

for (const sanitizedOut of toolOutput.output) {
functionToolsExecutedEvent.functionCalls.push(sanitizedOut.toolCall);

if (sanitizedOut.toolCallOutput !== undefined) {
functionToolsExecutedEvent.functionCallOutputs.push(sanitizedOut.toolCallOutput);
if (sanitizedOut.replyRequired) {
Expand Down
22 changes: 19 additions & 3 deletions agents/src/voice/generation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,13 @@ import {
import { log } from '../log.js';
import { IdentityTransform } from '../stream/identity_transform.js';
import { Future, Task, shortuuid, toError } from '../utils.js';
import { type Agent, type ModelSettings, asyncLocalStorage, isStopResponse } from './agent.js';
import {
type Agent,
type ModelSettings,
createActiveToolCall,
isStopResponse,
toolCallContext,
} from './agent.js';
import type { AgentSession } from './agent_session.js';
import type { AudioOutput, LLMNode, TTSNode, TextOutput } from './io.js';
import { RunContext } from './run_context.js';
Expand Down Expand Up @@ -772,12 +778,22 @@ export function performToolExecutions({
'Executing LLM tool call',
);

const toolExecution = asyncLocalStorage.run({ functionCall: toolCall }, async () => {
return await tool.execute(parsedArgs, {
const toolExecution = toolCallContext.run(createActiveToolCall(toolCall), async () => {
const result = await tool.execute(parsedArgs, {
ctx: new RunContext(session, speechHandle, toolCall),
toolCallId: toolCall.callId,
abortSignal: signal,
});

// [IMPORTANT] wait for all speech handles created inside tool call to be scheduled first
const activeToolCall = toolCallContext.getStore()!;
if (activeToolCall.speechHandles.length > 0) {
await Promise.all(
activeToolCall.speechHandles.map((handle) => handle._waitForScheduled()),
);
}

return result;
});

const tracableToolExecution = async (toolExecTask: Promise<unknown>) => {
Expand Down
6 changes: 3 additions & 3 deletions agents/src/voice/speech_handle.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
//
// SPDX-License-Identifier: Apache-2.0
import type { ChatItem } from '../llm/index.js';
import { Event, Future, shortuuid } from '../utils.js';
import type { Task } from '../utils.js';
import { asyncLocalStorage } from './agent.js';
import { Event, Future, shortuuid } from '../utils.js';
import { toolCallContext } from './agent.js';

export class SpeechHandle {
/** Priority for messages that should be played after all other messages in the queue */
Expand Down Expand Up @@ -124,7 +124,7 @@ export class SpeechHandle {
* has entirely played out, including any tool calls and response follow-ups.
*/
async waitForPlayout(): Promise<void> {
const store = asyncLocalStorage.getStore();
const store = toolCallContext.getStore();
if (store && store?.functionCall) {
throw new Error(
`Cannot call 'SpeechHandle.waitForPlayout()' from inside the function tool '${store.functionCall.name}'. ` +
Expand Down
11 changes: 7 additions & 4 deletions examples/src/basic_tool_call_agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ import {
} from '@livekit/agents';
import * as deepgram from '@livekit/agents-plugin-deepgram';
import * as elevenlabs from '@livekit/agents-plugin-elevenlabs';
import * as google from '@livekit/agents-plugin-google';
import * as livekit from '@livekit/agents-plugin-livekit';
import * as openai from '@livekit/agents-plugin-openai';
import * as silero from '@livekit/agents-plugin-silero';
import { fileURLToPath } from 'node:url';
import { z } from 'zod';
Expand Down Expand Up @@ -50,7 +50,10 @@ export default defineAgent({
location: z.string().describe('The location to get the weather for'),
}),
execute: async ({ location }, { ctx }) => {
ctx.session.say('Checking the weather, please wait a moment haha...');
ctx.session.generateReply({
userInput: 'Tell user you are queuing the request. Counting down from 3 to 1.',
});

return `The weather in ${location} is sunny today.`;
},
});
Expand All @@ -63,7 +66,7 @@ export default defineAgent({
}),
execute: async ({ room, switchTo }, { ctx }) => {
ctx.session.generateReply({
userInput: 'Tell user wait a moment for about 10 seconds',
userInput: 'Tell user you are working on it. please wait a moment.',
});

return `The light in the ${room} is now ${switchTo}.`;
Expand Down Expand Up @@ -137,7 +140,7 @@ export default defineAgent({
vad,
stt: new deepgram.STT(),
tts: new elevenlabs.TTS(),
llm: new google.LLM(),
llm: new openai.realtime.RealtimeModel(),
// to use realtime model, replace the stt, llm, tts and vad with the following
// llm: new openai.realtime.RealtimeModel(),
userData: { number: 0 },
Expand Down