From 0a71c04305498c4a2beb4c05cd7efb3601806533 Mon Sep 17 00:00:00 2001 From: Aaron Ang Date: Thu, 4 Sep 2025 19:54:39 -0700 Subject: [PATCH] fix: send raw input instead of prepared input with session history to input guardrail --- src/agents/run.py | 90 +++++++++++++++++++++++++++-------------------- 1 file changed, 51 insertions(+), 39 deletions(-) diff --git a/src/agents/run.py b/src/agents/run.py index 4575edb3f..ef5d72f48 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -445,22 +445,9 @@ async def run( # Start an agent span if we don't have one. This span is ended if the current # agent changes, or if the agent loop ends. if current_span is None: - handoff_names = [ - h.agent_name - for h in await AgentRunner._get_handoffs(current_agent, context_wrapper) - ] - if output_schema := AgentRunner._get_output_schema(current_agent): - output_type_name = output_schema.name() - else: - output_type_name = "str" - - current_span = agent_span( - name=current_agent.name, - handoffs=handoff_names, - output_type=output_type_name, + current_span = await AgentRunner._create_agent_span( + current_agent, all_tools, context_wrapper ) - current_span.start(mark_as_current=True) - current_span.span_data.tools = [t.name for t in all_tools] current_turn += 1 if current_turn > max_turns: @@ -481,9 +468,8 @@ async def run( input_guardrail_results, turn_result = await asyncio.gather( self._run_input_guardrails( starting_agent, - starting_agent.input_guardrails - + (run_config.input_guardrails or []), - _copy_str_or_list(prepared_input), + self._get_combined_input_guardrails(starting_agent, run_config), + _copy_str_or_list(input), context_wrapper, ), self._run_single_turn( @@ -522,7 +508,7 @@ async def run( if isinstance(turn_result.next_step, NextStepFinalOutput): output_guardrail_results = await self._run_output_guardrails( - current_agent.output_guardrails + (run_config.output_guardrails or []), + self._get_combined_output_guardrails(current_agent, run_config), current_agent, turn_result.next_step.output, context_wrapper, @@ -793,23 +779,9 @@ async def _start_streaming( # Start an agent span if we don't have one. This span is ended if the current # agent changes, or if the agent loop ends. if current_span is None: - handoff_names = [ - h.agent_name - for h in await cls._get_handoffs(current_agent, context_wrapper) - ] - if output_schema := cls._get_output_schema(current_agent): - output_type_name = output_schema.name() - else: - output_type_name = "str" - - current_span = agent_span( - name=current_agent.name, - handoffs=handoff_names, - output_type=output_type_name, + current_span = await cls._create_agent_span( + current_agent, all_tools, context_wrapper ) - current_span.start(mark_as_current=True) - tool_names = [t.name for t in all_tools] - current_span.span_data.tools = tool_names current_turn += 1 streamed_result.current_turn = current_turn @@ -829,8 +801,8 @@ async def _start_streaming( streamed_result._input_guardrails_task = asyncio.create_task( cls._run_input_guardrails_with_queue( starting_agent, - starting_agent.input_guardrails + (run_config.input_guardrails or []), - ItemHelpers.input_to_new_input_list(prepared_input), + cls._get_combined_input_guardrails(starting_agent, run_config), + ItemHelpers.input_to_new_input_list(starting_input), context_wrapper, streamed_result, current_span, @@ -868,8 +840,7 @@ async def _start_streaming( elif isinstance(turn_result.next_step, NextStepFinalOutput): streamed_result._output_guardrails_task = asyncio.create_task( cls._run_output_guardrails( - current_agent.output_guardrails - + (run_config.output_guardrails or []), + cls._get_combined_output_guardrails(current_agent, run_config), current_agent, turn_result.next_step.output, context_wrapper, @@ -1474,6 +1445,47 @@ def _get_model(cls, agent: Agent[Any], run_config: RunConfig) -> Model: return run_config.model_provider.get_model(agent.model) + @classmethod + async def _create_agent_span( + cls, + agent: Agent[Any], + all_tools: list[Tool], + context_wrapper: RunContextWrapper[Any], + ) -> Span[AgentSpanData]: + """Create and start an agent span with proper metadata.""" + handoff_names = [h.agent_name for h in await cls._get_handoffs(agent, context_wrapper)] + if output_schema := cls._get_output_schema(agent): + output_type_name = output_schema.name() + else: + output_type_name = "str" + + span = agent_span( + name=agent.name, + handoffs=handoff_names, + output_type=output_type_name, + ) + span.start(mark_as_current=True) + span.span_data.tools = [t.name for t in all_tools] + return span + + @classmethod + def _get_combined_input_guardrails( + cls, + agent: Agent[Any], + run_config: RunConfig, + ) -> list[InputGuardrail[Any]]: + """Get combined input guardrails from agent and run config.""" + return agent.input_guardrails + (run_config.input_guardrails or []) + + @classmethod + def _get_combined_output_guardrails( + cls, + agent: Agent[Any], + run_config: RunConfig, + ) -> list[OutputGuardrail[Any]]: + """Get combined output guardrails from agent and run config.""" + return agent.output_guardrails + (run_config.output_guardrails or []) + @classmethod async def _prepare_input_with_session( cls,