|
12 | 12 | from functools import wraps |
13 | 13 | from pathlib import Path |
14 | 14 | from typing import cast |
15 | | -from unittest.mock import AsyncMock, MagicMock, patch |
| 15 | +from unittest.mock import AsyncMock, patch |
16 | 16 | from uuid import uuid4 |
17 | 17 |
|
18 | 18 | import ldp.agent |
|
21 | 21 | Environment, |
22 | 22 | Tool, |
23 | 23 | ToolRequestMessage, |
| 24 | + ToolResponseMessage, |
24 | 25 | ToolsAdapter, |
25 | 26 | ToolSelector, |
26 | 27 | ) |
@@ -470,26 +471,27 @@ async def test_timeout(agent_test_settings: Settings, agent_type: str | type) -> |
470 | 471 | agent_test_settings.agent.timeout = 0.05 # Give time for Environment.reset() |
471 | 472 | agent_test_settings.llm = "gpt-4o-mini" |
472 | 473 | agent_test_settings.agent.tool_names = {"gen_answer", "complete"} |
473 | | - docs = Docs() |
| 474 | + orig_exec_tool_calls = PaperQAEnvironment.exec_tool_calls |
| 475 | + tool_responses: list[list[ToolResponseMessage]] = [] |
474 | 476 |
|
475 | | - async def custom_aget_evidence(*_, **kwargs) -> PQASession: # noqa: RUF029 |
476 | | - return kwargs["query"] |
| 477 | + async def spy_exec_tool_calls(*args, **kwargs) -> list[ToolResponseMessage]: |
| 478 | + responses = await orig_exec_tool_calls(*args, **kwargs) |
| 479 | + tool_responses.append(responses) |
| 480 | + return responses |
477 | 481 |
|
478 | | - with ( |
479 | | - patch.object(docs, "docs", {"stub_key": MagicMock(spec_set=Doc)}), |
480 | | - patch.multiple( |
481 | | - Docs, clear_docs=MagicMock(), aget_evidence=custom_aget_evidence |
482 | | - ), |
483 | | - ): |
| 482 | + with patch.object(PaperQAEnvironment, "exec_tool_calls", spy_exec_tool_calls): |
484 | 483 | response = await agent_query( |
485 | 484 | query="Are COVID-19 vaccines effective?", |
486 | 485 | settings=agent_test_settings, |
487 | | - docs=docs, |
488 | 486 | agent_type=agent_type, |
489 | 487 | ) |
490 | 488 | # Ensure that GenerateAnswerTool was called in truncation's failover |
491 | 489 | assert response.status == AgentStatus.TRUNCATED, "Agent did not timeout" |
492 | 490 | assert CANNOT_ANSWER_PHRASE in response.session.answer |
| 491 | + (last_response,) = tool_responses[-1] |
| 492 | + assert ( |
| 493 | + "no papers" in last_response.content |
| 494 | + ), "Expecting agent to been shown specifics on the failure" |
493 | 495 |
|
494 | 496 |
|
495 | 497 | @pytest.mark.flaky(reruns=5, only_rerun=["AssertionError"]) |
|
0 commit comments