diff --git a/tests/unit/utils/test_responses.py b/tests/unit/utils/test_responses.py index 485ecbf26..449184b3d 100644 --- a/tests/unit/utils/test_responses.py +++ b/tests/unit/utils/test_responses.py @@ -1084,6 +1084,7 @@ async def test_allowed_tools_function_tool_filtered_by_type_and_name( ) prepared, choice = await resolve_tool_choice(tools, allowed, "token") assert prepared is not None and len(prepared) == 1 + assert isinstance(prepared[0], InputToolFunction) assert prepared[0].name == "keep_fn" assert choice == ToolChoiceMode.required @@ -1276,6 +1277,7 @@ def test_mcp_type_and_server_label_specific(self) -> None: ) assert out is not None assert len(out) == 1 + assert isinstance(out[0], InputToolMCP) assert out[0].server_label == "keep" def test_function_type_and_name(self) -> None: @@ -1293,6 +1295,7 @@ def test_function_type_and_name(self) -> None: ) assert out is not None assert len(out) == 1 + assert isinstance(out[0], InputToolFunction) assert out[0].name == "fn_b" def test_web_search_type_literal_must_match(self) -> None: @@ -1361,6 +1364,7 @@ def test_mcp_name_grouped_by_server_narrows_allowed_tools(self) -> None: ) assert out is not None assert len(out) == 1 + assert isinstance(out[0], InputToolMCP) assert out[0].allowed_tools == ["alpha"] def test_mcp_allowed_tools_none_projects_to_entry_names(self) -> None: @@ -1381,6 +1385,7 @@ def test_mcp_allowed_tools_none_projects_to_entry_names(self) -> None: ) assert out is not None assert len(out) == 1 + assert isinstance(out[0], InputToolMCP) assert out[0].allowed_tools == ["gamma"] def test_mcp_server_without_name_in_allowlist_skips_projection(self) -> None: @@ -1404,6 +1409,7 @@ def test_mcp_server_without_name_in_allowlist_skips_projection(self) -> None: ) assert out is not None assert len(out) == 1 + assert isinstance(out[0], InputToolMCP) assert out[0].allowed_tools == ["a", "b"] def test_mcp_allowed_tools_filter_tool_names_none(self) -> None: @@ -1424,6 +1430,7 @@ def test_mcp_allowed_tools_filter_tool_names_none(self) -> None: ) assert out is not None assert len(out) == 1 + assert isinstance(out[0], InputToolMCP) assert out[0].allowed_tools == ["z"] def test_mcp_allowed_tools_filter_intersects_with_grouped_names(self) -> None: @@ -1444,6 +1451,7 @@ def test_mcp_allowed_tools_filter_intersects_with_grouped_names(self) -> None: ) assert out is not None assert len(out) == 1 + assert isinstance(out[0], InputToolMCP) assert out[0].allowed_tools == ["alpha"] def test_mcp_name_not_permitted_drops_tool(self) -> None: @@ -1501,6 +1509,7 @@ async def test_prepare_tools_fetch_vector_stores( result = await prepare_tools(mock_client, None, False, "token") assert result is not None assert len(result) == 1 + assert isinstance(result[0], InputToolFileSearch) assert result[0].vector_store_ids == ["vs1", "vs2"] @pytest.mark.asyncio @@ -1583,6 +1592,10 @@ def _make_byok_rag(rag_id: str, vector_db_id: str) -> ByokRag: rag_id=rag_id, vector_db_id=vector_db_id, db_path="tests/configuration/rag.txt", + rag_type="rag", + embedding_model="model", + embedding_dimension=768, + score_multiplier=1.0, ) def test_translates_customer_facing_ids_to_internal(self) -> None: @@ -1671,6 +1684,7 @@ async def test_passes_through_unknown_ids_in_prepare_tools( result = await prepare_tools(mock_client, ["raw-internal-id"], False, "token") assert result is not None + assert isinstance(result[0], InputToolFileSearch) assert result[0].vector_store_ids == ["raw-internal-id"] @pytest.mark.asyncio @@ -1702,6 +1716,7 @@ async def test_does_not_translate_when_ids_fetched_from_llama_stack( result = await prepare_tools(mock_client, None, False, "token") assert result is not None # The IDs from llama-stack should be used as-is (no BYOK translation on None path) + assert isinstance(result[0], InputToolFileSearch) assert result[0].vector_store_ids == ["vs-internal"] @@ -1750,6 +1765,7 @@ async def test_rag_tool_config_ids_are_translated( result = await prepare_tools(mock_client, None, False, "token") assert result is not None + assert isinstance(result[0], InputToolFileSearch) assert result[0].vector_store_ids == ["vs-001"] mock_client.vector_stores.list.assert_not_called() @@ -1790,6 +1806,7 @@ async def test_per_request_ids_override_rag_tool_config( result = await prepare_tools(mock_client, ["request-id-1"], False, "token") assert result is not None + assert isinstance(result[0], InputToolFileSearch) assert result[0].vector_store_ids == ["request-id-1"] mock_client.vector_stores.list.assert_not_called() @@ -1815,6 +1832,7 @@ async def test_all_registered_dbs_used_when_neither_tool_nor_inline_configured( result = await prepare_tools(mock_client, None, False, "token") assert result is not None + assert isinstance(result[0], InputToolFileSearch) assert result[0].vector_store_ids == ["vs-registered"] mock_client.vector_stores.list.assert_called_once()