-
Notifications
You must be signed in to change notification settings - Fork 2.9k
feat: add run_in_parallel parameter to input guardrails #1986
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
d28f527
d4bc971
0050939
721e9ce
e35aa7c
f2c34eb
339895f
db58955
900241a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -601,11 +601,31 @@ async def run( | |
| ) | ||
|
|
||
| if current_turn == 1: | ||
| # Separate guardrails based on execution mode. | ||
| all_input_guardrails = starting_agent.input_guardrails + ( | ||
| run_config.input_guardrails or [] | ||
| ) | ||
| sequential_guardrails = [ | ||
| g for g in all_input_guardrails if not g.run_in_parallel | ||
| ] | ||
| parallel_guardrails = [g for g in all_input_guardrails if g.run_in_parallel] | ||
|
|
||
| # Run blocking guardrails first, before agent starts. | ||
| # (will raise exception if tripwire triggered). | ||
| sequential_results = [] | ||
| if sequential_guardrails: | ||
| sequential_results = await self._run_input_guardrails( | ||
| starting_agent, | ||
| sequential_guardrails, | ||
| _copy_str_or_list(prepared_input), | ||
| context_wrapper, | ||
| ) | ||
|
|
||
| # Run parallel guardrails + agent together. | ||
| input_guardrail_results, turn_result = await asyncio.gather( | ||
| self._run_input_guardrails( | ||
| starting_agent, | ||
| starting_agent.input_guardrails | ||
| + (run_config.input_guardrails or []), | ||
| parallel_guardrails, | ||
| _copy_str_or_list(prepared_input), | ||
| context_wrapper, | ||
| ), | ||
|
|
@@ -622,6 +642,9 @@ async def run( | |
| server_conversation_tracker=server_conversation_tracker, | ||
| ), | ||
| ) | ||
|
|
||
| # Combine sequential and parallel results. | ||
| input_guardrail_results = sequential_results + input_guardrail_results | ||
| else: | ||
| turn_result = await self._run_single_turn( | ||
| agent=current_agent, | ||
|
|
@@ -941,6 +964,11 @@ async def _run_input_guardrails_with_queue( | |
| for done in asyncio.as_completed(guardrail_tasks): | ||
| result = await done | ||
| if result.output.tripwire_triggered: | ||
| # Cancel all remaining guardrail tasks if a tripwire is triggered. | ||
| for t in guardrail_tasks: | ||
| t.cancel() | ||
| # Wait for cancellations to propagate by awaiting the cancelled tasks. | ||
| await asyncio.gather(*guardrail_tasks, return_exceptions=True) | ||
| _error_tracing.attach_error_to_span( | ||
| parent_span, | ||
| SpanError( | ||
|
|
@@ -951,14 +979,19 @@ async def _run_input_guardrails_with_queue( | |
| }, | ||
| ), | ||
| ) | ||
| queue.put_nowait(result) | ||
| guardrail_results.append(result) | ||
| break | ||
| queue.put_nowait(result) | ||
| guardrail_results.append(result) | ||
| except Exception: | ||
| for t in guardrail_tasks: | ||
| t.cancel() | ||
| raise | ||
|
|
||
| streamed_result.input_guardrail_results = guardrail_results | ||
| streamed_result.input_guardrail_results = ( | ||
| streamed_result.input_guardrail_results + guardrail_results | ||
| ) | ||
|
|
||
| @classmethod | ||
| async def _start_streaming( | ||
|
Comment on lines
989
to
997
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Blocking input guardrails are cancelled in the non‑streaming path ( Useful? React with 👍 / 👎.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
@@ -1050,11 +1083,36 @@ async def _start_streaming( | |
| break | ||
|
|
||
| if current_turn == 1: | ||
| # Run the input guardrails in the background and put the results on the queue | ||
| # Separate guardrails based on execution mode. | ||
| all_input_guardrails = starting_agent.input_guardrails + ( | ||
| run_config.input_guardrails or [] | ||
| ) | ||
| sequential_guardrails = [ | ||
| g for g in all_input_guardrails if not g.run_in_parallel | ||
| ] | ||
| parallel_guardrails = [g for g in all_input_guardrails if g.run_in_parallel] | ||
|
|
||
| # Run sequential guardrails first. | ||
| if sequential_guardrails: | ||
| await cls._run_input_guardrails_with_queue( | ||
| starting_agent, | ||
| sequential_guardrails, | ||
| ItemHelpers.input_to_new_input_list(prepared_input), | ||
| context_wrapper, | ||
| streamed_result, | ||
| current_span, | ||
| ) | ||
| # Check if any blocking guardrail triggered and raise before starting agent. | ||
| for result in streamed_result.input_guardrail_results: | ||
| if result.output.tripwire_triggered: | ||
| streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) | ||
| raise InputGuardrailTripwireTriggered(result) | ||
|
|
||
| # Run parallel guardrails in background. | ||
| streamed_result._input_guardrails_task = asyncio.create_task( | ||
| cls._run_input_guardrails_with_queue( | ||
| starting_agent, | ||
| starting_agent.input_guardrails + (run_config.input_guardrails or []), | ||
| parallel_guardrails, | ||
| ItemHelpers.input_to_new_input_list(prepared_input), | ||
| context_wrapper, | ||
| streamed_result, | ||
|
|
@@ -1619,6 +1677,8 @@ async def _run_input_guardrails( | |
| # Cancel all guardrail tasks if a tripwire is triggered. | ||
| for t in guardrail_tasks: | ||
| t.cancel() | ||
| # Wait for cancellations to propagate by awaiting the cancelled tasks. | ||
| await asyncio.gather(*guardrail_tasks, return_exceptions=True) | ||
| _error_tracing.attach_error_to_current_span( | ||
| SpanError( | ||
| message="Guardrail tripwire triggered", | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The class-level docstring for
InputGuardrailstates 'Input guardrails are checks that run in parallel to the agent's execution' but this is no longer accurate since guardrails can now run sequentially. The docstring should be updated to reflect that guardrails can run either in parallel or sequentially based on therun_in_parallelparameter.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated in 1ad513a