Skip to content

Commit 900241a

Browse files
committed
test: add cancellation tests for blocking guardrails
1 parent db58955 commit 900241a

File tree

1 file changed

+167
-0
lines changed

1 file changed

+167
-0
lines changed

tests/test_guardrails.py

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1232,3 +1232,170 @@ async def config_level_check(
12321232
assert result2.final_output is not None
12331233
assert model1.first_turn_args is not None
12341234
assert model2.first_turn_args is not None
1235+
1236+
1237+
@pytest.mark.asyncio
1238+
async def test_blocking_guardrail_cancels_remaining_on_trigger():
1239+
"""
1240+
Test that when one blocking guardrail triggers, remaining guardrails
1241+
are cancelled (non-streaming).
1242+
"""
1243+
fast_guardrail_executed = False
1244+
slow_guardrail_executed = False
1245+
slow_guardrail_cancelled = False
1246+
timestamps = {}
1247+
1248+
@input_guardrail(run_in_parallel=False)
1249+
async def fast_guardrail_that_triggers(
1250+
ctx: RunContextWrapper[Any], agent: Agent[Any], input: str | list[TResponseInputItem]
1251+
) -> GuardrailFunctionOutput:
1252+
nonlocal fast_guardrail_executed
1253+
timestamps["fast_start"] = time.time()
1254+
await asyncio.sleep(0.1)
1255+
fast_guardrail_executed = True
1256+
timestamps["fast_end"] = time.time()
1257+
return GuardrailFunctionOutput(
1258+
output_info="fast_triggered",
1259+
tripwire_triggered=True,
1260+
)
1261+
1262+
@input_guardrail(run_in_parallel=False)
1263+
async def slow_guardrail_that_should_be_cancelled(
1264+
ctx: RunContextWrapper[Any], agent: Agent[Any], input: str | list[TResponseInputItem]
1265+
) -> GuardrailFunctionOutput:
1266+
nonlocal slow_guardrail_executed, slow_guardrail_cancelled
1267+
timestamps["slow_start"] = time.time()
1268+
try:
1269+
await asyncio.sleep(0.3)
1270+
slow_guardrail_executed = True
1271+
timestamps["slow_end"] = time.time()
1272+
return GuardrailFunctionOutput(
1273+
output_info="slow_completed",
1274+
tripwire_triggered=False,
1275+
)
1276+
except asyncio.CancelledError:
1277+
slow_guardrail_cancelled = True
1278+
timestamps["slow_cancelled"] = time.time()
1279+
raise
1280+
1281+
model = FakeModel()
1282+
agent = Agent(
1283+
name="test_agent",
1284+
instructions="Reply with 'hello'",
1285+
input_guardrails=[fast_guardrail_that_triggers, slow_guardrail_that_should_be_cancelled],
1286+
model=model,
1287+
)
1288+
model.set_next_output([get_text_message("hello")])
1289+
1290+
with pytest.raises(InputGuardrailTripwireTriggered):
1291+
await Runner.run(agent, "test input")
1292+
1293+
# Verify the fast guardrail executed
1294+
assert fast_guardrail_executed is True, "Fast guardrail should have executed"
1295+
1296+
# Verify the slow guardrail was cancelled, not completed
1297+
assert slow_guardrail_cancelled is True, "Slow guardrail should have been cancelled"
1298+
assert slow_guardrail_executed is False, "Slow guardrail should NOT have completed execution"
1299+
1300+
# Verify timing: cancellation happened shortly after fast guardrail triggered
1301+
assert "fast_end" in timestamps
1302+
assert "slow_cancelled" in timestamps
1303+
cancellation_delay = timestamps["slow_cancelled"] - timestamps["fast_end"]
1304+
assert cancellation_delay >= 0, (
1305+
f"Slow guardrail should be cancelled after fast one completes, "
1306+
f"but was {cancellation_delay:.2f}s"
1307+
)
1308+
assert cancellation_delay < 0.2, (
1309+
f"Cancellation should happen before the slow guardrail completes, "
1310+
f"but took {cancellation_delay:.2f}s"
1311+
)
1312+
1313+
# Verify agent never started
1314+
assert model.first_turn_args is None, (
1315+
"Model should not have been called when guardrail triggered"
1316+
)
1317+
1318+
1319+
@pytest.mark.asyncio
1320+
async def test_blocking_guardrail_cancels_remaining_on_trigger_streaming():
1321+
"""
1322+
Test that when one blocking guardrail triggers, remaining guardrails
1323+
are cancelled (streaming).
1324+
"""
1325+
fast_guardrail_executed = False
1326+
slow_guardrail_executed = False
1327+
slow_guardrail_cancelled = False
1328+
timestamps = {}
1329+
1330+
@input_guardrail(run_in_parallel=False)
1331+
async def fast_guardrail_that_triggers(
1332+
ctx: RunContextWrapper[Any], agent: Agent[Any], input: str | list[TResponseInputItem]
1333+
) -> GuardrailFunctionOutput:
1334+
nonlocal fast_guardrail_executed
1335+
timestamps["fast_start"] = time.time()
1336+
await asyncio.sleep(0.1)
1337+
fast_guardrail_executed = True
1338+
timestamps["fast_end"] = time.time()
1339+
return GuardrailFunctionOutput(
1340+
output_info="fast_triggered",
1341+
tripwire_triggered=True,
1342+
)
1343+
1344+
@input_guardrail(run_in_parallel=False)
1345+
async def slow_guardrail_that_should_be_cancelled(
1346+
ctx: RunContextWrapper[Any], agent: Agent[Any], input: str | list[TResponseInputItem]
1347+
) -> GuardrailFunctionOutput:
1348+
nonlocal slow_guardrail_executed, slow_guardrail_cancelled
1349+
timestamps["slow_start"] = time.time()
1350+
try:
1351+
await asyncio.sleep(0.3)
1352+
slow_guardrail_executed = True
1353+
timestamps["slow_end"] = time.time()
1354+
return GuardrailFunctionOutput(
1355+
output_info="slow_completed",
1356+
tripwire_triggered=False,
1357+
)
1358+
except asyncio.CancelledError:
1359+
slow_guardrail_cancelled = True
1360+
timestamps["slow_cancelled"] = time.time()
1361+
raise
1362+
1363+
model = FakeModel()
1364+
agent = Agent(
1365+
name="test_agent",
1366+
instructions="Reply with 'hello'",
1367+
input_guardrails=[fast_guardrail_that_triggers, slow_guardrail_that_should_be_cancelled],
1368+
model=model,
1369+
)
1370+
model.set_next_output([get_text_message("hello")])
1371+
1372+
result = Runner.run_streamed(agent, "test input")
1373+
1374+
with pytest.raises(InputGuardrailTripwireTriggered):
1375+
async for _event in result.stream_events():
1376+
pass
1377+
1378+
# Verify the fast guardrail executed
1379+
assert fast_guardrail_executed is True, "Fast guardrail should have executed"
1380+
1381+
# Verify the slow guardrail was cancelled, not completed
1382+
assert slow_guardrail_cancelled is True, "Slow guardrail should have been cancelled"
1383+
assert slow_guardrail_executed is False, "Slow guardrail should NOT have completed execution"
1384+
1385+
# Verify timing: cancellation happened shortly after fast guardrail triggered
1386+
assert "fast_end" in timestamps
1387+
assert "slow_cancelled" in timestamps
1388+
cancellation_delay = timestamps["slow_cancelled"] - timestamps["fast_end"]
1389+
assert cancellation_delay >= 0, (
1390+
f"Slow guardrail should be cancelled after fast one completes, "
1391+
f"but was {cancellation_delay:.2f}s"
1392+
)
1393+
assert cancellation_delay < 0.2, (
1394+
f"Cancellation should happen before the slow guardrail completes, "
1395+
f"but took {cancellation_delay:.2f}s"
1396+
)
1397+
1398+
# Verify agent never started
1399+
assert model.first_turn_args is None, (
1400+
"Model should not have been called when guardrail triggered"
1401+
)

0 commit comments

Comments
 (0)