Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 49 additions & 34 deletions core/src/agents/base_agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,15 @@
*/

import {Content} from '@google/genai';
import {trace} from '@opentelemetry/api';
import {context, trace} from '@opentelemetry/api';

import {createEvent, Event} from '../events/event.js';

import {
runAsyncGeneratorWithOtelContext,
traceAgentInvocation,
tracer,
} from '../telemetry/tracing.js';
import {CallbackContext} from './callback_context.js';
import {InvocationContext} from './invocation_context.js';

Expand Down Expand Up @@ -164,35 +169,41 @@ export abstract class BaseAgent {
async *runAsync(
parentContext: InvocationContext,
): AsyncGenerator<Event, void, void> {
const span = trace
.getTracer('gcp.vertex.agent')
.startSpan(`agent_run [${this.name}]`);
const span = tracer.startSpan(`invoke_agent ${this.name}`);
const ctx = trace.setSpan(context.active(), span);
try {
const context = this.createInvocationContext(parentContext);

const beforeAgentCallbackEvent =
await this.handleBeforeAgentCallback(context);
if (beforeAgentCallbackEvent) {
yield beforeAgentCallbackEvent;
}

if (context.endInvocation) {
return;
}

for await (const event of this.runAsyncImpl(context)) {
yield event;
}

if (context.endInvocation) {
return;
}

const afterAgentCallbackEvent =
await this.handleAfterAgentCallback(context);
if (afterAgentCallbackEvent) {
yield afterAgentCallbackEvent;
}
yield* runAsyncGeneratorWithOtelContext<BaseAgent, Event>(
ctx,
this,
async function* () {
const context = this.createInvocationContext(parentContext);

const beforeAgentCallbackEvent =
await this.handleBeforeAgentCallback(context);
if (beforeAgentCallbackEvent) {
yield beforeAgentCallbackEvent;
}

if (context.endInvocation) {
return;
}

traceAgentInvocation({agent: this, invocationContext: context});
for await (const event of this.runAsyncImpl(context)) {
yield event;
}

if (context.endInvocation) {
return;
}

const afterAgentCallbackEvent =
await this.handleAfterAgentCallback(context);
if (afterAgentCallbackEvent) {
yield afterAgentCallbackEvent;
}
},
);
} finally {
span.end();
}
Expand All @@ -205,15 +216,19 @@ export abstract class BaseAgent {
* @yields The events generated by the agent.
* @returns An AsyncGenerator that yields the events generated by the agent.
*/
// eslint-disable-next-line require-yield
async *runLive(
parentContext: InvocationContext, // eslint-disable-line @typescript-eslint/no-unused-vars
): AsyncGenerator<Event, void, void> {
const span = trace
.getTracer('gcp.vertex.agent')
.startSpan(`agent_run [${this.name}]`);
const span = tracer.startSpan(`invoke_agent ${this.name}`);
const ctx = trace.setSpan(context.active(), span);
try {
// TODO(b/425992518): Implement live mode.
yield* runAsyncGeneratorWithOtelContext<BaseAgent, Event>(
ctx,
this,
async function* () {
// TODO(b/425992518): Implement live mode.
},
);
throw new Error('Live mode is not implemented yet.');
} finally {
span.end();
Expand Down
83 changes: 73 additions & 10 deletions core/src/agents/functions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
* SPDX-License-Identifier: Apache-2.0
*/

// TODO - b/436079721: implement traceMergedToolCalls, traceToolCall, tracer.
import {Content, createUserContent, FunctionCall, Part} from '@google/genai';
import {isEmpty} from 'lodash-es';

Expand All @@ -17,6 +16,11 @@ import {ToolContext} from '../tools/tool_context.js';
import {randomUUID} from '../utils/env_aware_utils.js';
import {logger} from '../utils/logger.js';

import {
traceMergedToolCalls,
tracer,
traceToolCall,
} from '../telemetry/tracing.js';
import {
SingleAfterToolCallback,
SingleBeforeToolCallback,
Expand Down Expand Up @@ -212,11 +216,61 @@ async function callToolAsync(
toolContext: ToolContext,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
): Promise<any> {
// TODO - b/436079721: implement [tracer.start_as_current_span]
logger.debug(`callToolAsync ${tool.name}`);
return await tool.runAsync({args, toolContext});
return tracer.startActiveSpan(`execute_tool ${tool.name}`, async (span) => {
try {
logger.debug(`callToolAsync ${tool.name}`);
const result = await tool.runAsync({args, toolContext});
traceToolCall({
tool,
args,
functionResponseEvent: buildResponseEvent(
tool,
result,
toolContext,
toolContext.invocationContext,
),
});
return result;
} finally {
span.end();
}
});
}

function buildResponseEvent(
tool: BaseTool,
functionResult: unknown,
toolContext: ToolContext,
invocationContext: InvocationContext,
): Event {
let responseResult: Record<string, unknown>;
if (typeof functionResult !== 'object' || functionResult == null) {
responseResult = {result: functionResult};
} else {
responseResult = functionResult as Record<string, unknown>;
}

const partFunctionResponse: Part = {
functionResponse: {
name: tool.name,
response: responseResult,
id: toolContext.functionCallId,
},
};

const content: Content = {
role: 'user',
parts: [partFunctionResponse],
};

return createEvent({
invocationId: invocationContext.invocationId,
author: invocationContext.agent.name,
content: content,
actions: toolContext.actions,
branch: invocationContext.branch,
});
}
/**
* Handles function calls.
* Runtime behavior to pay attention to:
Expand Down Expand Up @@ -444,12 +498,21 @@ export async function handleFunctionCallList({
);

if (functionResponseEvents.length > 1) {
// TODO - b/436079721: implement [tracer.start_as_current_span]
logger.debug('execute_tool (merged)');
// TODO - b/436079721: implement [traceMergedToolCalls]
logger.debug('traceMergedToolCalls', {
responseEventId: mergedEvent.id,
functionResponseEvent: mergedEvent.id,
tracer.startActiveSpan('execute_tool (merged)', (span) => {
try {
logger.debug('execute_tool (merged)');
// TODO - b/436079721: implement [traceMergedToolCalls]
logger.debug('traceMergedToolCalls', {
responseEventId: mergedEvent.id,
functionResponseEvent: mergedEvent.id,
});
traceMergedToolCalls({
responseEventId: mergedEvent.id,
functionResponseEvent: mergedEvent,
});
} finally {
span.end();
}
});
}
return mergedEvent;
Expand Down
64 changes: 41 additions & 23 deletions core/src/agents/llm_agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import {
Part,
Schema,
} from '@google/genai';
import {context, trace} from '@opentelemetry/api';
import {cloneDeep} from 'lodash-es';
import {z} from 'zod';

Expand Down Expand Up @@ -56,6 +57,11 @@ import {ToolContext} from '../tools/tool_context.js';
import {base64Decode} from '../utils/env_aware_utils.js';
import {logger} from '../utils/logger.js';

import {
runAsyncGeneratorWithOtelContext,
traceCallLlm,
tracer,
} from '../telemetry/tracing.js';
import {BaseAgent, BaseAgentConfig} from './base_agent.js';
import {
BaseLlmRequestProcessor,
Expand Down Expand Up @@ -1707,26 +1713,35 @@ export class LlmAgent extends BaseAgent {
author: this.name,
branch: invocationContext.branch,
});
for await (const llmResponse of this.callLlmAsync(
invocationContext,
llmRequest,
modelResponseEvent,
)) {
// ======================================================================
// Postprocess after calling the LLM
// ======================================================================
for await (const event of this.postprocess(
invocationContext,
llmRequest,
llmResponse,
modelResponseEvent,
)) {
// Update the mutable event id to avoid conflict
modelResponseEvent.id = createNewEventId();
modelResponseEvent.timestamp = new Date().getTime();
yield event;
}
}
const span = tracer.startSpan('call_llm');
const ctx = trace.setSpan(context.active(), span);
yield* runAsyncGeneratorWithOtelContext<LlmAgent, Event>(
ctx,
this,
async function* () {
for await (const llmResponse of this.callLlmAsync(
invocationContext,
llmRequest,
modelResponseEvent,
)) {
// ======================================================================
// Postprocess after calling the LLM
// ======================================================================
for await (const event of this.postprocess(
invocationContext,
llmRequest,
llmResponse,
modelResponseEvent,
)) {
// Update the mutable event id to avoid conflict
modelResponseEvent.id = createNewEventId();
modelResponseEvent.timestamp = new Date().getTime();
yield event;
}
}
},
);
span.end();
}

private async *postprocess(
Expand Down Expand Up @@ -1885,7 +1900,6 @@ export class LlmAgent extends BaseAgent {

// Calls the LLM.
const llm = this.canonicalModel;
// TODO - b/436079721: Add tracer.start_as_current_span('call_llm')
if (invocationContext.runConfig?.supportCfc) {
// TODO - b/425992518: Implement CFC call path
// This is a hack, underneath it calls runLive. Which makes
Expand All @@ -1905,8 +1919,12 @@ export class LlmAgent extends BaseAgent {
llmRequest,
modelResponseEvent,
)) {
// TODO - b/436079721: Add trace_call_llm

traceCallLlm({
invocationContext,
eventId: modelResponseEvent.id,
llmRequest,
llmResponse,
});
// Runs after_model_callback if it exists.
const alteredLlmResponse = await this.handleAfterModelCallback(
invocationContext,
Expand Down
Loading
Loading