Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 51 additions & 39 deletions src/agents/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,22 +445,9 @@ async def run(
# Start an agent span if we don't have one. This span is ended if the current
# agent changes, or if the agent loop ends.
if current_span is None:
handoff_names = [
h.agent_name
for h in await AgentRunner._get_handoffs(current_agent, context_wrapper)
]
if output_schema := AgentRunner._get_output_schema(current_agent):
output_type_name = output_schema.name()
else:
output_type_name = "str"

current_span = agent_span(
name=current_agent.name,
handoffs=handoff_names,
output_type=output_type_name,
current_span = await AgentRunner._create_agent_span(
current_agent, all_tools, context_wrapper
)
current_span.start(mark_as_current=True)
current_span.span_data.tools = [t.name for t in all_tools]

current_turn += 1
if current_turn > max_turns:
Expand All @@ -481,9 +468,8 @@ async def run(
input_guardrail_results, turn_result = await asyncio.gather(
self._run_input_guardrails(
starting_agent,
starting_agent.input_guardrails
+ (run_config.input_guardrails or []),
_copy_str_or_list(prepared_input),
self._get_combined_input_guardrails(starting_agent, run_config),
_copy_str_or_list(input),
context_wrapper,
),
self._run_single_turn(
Expand Down Expand Up @@ -522,7 +508,7 @@ async def run(

if isinstance(turn_result.next_step, NextStepFinalOutput):
output_guardrail_results = await self._run_output_guardrails(
current_agent.output_guardrails + (run_config.output_guardrails or []),
self._get_combined_output_guardrails(current_agent, run_config),
current_agent,
turn_result.next_step.output,
context_wrapper,
Expand Down Expand Up @@ -793,23 +779,9 @@ async def _start_streaming(
# Start an agent span if we don't have one. This span is ended if the current
# agent changes, or if the agent loop ends.
if current_span is None:
handoff_names = [
h.agent_name
for h in await cls._get_handoffs(current_agent, context_wrapper)
]
if output_schema := cls._get_output_schema(current_agent):
output_type_name = output_schema.name()
else:
output_type_name = "str"

current_span = agent_span(
name=current_agent.name,
handoffs=handoff_names,
output_type=output_type_name,
current_span = await cls._create_agent_span(
current_agent, all_tools, context_wrapper
)
current_span.start(mark_as_current=True)
tool_names = [t.name for t in all_tools]
current_span.span_data.tools = tool_names
current_turn += 1
streamed_result.current_turn = current_turn

Expand All @@ -829,8 +801,8 @@ async def _start_streaming(
streamed_result._input_guardrails_task = asyncio.create_task(
cls._run_input_guardrails_with_queue(
starting_agent,
starting_agent.input_guardrails + (run_config.input_guardrails or []),
ItemHelpers.input_to_new_input_list(prepared_input),
cls._get_combined_input_guardrails(starting_agent, run_config),
ItemHelpers.input_to_new_input_list(starting_input),
context_wrapper,
streamed_result,
current_span,
Expand Down Expand Up @@ -868,8 +840,7 @@ async def _start_streaming(
elif isinstance(turn_result.next_step, NextStepFinalOutput):
streamed_result._output_guardrails_task = asyncio.create_task(
cls._run_output_guardrails(
current_agent.output_guardrails
+ (run_config.output_guardrails or []),
cls._get_combined_output_guardrails(current_agent, run_config),
current_agent,
turn_result.next_step.output,
context_wrapper,
Expand Down Expand Up @@ -1474,6 +1445,47 @@ def _get_model(cls, agent: Agent[Any], run_config: RunConfig) -> Model:

return run_config.model_provider.get_model(agent.model)

@classmethod
async def _create_agent_span(
cls,
agent: Agent[Any],
all_tools: list[Tool],
context_wrapper: RunContextWrapper[Any],
) -> Span[AgentSpanData]:
"""Create and start an agent span with proper metadata."""
handoff_names = [h.agent_name for h in await cls._get_handoffs(agent, context_wrapper)]
if output_schema := cls._get_output_schema(agent):
output_type_name = output_schema.name()
else:
output_type_name = "str"

span = agent_span(
name=agent.name,
handoffs=handoff_names,
output_type=output_type_name,
)
span.start(mark_as_current=True)
span.span_data.tools = [t.name for t in all_tools]
return span

@classmethod
def _get_combined_input_guardrails(
cls,
agent: Agent[Any],
run_config: RunConfig,
) -> list[InputGuardrail[Any]]:
"""Get combined input guardrails from agent and run config."""
return agent.input_guardrails + (run_config.input_guardrails or [])

@classmethod
def _get_combined_output_guardrails(
cls,
agent: Agent[Any],
run_config: RunConfig,
) -> list[OutputGuardrail[Any]]:
"""Get combined output guardrails from agent and run config."""
return agent.output_guardrails + (run_config.output_guardrails or [])

@classmethod
async def _prepare_input_with_session(
cls,
Expand Down