diff --git a/src/agents/guardrail.py b/src/agents/guardrail.py index 99e287675..8ab68cd34 100644 --- a/src/agents/guardrail.py +++ b/src/agents/guardrail.py @@ -70,7 +70,7 @@ class OutputGuardrailResult: @dataclass class InputGuardrail(Generic[TContext]): - """Input guardrails are checks that run in parallel to the agent's execution. + """Input guardrails are checks that run either in parallel with the agent or before it starts. They can be used to do things like: - Check if input messages are off-topic - Take over control of the agent's execution if an unexpected input is detected @@ -97,6 +97,11 @@ class InputGuardrail(Generic[TContext]): function's name. """ + run_in_parallel: bool = True + """Whether the guardrail runs concurrently with the agent (True, default) or before + the agent starts (False). + """ + def get_name(self) -> str: if self.name: return self.name @@ -209,6 +214,7 @@ def input_guardrail( def input_guardrail( *, name: str | None = None, + run_in_parallel: bool = True, ) -> Callable[ [_InputGuardrailFuncSync[TContext_co] | _InputGuardrailFuncAsync[TContext_co]], InputGuardrail[TContext_co], @@ -221,6 +227,7 @@ def input_guardrail( | None = None, *, name: str | None = None, + run_in_parallel: bool = True, ) -> ( InputGuardrail[TContext_co] | Callable[ @@ -235,8 +242,14 @@ def input_guardrail( @input_guardrail def my_sync_guardrail(...): ... - @input_guardrail(name="guardrail_name") + @input_guardrail(name="guardrail_name", run_in_parallel=False) async def my_async_guardrail(...): ... + + Args: + func: The guardrail function to wrap. + name: Optional name for the guardrail. If not provided, uses the function's name. + run_in_parallel: Whether to run the guardrail concurrently with the agent (True, default) + or before the agent starts (False). """ def decorator( @@ -246,6 +259,7 @@ def decorator( guardrail_function=f, # If not set, guardrail name uses the function’s name by default. name=name if name else f.__name__, + run_in_parallel=run_in_parallel, ) if func is not None: diff --git a/src/agents/run.py b/src/agents/run.py index c14f13e3f..da454757f 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -601,11 +601,31 @@ async def run( ) if current_turn == 1: + # Separate guardrails based on execution mode. + all_input_guardrails = starting_agent.input_guardrails + ( + run_config.input_guardrails or [] + ) + sequential_guardrails = [ + g for g in all_input_guardrails if not g.run_in_parallel + ] + parallel_guardrails = [g for g in all_input_guardrails if g.run_in_parallel] + + # Run blocking guardrails first, before agent starts. + # (will raise exception if tripwire triggered). + sequential_results = [] + if sequential_guardrails: + sequential_results = await self._run_input_guardrails( + starting_agent, + sequential_guardrails, + _copy_str_or_list(prepared_input), + context_wrapper, + ) + + # Run parallel guardrails + agent together. input_guardrail_results, turn_result = await asyncio.gather( self._run_input_guardrails( starting_agent, - starting_agent.input_guardrails - + (run_config.input_guardrails or []), + parallel_guardrails, _copy_str_or_list(prepared_input), context_wrapper, ), @@ -622,6 +642,9 @@ async def run( server_conversation_tracker=server_conversation_tracker, ), ) + + # Combine sequential and parallel results. + input_guardrail_results = sequential_results + input_guardrail_results else: turn_result = await self._run_single_turn( agent=current_agent, @@ -941,6 +964,11 @@ async def _run_input_guardrails_with_queue( for done in asyncio.as_completed(guardrail_tasks): result = await done if result.output.tripwire_triggered: + # Cancel all remaining guardrail tasks if a tripwire is triggered. + for t in guardrail_tasks: + t.cancel() + # Wait for cancellations to propagate by awaiting the cancelled tasks. + await asyncio.gather(*guardrail_tasks, return_exceptions=True) _error_tracing.attach_error_to_span( parent_span, SpanError( @@ -951,6 +979,9 @@ async def _run_input_guardrails_with_queue( }, ), ) + queue.put_nowait(result) + guardrail_results.append(result) + break queue.put_nowait(result) guardrail_results.append(result) except Exception: @@ -958,7 +989,9 @@ async def _run_input_guardrails_with_queue( t.cancel() raise - streamed_result.input_guardrail_results = guardrail_results + streamed_result.input_guardrail_results = ( + streamed_result.input_guardrail_results + guardrail_results + ) @classmethod async def _start_streaming( @@ -1050,11 +1083,36 @@ async def _start_streaming( break if current_turn == 1: - # Run the input guardrails in the background and put the results on the queue + # Separate guardrails based on execution mode. + all_input_guardrails = starting_agent.input_guardrails + ( + run_config.input_guardrails or [] + ) + sequential_guardrails = [ + g for g in all_input_guardrails if not g.run_in_parallel + ] + parallel_guardrails = [g for g in all_input_guardrails if g.run_in_parallel] + + # Run sequential guardrails first. + if sequential_guardrails: + await cls._run_input_guardrails_with_queue( + starting_agent, + sequential_guardrails, + ItemHelpers.input_to_new_input_list(prepared_input), + context_wrapper, + streamed_result, + current_span, + ) + # Check if any blocking guardrail triggered and raise before starting agent. + for result in streamed_result.input_guardrail_results: + if result.output.tripwire_triggered: + streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) + raise InputGuardrailTripwireTriggered(result) + + # Run parallel guardrails in background. streamed_result._input_guardrails_task = asyncio.create_task( cls._run_input_guardrails_with_queue( starting_agent, - starting_agent.input_guardrails + (run_config.input_guardrails or []), + parallel_guardrails, ItemHelpers.input_to_new_input_list(prepared_input), context_wrapper, streamed_result, @@ -1619,6 +1677,8 @@ async def _run_input_guardrails( # Cancel all guardrail tasks if a tripwire is triggered. for t in guardrail_tasks: t.cancel() + # Wait for cancellations to propagate by awaiting the cancelled tasks. + await asyncio.gather(*guardrail_tasks, return_exceptions=True) _error_tracing.attach_error_to_current_span( SpanError( message="Guardrail tripwire triggered", diff --git a/tests/test_guardrails.py b/tests/test_guardrails.py index c9f318c32..199564ef5 100644 --- a/tests/test_guardrails.py +++ b/tests/test_guardrails.py @@ -1,6 +1,9 @@ from __future__ import annotations +import asyncio +import time from typing import Any +from unittest.mock import patch import pytest @@ -8,13 +11,20 @@ Agent, GuardrailFunctionOutput, InputGuardrail, + InputGuardrailTripwireTriggered, OutputGuardrail, + RunConfig, RunContextWrapper, + Runner, TResponseInputItem, UserError, + function_tool, ) from agents.guardrail import input_guardrail, output_guardrail +from .fake_model import FakeModel +from .test_responses import get_function_tool_call, get_text_message + def get_sync_guardrail(triggers: bool, output_info: Any | None = None): def sync_guardrail( @@ -260,3 +270,1132 @@ async def test_output_guardrail_decorators(): assert not result.output.tripwire_triggered assert result.output.output_info == "test_4" assert guardrail.get_name() == "Custom name" + + +@pytest.mark.asyncio +async def test_input_guardrail_run_in_parallel_default(): + guardrail = InputGuardrail( + guardrail_function=lambda ctx, agent, input: GuardrailFunctionOutput( + output_info=None, tripwire_triggered=False + ) + ) + assert guardrail.run_in_parallel is True + + +@pytest.mark.asyncio +async def test_input_guardrail_run_in_parallel_false(): + guardrail = InputGuardrail( + guardrail_function=lambda ctx, agent, input: GuardrailFunctionOutput( + output_info=None, tripwire_triggered=False + ), + run_in_parallel=False, + ) + assert guardrail.run_in_parallel is False + + +@pytest.mark.asyncio +async def test_input_guardrail_decorator_with_run_in_parallel(): + @input_guardrail(run_in_parallel=False) + def blocking_guardrail( + context: RunContextWrapper[Any], agent: Agent[Any], input: str | list[TResponseInputItem] + ) -> GuardrailFunctionOutput: + return GuardrailFunctionOutput( + output_info="blocking", + tripwire_triggered=False, + ) + + assert blocking_guardrail.run_in_parallel is False + result = await blocking_guardrail.run( + agent=Agent(name="test"), input="test", context=RunContextWrapper(context=None) + ) + assert not result.output.tripwire_triggered + assert result.output.output_info == "blocking" + + +@pytest.mark.asyncio +async def test_input_guardrail_decorator_with_name_and_run_in_parallel(): + @input_guardrail(name="custom_name", run_in_parallel=False) + def named_blocking_guardrail( + context: RunContextWrapper[Any], agent: Agent[Any], input: str | list[TResponseInputItem] + ) -> GuardrailFunctionOutput: + return GuardrailFunctionOutput( + output_info="named_blocking", + tripwire_triggered=False, + ) + + assert named_blocking_guardrail.get_name() == "custom_name" + assert named_blocking_guardrail.run_in_parallel is False + + +@pytest.mark.asyncio +async def test_parallel_guardrail_runs_concurrently_with_agent(): + guardrail_executed = False + + @input_guardrail(run_in_parallel=True) + async def parallel_check( + ctx: RunContextWrapper[Any], agent: Agent[Any], input: str | list[TResponseInputItem] + ) -> GuardrailFunctionOutput: + nonlocal guardrail_executed + await asyncio.sleep(0.3) + guardrail_executed = True + return GuardrailFunctionOutput( + output_info="parallel_ok", + tripwire_triggered=False, + ) + + model = FakeModel() + agent = Agent( + name="test_agent", + instructions="Reply with 'hello'", + input_guardrails=[parallel_check], + model=model, + ) + model.set_next_output([get_text_message("hello")]) + + result = await Runner.run(agent, "test input") + + assert guardrail_executed is True + assert result.final_output is not None + assert len(result.input_guardrail_results) == 1 + assert result.input_guardrail_results[0].output.output_info == "parallel_ok" + assert model.first_turn_args is not None, "Model should have been called in parallel mode" + + +@pytest.mark.asyncio +async def test_parallel_guardrail_runs_concurrently_with_agent_streaming(): + guardrail_executed = False + + @input_guardrail(run_in_parallel=True) + async def parallel_check( + ctx: RunContextWrapper[Any], agent: Agent[Any], input: str | list[TResponseInputItem] + ) -> GuardrailFunctionOutput: + nonlocal guardrail_executed + await asyncio.sleep(0.1) + guardrail_executed = True + return GuardrailFunctionOutput( + output_info="parallel_streaming_ok", + tripwire_triggered=False, + ) + + model = FakeModel() + agent = Agent( + name="streaming_agent", + instructions="Reply with 'hello'", + input_guardrails=[parallel_check], + model=model, + ) + model.set_next_output([get_text_message("hello from stream")]) + + result = Runner.run_streamed(agent, "test input") + + received_events = False + async for _event in result.stream_events(): + received_events = True + + assert guardrail_executed is True + assert received_events is True + assert model.first_turn_args is not None, "Model should have been called in parallel mode" + + +@pytest.mark.asyncio +async def test_blocking_guardrail_prevents_agent_execution(): + guardrail_executed = False + + @input_guardrail(run_in_parallel=False) + async def blocking_check( + ctx: RunContextWrapper[Any], agent: Agent[Any], input: str | list[TResponseInputItem] + ) -> GuardrailFunctionOutput: + nonlocal guardrail_executed + guardrail_executed = True + await asyncio.sleep(0.3) + return GuardrailFunctionOutput( + output_info="security_violation", + tripwire_triggered=True, + ) + + model = FakeModel() + agent = Agent( + name="test_agent", + instructions="Reply with 'hello'", + input_guardrails=[blocking_check], + model=model, + ) + model.set_next_output([get_text_message("hello")]) + + with pytest.raises(InputGuardrailTripwireTriggered) as exc_info: + await Runner.run(agent, "test input") + + assert guardrail_executed is True + assert exc_info.value.guardrail_result.output.output_info == "security_violation" + assert model.first_turn_args is None, "Model should not have been called" + + +@pytest.mark.asyncio +async def test_blocking_guardrail_prevents_agent_execution_streaming(): + guardrail_executed = False + + @input_guardrail(run_in_parallel=False) + async def blocking_check( + ctx: RunContextWrapper[Any], agent: Agent[Any], input: str | list[TResponseInputItem] + ) -> GuardrailFunctionOutput: + nonlocal guardrail_executed + guardrail_executed = True + await asyncio.sleep(0.3) + return GuardrailFunctionOutput( + output_info="blocked_streaming", + tripwire_triggered=True, + ) + + model = FakeModel() + agent = Agent( + name="streaming_agent", + instructions="Reply with a long message", + input_guardrails=[blocking_check], + model=model, + ) + model.set_next_output([get_text_message("hello")]) + + result = Runner.run_streamed(agent, "test input") + + with pytest.raises(InputGuardrailTripwireTriggered): + async for _event in result.stream_events(): + pass + + assert guardrail_executed is True + assert model.first_turn_args is None, "Model should not have been called" + + +@pytest.mark.asyncio +async def test_parallel_guardrail_may_not_prevent_tool_execution(): + tool_was_executed = False + guardrail_executed = False + + @function_tool + def fast_tool() -> str: + nonlocal tool_was_executed + tool_was_executed = True + return "tool_executed" + + @input_guardrail(run_in_parallel=True) + async def slow_parallel_check( + ctx: RunContextWrapper[Any], agent: Agent[Any], input: str | list[TResponseInputItem] + ) -> GuardrailFunctionOutput: + nonlocal guardrail_executed + await asyncio.sleep(0.5) + guardrail_executed = True + return GuardrailFunctionOutput( + output_info="slow_parallel_triggered", + tripwire_triggered=True, + ) + + model = FakeModel() + agent = Agent( + name="agent_with_tools", + instructions="Call the fast_tool immediately", + tools=[fast_tool], + input_guardrails=[slow_parallel_check], + model=model, + ) + model.set_next_output([get_function_tool_call("fast_tool", arguments="{}")]) + model.set_next_output([get_text_message("done")]) + + with pytest.raises(InputGuardrailTripwireTriggered): + await Runner.run(agent, "trigger guardrail") + + assert guardrail_executed is True + assert tool_was_executed is True, ( + "Expected tool to execute before slow parallel guardrail triggered" + ) + assert model.first_turn_args is not None, "Model should have been called in parallel mode" + + +@pytest.mark.asyncio +async def test_parallel_guardrail_may_not_prevent_tool_execution_streaming(): + tool_was_executed = False + guardrail_executed = False + + @function_tool + def fast_tool() -> str: + nonlocal tool_was_executed + tool_was_executed = True + return "tool_executed" + + @input_guardrail(run_in_parallel=True) + async def slow_parallel_check( + ctx: RunContextWrapper[Any], agent: Agent[Any], input: str | list[TResponseInputItem] + ) -> GuardrailFunctionOutput: + nonlocal guardrail_executed + await asyncio.sleep(0.5) + guardrail_executed = True + return GuardrailFunctionOutput( + output_info="slow_parallel_triggered_streaming", + tripwire_triggered=True, + ) + + model = FakeModel() + agent = Agent( + name="agent_with_tools", + instructions="Call the fast_tool immediately", + tools=[fast_tool], + input_guardrails=[slow_parallel_check], + model=model, + ) + model.set_next_output([get_function_tool_call("fast_tool", arguments="{}")]) + model.set_next_output([get_text_message("done")]) + + result = Runner.run_streamed(agent, "trigger guardrail") + + with pytest.raises(InputGuardrailTripwireTriggered): + async for _event in result.stream_events(): + pass + + assert guardrail_executed is True + assert tool_was_executed is True, ( + "Expected tool to execute before slow parallel guardrail triggered" + ) + assert model.first_turn_args is not None, "Model should have been called in parallel mode" + + +@pytest.mark.asyncio +async def test_blocking_guardrail_prevents_tool_execution(): + tool_was_executed = False + guardrail_executed = False + + @function_tool + def dangerous_tool() -> str: + nonlocal tool_was_executed + tool_was_executed = True + return "tool_executed" + + @input_guardrail(run_in_parallel=False) + async def security_check( + ctx: RunContextWrapper[Any], agent: Agent[Any], input: str | list[TResponseInputItem] + ) -> GuardrailFunctionOutput: + nonlocal guardrail_executed + await asyncio.sleep(0.3) + guardrail_executed = True + return GuardrailFunctionOutput( + output_info="blocked_dangerous_input", + tripwire_triggered=True, + ) + + model = FakeModel() + agent = Agent( + name="agent_with_tools", + instructions="Call the dangerous_tool immediately", + tools=[dangerous_tool], + input_guardrails=[security_check], + model=model, + ) + model.set_next_output([get_function_tool_call("dangerous_tool", arguments="{}")]) + + with pytest.raises(InputGuardrailTripwireTriggered): + await Runner.run(agent, "trigger guardrail") + + assert guardrail_executed is True + assert tool_was_executed is False + assert model.first_turn_args is None, "Model should not have been called" + + +@pytest.mark.asyncio +async def test_blocking_guardrail_prevents_tool_execution_streaming(): + tool_was_executed = False + guardrail_executed = False + + @function_tool + def dangerous_tool() -> str: + nonlocal tool_was_executed + tool_was_executed = True + return "tool_executed" + + @input_guardrail(run_in_parallel=False) + async def security_check( + ctx: RunContextWrapper[Any], agent: Agent[Any], input: str | list[TResponseInputItem] + ) -> GuardrailFunctionOutput: + nonlocal guardrail_executed + await asyncio.sleep(0.3) + guardrail_executed = True + return GuardrailFunctionOutput( + output_info="blocked_dangerous_input_streaming", + tripwire_triggered=True, + ) + + model = FakeModel() + agent = Agent( + name="agent_with_tools", + instructions="Call the dangerous_tool immediately", + tools=[dangerous_tool], + input_guardrails=[security_check], + model=model, + ) + model.set_next_output([get_function_tool_call("dangerous_tool", arguments="{}")]) + + result = Runner.run_streamed(agent, "trigger guardrail") + + with pytest.raises(InputGuardrailTripwireTriggered): + async for _event in result.stream_events(): + pass + + assert guardrail_executed is True + assert tool_was_executed is False + assert model.first_turn_args is None, "Model should not have been called" + + +@pytest.mark.asyncio +async def test_parallel_guardrail_passes_agent_continues(): + guardrail_executed = False + + @input_guardrail(run_in_parallel=True) + async def parallel_check( + ctx: RunContextWrapper[Any], agent: Agent[Any], input: str | list[TResponseInputItem] + ) -> GuardrailFunctionOutput: + nonlocal guardrail_executed + await asyncio.sleep(0.1) + guardrail_executed = True + return GuardrailFunctionOutput( + output_info="parallel_passed", + tripwire_triggered=False, + ) + + model = FakeModel() + agent = Agent( + name="test_agent", + instructions="Reply with 'success'", + input_guardrails=[parallel_check], + model=model, + ) + model.set_next_output([get_text_message("success")]) + + result = await Runner.run(agent, "test input") + + assert guardrail_executed is True + assert result.final_output is not None + assert model.first_turn_args is not None, "Model should have been called" + + +@pytest.mark.asyncio +async def test_parallel_guardrail_passes_agent_continues_streaming(): + guardrail_executed = False + + @input_guardrail(run_in_parallel=True) + async def parallel_check( + ctx: RunContextWrapper[Any], agent: Agent[Any], input: str | list[TResponseInputItem] + ) -> GuardrailFunctionOutput: + nonlocal guardrail_executed + await asyncio.sleep(0.1) + guardrail_executed = True + return GuardrailFunctionOutput( + output_info="parallel_passed_streaming", + tripwire_triggered=False, + ) + + model = FakeModel() + agent = Agent( + name="test_agent", + instructions="Reply with 'success'", + input_guardrails=[parallel_check], + model=model, + ) + model.set_next_output([get_text_message("success")]) + + result = Runner.run_streamed(agent, "test input") + + received_events = False + async for _event in result.stream_events(): + received_events = True + + assert guardrail_executed is True + assert received_events is True + assert model.first_turn_args is not None, "Model should have been called" + + +@pytest.mark.asyncio +async def test_blocking_guardrail_passes_agent_continues(): + guardrail_executed = False + + @input_guardrail(run_in_parallel=False) + async def blocking_check( + ctx: RunContextWrapper[Any], agent: Agent[Any], input: str | list[TResponseInputItem] + ) -> GuardrailFunctionOutput: + nonlocal guardrail_executed + await asyncio.sleep(0.3) + guardrail_executed = True + return GuardrailFunctionOutput( + output_info="blocking_passed", + tripwire_triggered=False, + ) + + model = FakeModel() + agent = Agent( + name="test_agent", + instructions="Reply with 'success'", + input_guardrails=[blocking_check], + model=model, + ) + model.set_next_output([get_text_message("success")]) + + result = await Runner.run(agent, "test input") + + assert guardrail_executed is True + assert result.final_output is not None + assert model.first_turn_args is not None, "Model should have been called after guardrail passed" + + +@pytest.mark.asyncio +async def test_blocking_guardrail_passes_agent_continues_streaming(): + guardrail_executed = False + + @input_guardrail(run_in_parallel=False) + async def blocking_check( + ctx: RunContextWrapper[Any], agent: Agent[Any], input: str | list[TResponseInputItem] + ) -> GuardrailFunctionOutput: + nonlocal guardrail_executed + await asyncio.sleep(0.3) + guardrail_executed = True + return GuardrailFunctionOutput( + output_info="blocking_passed_streaming", + tripwire_triggered=False, + ) + + model = FakeModel() + agent = Agent( + name="test_agent", + instructions="Reply with 'success'", + input_guardrails=[blocking_check], + model=model, + ) + model.set_next_output([get_text_message("success")]) + + result = Runner.run_streamed(agent, "test input") + + received_events = False + async for _event in result.stream_events(): + received_events = True + + assert guardrail_executed is True + assert received_events is True + assert model.first_turn_args is not None, "Model should have been called after guardrail passed" + + +@pytest.mark.asyncio +async def test_mixed_blocking_and_parallel_guardrails(): + timestamps = {} + + @input_guardrail(run_in_parallel=False) + async def blocking_check( + ctx: RunContextWrapper[Any], agent: Agent[Any], input: str | list[TResponseInputItem] + ) -> GuardrailFunctionOutput: + timestamps["blocking_start"] = time.time() + await asyncio.sleep(0.3) + timestamps["blocking_end"] = time.time() + return GuardrailFunctionOutput( + output_info="blocking_passed", + tripwire_triggered=False, + ) + + @input_guardrail(run_in_parallel=True) + async def parallel_check( + ctx: RunContextWrapper[Any], agent: Agent[Any], input: str | list[TResponseInputItem] + ) -> GuardrailFunctionOutput: + timestamps["parallel_start"] = time.time() + await asyncio.sleep(0.3) + timestamps["parallel_end"] = time.time() + return GuardrailFunctionOutput( + output_info="parallel_passed", + tripwire_triggered=False, + ) + + model = FakeModel() + + original_get_response = model.get_response + + async def tracked_get_response(*args, **kwargs): + timestamps["model_called"] = time.time() + return await original_get_response(*args, **kwargs) + + agent = Agent( + name="mixed_agent", + instructions="Reply with 'hello'", + input_guardrails=[blocking_check, parallel_check], + model=model, + ) + model.set_next_output([get_text_message("hello")]) + + with patch.object(model, "get_response", side_effect=tracked_get_response): + result = await Runner.run(agent, "test input") + + assert result.final_output is not None + assert len(result.input_guardrail_results) == 2 + + assert "blocking_start" in timestamps + assert "blocking_end" in timestamps + assert "parallel_start" in timestamps + assert "parallel_end" in timestamps + assert "model_called" in timestamps + + assert timestamps["blocking_end"] <= timestamps["parallel_start"], ( + "Blocking must complete before parallel starts" + ) + assert timestamps["blocking_end"] <= timestamps["model_called"], ( + "Blocking must complete before model is called" + ) + assert timestamps["model_called"] <= timestamps["parallel_end"], ( + "Model called while parallel guardrail still running" + ) + assert model.first_turn_args is not None, ( + "Model should have been called after blocking guardrails passed" + ) + + +@pytest.mark.asyncio +async def test_mixed_blocking_and_parallel_guardrails_streaming(): + timestamps = {} + + @input_guardrail(run_in_parallel=False) + async def blocking_check( + ctx: RunContextWrapper[Any], agent: Agent[Any], input: str | list[TResponseInputItem] + ) -> GuardrailFunctionOutput: + timestamps["blocking_start"] = time.time() + await asyncio.sleep(0.3) + timestamps["blocking_end"] = time.time() + return GuardrailFunctionOutput( + output_info="blocking_passed", + tripwire_triggered=False, + ) + + @input_guardrail(run_in_parallel=True) + async def parallel_check( + ctx: RunContextWrapper[Any], agent: Agent[Any], input: str | list[TResponseInputItem] + ) -> GuardrailFunctionOutput: + timestamps["parallel_start"] = time.time() + await asyncio.sleep(0.3) + timestamps["parallel_end"] = time.time() + return GuardrailFunctionOutput( + output_info="parallel_passed", + tripwire_triggered=False, + ) + + model = FakeModel() + + original_stream_response = model.stream_response + + async def tracked_stream_response(*args, **kwargs): + timestamps["model_called"] = time.time() + async for event in original_stream_response(*args, **kwargs): + yield event + + agent = Agent( + name="mixed_agent", + instructions="Reply with 'hello'", + input_guardrails=[blocking_check, parallel_check], + model=model, + ) + model.set_next_output([get_text_message("hello")]) + + with patch.object(model, "stream_response", side_effect=tracked_stream_response): + result = Runner.run_streamed(agent, "test input") + + received_events = False + async for _event in result.stream_events(): + received_events = True + + assert received_events is True + assert "blocking_start" in timestamps + assert "blocking_end" in timestamps + assert "parallel_start" in timestamps + assert "parallel_end" in timestamps + assert "model_called" in timestamps + + assert timestamps["blocking_end"] <= timestamps["parallel_start"], ( + "Blocking must complete before parallel starts" + ) + assert timestamps["blocking_end"] <= timestamps["model_called"], ( + "Blocking must complete before model is called" + ) + assert timestamps["model_called"] <= timestamps["parallel_end"], ( + "Model called while parallel guardrail still running" + ) + assert model.first_turn_args is not None, ( + "Model should have been called after blocking guardrails passed" + ) + + +@pytest.mark.asyncio +async def test_multiple_blocking_guardrails_complete_before_agent(): + timestamps = {} + + @input_guardrail(run_in_parallel=False) + async def first_blocking_check( + ctx: RunContextWrapper[Any], agent: Agent[Any], input: str | list[TResponseInputItem] + ) -> GuardrailFunctionOutput: + timestamps["first_blocking_start"] = time.time() + await asyncio.sleep(0.3) + timestamps["first_blocking_end"] = time.time() + return GuardrailFunctionOutput( + output_info="first_passed", + tripwire_triggered=False, + ) + + @input_guardrail(run_in_parallel=False) + async def second_blocking_check( + ctx: RunContextWrapper[Any], agent: Agent[Any], input: str | list[TResponseInputItem] + ) -> GuardrailFunctionOutput: + timestamps["second_blocking_start"] = time.time() + await asyncio.sleep(0.3) + timestamps["second_blocking_end"] = time.time() + return GuardrailFunctionOutput( + output_info="second_passed", + tripwire_triggered=False, + ) + + model = FakeModel() + + original_get_response = model.get_response + + async def tracked_get_response(*args, **kwargs): + timestamps["model_called"] = time.time() + return await original_get_response(*args, **kwargs) + + agent = Agent( + name="multi_blocking_agent", + instructions="Reply with 'hello'", + input_guardrails=[first_blocking_check, second_blocking_check], + model=model, + ) + model.set_next_output([get_text_message("hello")]) + + with patch.object(model, "get_response", side_effect=tracked_get_response): + result = await Runner.run(agent, "test input") + + assert result.final_output is not None + assert len(result.input_guardrail_results) == 2 + + assert "first_blocking_start" in timestamps + assert "first_blocking_end" in timestamps + assert "second_blocking_start" in timestamps + assert "second_blocking_end" in timestamps + assert "model_called" in timestamps + + assert timestamps["first_blocking_end"] <= timestamps["model_called"], ( + "First blocking guardrail must complete before model is called" + ) + assert timestamps["second_blocking_end"] <= timestamps["model_called"], ( + "Second blocking guardrail must complete before model is called" + ) + assert model.first_turn_args is not None, ( + "Model should have been called after all blocking guardrails passed" + ) + + +@pytest.mark.asyncio +async def test_multiple_blocking_guardrails_complete_before_agent_streaming(): + timestamps = {} + + @input_guardrail(run_in_parallel=False) + async def first_blocking_check( + ctx: RunContextWrapper[Any], agent: Agent[Any], input: str | list[TResponseInputItem] + ) -> GuardrailFunctionOutput: + timestamps["first_blocking_start"] = time.time() + await asyncio.sleep(0.3) + timestamps["first_blocking_end"] = time.time() + return GuardrailFunctionOutput( + output_info="first_passed", + tripwire_triggered=False, + ) + + @input_guardrail(run_in_parallel=False) + async def second_blocking_check( + ctx: RunContextWrapper[Any], agent: Agent[Any], input: str | list[TResponseInputItem] + ) -> GuardrailFunctionOutput: + timestamps["second_blocking_start"] = time.time() + await asyncio.sleep(0.3) + timestamps["second_blocking_end"] = time.time() + return GuardrailFunctionOutput( + output_info="second_passed", + tripwire_triggered=False, + ) + + model = FakeModel() + + original_stream_response = model.stream_response + + async def tracked_stream_response(*args, **kwargs): + timestamps["model_called"] = time.time() + async for event in original_stream_response(*args, **kwargs): + yield event + + agent = Agent( + name="multi_blocking_agent", + instructions="Reply with 'hello'", + input_guardrails=[first_blocking_check, second_blocking_check], + model=model, + ) + model.set_next_output([get_text_message("hello")]) + + with patch.object(model, "stream_response", side_effect=tracked_stream_response): + result = Runner.run_streamed(agent, "test input") + + received_events = False + async for _event in result.stream_events(): + received_events = True + + assert received_events is True + assert "first_blocking_start" in timestamps + assert "first_blocking_end" in timestamps + assert "second_blocking_start" in timestamps + assert "second_blocking_end" in timestamps + assert "model_called" in timestamps + + assert timestamps["first_blocking_end"] <= timestamps["model_called"], ( + "First blocking guardrail must complete before model is called" + ) + assert timestamps["second_blocking_end"] <= timestamps["model_called"], ( + "Second blocking guardrail must complete before model is called" + ) + assert model.first_turn_args is not None, ( + "Model should have been called after all blocking guardrails passed" + ) + + +@pytest.mark.asyncio +async def test_multiple_blocking_guardrails_one_triggers(): + timestamps = {} + first_guardrail_executed = False + second_guardrail_executed = False + + @input_guardrail(run_in_parallel=False) + async def first_blocking_check( + ctx: RunContextWrapper[Any], agent: Agent[Any], input: str | list[TResponseInputItem] + ) -> GuardrailFunctionOutput: + nonlocal first_guardrail_executed + timestamps["first_blocking_start"] = time.time() + await asyncio.sleep(0.3) + first_guardrail_executed = True + timestamps["first_blocking_end"] = time.time() + return GuardrailFunctionOutput( + output_info="first_passed", + tripwire_triggered=False, + ) + + @input_guardrail(run_in_parallel=False) + async def second_blocking_check( + ctx: RunContextWrapper[Any], agent: Agent[Any], input: str | list[TResponseInputItem] + ) -> GuardrailFunctionOutput: + nonlocal second_guardrail_executed + timestamps["second_blocking_start"] = time.time() + await asyncio.sleep(0.3) + second_guardrail_executed = True + timestamps["second_blocking_end"] = time.time() + return GuardrailFunctionOutput( + output_info="second_triggered", + tripwire_triggered=True, + ) + + model = FakeModel() + agent = Agent( + name="multi_blocking_agent", + instructions="Reply with 'hello'", + input_guardrails=[first_blocking_check, second_blocking_check], + model=model, + ) + model.set_next_output([get_text_message("hello")]) + + with pytest.raises(InputGuardrailTripwireTriggered): + await Runner.run(agent, "test input") + + assert first_guardrail_executed is True + assert second_guardrail_executed is True + assert "first_blocking_start" in timestamps + assert "first_blocking_end" in timestamps + assert "second_blocking_start" in timestamps + assert "second_blocking_end" in timestamps + assert model.first_turn_args is None, ( + "Model should not have been called when guardrail triggered" + ) + + +@pytest.mark.asyncio +async def test_multiple_blocking_guardrails_one_triggers_streaming(): + timestamps = {} + first_guardrail_executed = False + second_guardrail_executed = False + + @input_guardrail(run_in_parallel=False) + async def first_blocking_check( + ctx: RunContextWrapper[Any], agent: Agent[Any], input: str | list[TResponseInputItem] + ) -> GuardrailFunctionOutput: + nonlocal first_guardrail_executed + timestamps["first_blocking_start"] = time.time() + await asyncio.sleep(0.3) + first_guardrail_executed = True + timestamps["first_blocking_end"] = time.time() + return GuardrailFunctionOutput( + output_info="first_passed", + tripwire_triggered=False, + ) + + @input_guardrail(run_in_parallel=False) + async def second_blocking_check( + ctx: RunContextWrapper[Any], agent: Agent[Any], input: str | list[TResponseInputItem] + ) -> GuardrailFunctionOutput: + nonlocal second_guardrail_executed + timestamps["second_blocking_start"] = time.time() + await asyncio.sleep(0.3) + second_guardrail_executed = True + timestamps["second_blocking_end"] = time.time() + return GuardrailFunctionOutput( + output_info="second_triggered", + tripwire_triggered=True, + ) + + model = FakeModel() + agent = Agent( + name="multi_blocking_agent", + instructions="Reply with 'hello'", + input_guardrails=[first_blocking_check, second_blocking_check], + model=model, + ) + model.set_next_output([get_text_message("hello")]) + + result = Runner.run_streamed(agent, "test input") + + with pytest.raises(InputGuardrailTripwireTriggered): + async for _event in result.stream_events(): + pass + + assert first_guardrail_executed is True + assert second_guardrail_executed is True + assert "first_blocking_start" in timestamps + assert "first_blocking_end" in timestamps + assert "second_blocking_start" in timestamps + assert "second_blocking_end" in timestamps + assert model.first_turn_args is None, ( + "Model should not have been called when guardrail triggered" + ) + + +@pytest.mark.asyncio +async def test_guardrail_via_agent_and_run_config_equivalent(): + agent_guardrail_executed = False + config_guardrail_executed = False + + @input_guardrail(run_in_parallel=False) + async def agent_level_check( + ctx: RunContextWrapper[Any], agent: Agent[Any], input: str | list[TResponseInputItem] + ) -> GuardrailFunctionOutput: + nonlocal agent_guardrail_executed + agent_guardrail_executed = True + return GuardrailFunctionOutput( + output_info="agent_level_passed", + tripwire_triggered=False, + ) + + @input_guardrail(run_in_parallel=False) + async def config_level_check( + ctx: RunContextWrapper[Any], agent: Agent[Any], input: str | list[TResponseInputItem] + ) -> GuardrailFunctionOutput: + nonlocal config_guardrail_executed + config_guardrail_executed = True + return GuardrailFunctionOutput( + output_info="config_level_passed", + tripwire_triggered=False, + ) + + model1 = FakeModel() + agent_with_guardrail = Agent( + name="test_agent", + instructions="Reply with 'hello'", + input_guardrails=[agent_level_check], + model=model1, + ) + model1.set_next_output([get_text_message("hello")]) + + model2 = FakeModel() + agent_without_guardrail = Agent( + name="test_agent", + instructions="Reply with 'hello'", + model=model2, + ) + model2.set_next_output([get_text_message("hello")]) + run_config = RunConfig(input_guardrails=[config_level_check]) + + result1 = await Runner.run(agent_with_guardrail, "test input") + result2 = await Runner.run(agent_without_guardrail, "test input", run_config=run_config) + + assert agent_guardrail_executed is True + assert config_guardrail_executed is True + assert len(result1.input_guardrail_results) == 1 + assert len(result2.input_guardrail_results) == 1 + assert result1.input_guardrail_results[0].output.output_info == "agent_level_passed" + assert result2.input_guardrail_results[0].output.output_info == "config_level_passed" + assert result1.final_output is not None + assert result2.final_output is not None + assert model1.first_turn_args is not None + assert model2.first_turn_args is not None + + +@pytest.mark.asyncio +async def test_blocking_guardrail_cancels_remaining_on_trigger(): + """ + Test that when one blocking guardrail triggers, remaining guardrails + are cancelled (non-streaming). + """ + fast_guardrail_executed = False + slow_guardrail_executed = False + slow_guardrail_cancelled = False + timestamps = {} + + @input_guardrail(run_in_parallel=False) + async def fast_guardrail_that_triggers( + ctx: RunContextWrapper[Any], agent: Agent[Any], input: str | list[TResponseInputItem] + ) -> GuardrailFunctionOutput: + nonlocal fast_guardrail_executed + timestamps["fast_start"] = time.time() + await asyncio.sleep(0.1) + fast_guardrail_executed = True + timestamps["fast_end"] = time.time() + return GuardrailFunctionOutput( + output_info="fast_triggered", + tripwire_triggered=True, + ) + + @input_guardrail(run_in_parallel=False) + async def slow_guardrail_that_should_be_cancelled( + ctx: RunContextWrapper[Any], agent: Agent[Any], input: str | list[TResponseInputItem] + ) -> GuardrailFunctionOutput: + nonlocal slow_guardrail_executed, slow_guardrail_cancelled + timestamps["slow_start"] = time.time() + try: + await asyncio.sleep(0.3) + slow_guardrail_executed = True + timestamps["slow_end"] = time.time() + return GuardrailFunctionOutput( + output_info="slow_completed", + tripwire_triggered=False, + ) + except asyncio.CancelledError: + slow_guardrail_cancelled = True + timestamps["slow_cancelled"] = time.time() + raise + + model = FakeModel() + agent = Agent( + name="test_agent", + instructions="Reply with 'hello'", + input_guardrails=[fast_guardrail_that_triggers, slow_guardrail_that_should_be_cancelled], + model=model, + ) + model.set_next_output([get_text_message("hello")]) + + with pytest.raises(InputGuardrailTripwireTriggered): + await Runner.run(agent, "test input") + + # Verify the fast guardrail executed + assert fast_guardrail_executed is True, "Fast guardrail should have executed" + + # Verify the slow guardrail was cancelled, not completed + assert slow_guardrail_cancelled is True, "Slow guardrail should have been cancelled" + assert slow_guardrail_executed is False, "Slow guardrail should NOT have completed execution" + + # Verify timing: cancellation happened shortly after fast guardrail triggered + assert "fast_end" in timestamps + assert "slow_cancelled" in timestamps + cancellation_delay = timestamps["slow_cancelled"] - timestamps["fast_end"] + assert cancellation_delay >= 0, ( + f"Slow guardrail should be cancelled after fast one completes, " + f"but was {cancellation_delay:.2f}s" + ) + assert cancellation_delay < 0.2, ( + f"Cancellation should happen before the slow guardrail completes, " + f"but took {cancellation_delay:.2f}s" + ) + + # Verify agent never started + assert model.first_turn_args is None, ( + "Model should not have been called when guardrail triggered" + ) + + +@pytest.mark.asyncio +async def test_blocking_guardrail_cancels_remaining_on_trigger_streaming(): + """ + Test that when one blocking guardrail triggers, remaining guardrails + are cancelled (streaming). + """ + fast_guardrail_executed = False + slow_guardrail_executed = False + slow_guardrail_cancelled = False + timestamps = {} + + @input_guardrail(run_in_parallel=False) + async def fast_guardrail_that_triggers( + ctx: RunContextWrapper[Any], agent: Agent[Any], input: str | list[TResponseInputItem] + ) -> GuardrailFunctionOutput: + nonlocal fast_guardrail_executed + timestamps["fast_start"] = time.time() + await asyncio.sleep(0.1) + fast_guardrail_executed = True + timestamps["fast_end"] = time.time() + return GuardrailFunctionOutput( + output_info="fast_triggered", + tripwire_triggered=True, + ) + + @input_guardrail(run_in_parallel=False) + async def slow_guardrail_that_should_be_cancelled( + ctx: RunContextWrapper[Any], agent: Agent[Any], input: str | list[TResponseInputItem] + ) -> GuardrailFunctionOutput: + nonlocal slow_guardrail_executed, slow_guardrail_cancelled + timestamps["slow_start"] = time.time() + try: + await asyncio.sleep(0.3) + slow_guardrail_executed = True + timestamps["slow_end"] = time.time() + return GuardrailFunctionOutput( + output_info="slow_completed", + tripwire_triggered=False, + ) + except asyncio.CancelledError: + slow_guardrail_cancelled = True + timestamps["slow_cancelled"] = time.time() + raise + + model = FakeModel() + agent = Agent( + name="test_agent", + instructions="Reply with 'hello'", + input_guardrails=[fast_guardrail_that_triggers, slow_guardrail_that_should_be_cancelled], + model=model, + ) + model.set_next_output([get_text_message("hello")]) + + result = Runner.run_streamed(agent, "test input") + + with pytest.raises(InputGuardrailTripwireTriggered): + async for _event in result.stream_events(): + pass + + # Verify the fast guardrail executed + assert fast_guardrail_executed is True, "Fast guardrail should have executed" + + # Verify the slow guardrail was cancelled, not completed + assert slow_guardrail_cancelled is True, "Slow guardrail should have been cancelled" + assert slow_guardrail_executed is False, "Slow guardrail should NOT have completed execution" + + # Verify timing: cancellation happened shortly after fast guardrail triggered + assert "fast_end" in timestamps + assert "slow_cancelled" in timestamps + cancellation_delay = timestamps["slow_cancelled"] - timestamps["fast_end"] + assert cancellation_delay >= 0, ( + f"Slow guardrail should be cancelled after fast one completes, " + f"but was {cancellation_delay:.2f}s" + ) + assert cancellation_delay < 0.2, ( + f"Cancellation should happen before the slow guardrail completes, " + f"but took {cancellation_delay:.2f}s" + ) + + # Verify agent never started + assert model.first_turn_args is None, ( + "Model should not have been called when guardrail triggered" + )