From f348b3a774a20e00a2013ea191bb11a5d5f6be57 Mon Sep 17 00:00:00 2001 From: Oleksandr Klymenko Date: Mon, 6 Oct 2025 23:02:08 +0200 Subject: [PATCH 1/3] test: Add comprehensive test coverage for DefaultToolCallingManager Co-authored-by: Oleksandr Klymenko Signed-off-by: Oleksandr Klymenko --- .../tool/DefaultToolCallingManagerTest.java | 266 ++++++++++++++++++ 1 file changed, 266 insertions(+) diff --git a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingManagerTest.java b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingManagerTest.java index ce775b20cd5..0ea3af84fee 100644 --- a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingManagerTest.java +++ b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingManagerTest.java @@ -155,4 +155,270 @@ public String call(String toolInput) { assertThatNoException().isThrownBy(() -> managerWithCallback.executeToolCalls(prompt, chatResponse)); } + @Test + void shouldHandleMultipleToolCallsInSingleResponse() { + // Create mock tool callbacks + ToolCallback toolCallback1 = new ToolCallback() { + @Override + public ToolDefinition getToolDefinition() { + return DefaultToolDefinition.builder() + .name("tool1") + .description("First tool") + .inputSchema("{\"type\": \"object\", \"properties\": {\"param\": {\"type\": \"string\"}}}") + .build(); + } + + @Override + public ToolMetadata getToolMetadata() { + return ToolMetadata.builder().build(); + } + + @Override + public String call(String toolInput) { + return "{\"result\": \"tool1_success\"}"; + } + }; + + ToolCallback toolCallback2 = new ToolCallback() { + @Override + public ToolDefinition getToolDefinition() { + return DefaultToolDefinition.builder() + .name("tool2") + .description("Second tool") + .inputSchema("{\"type\": \"object\", \"properties\": {\"value\": {\"type\": \"number\"}}}") + .build(); + } + + @Override + public ToolMetadata getToolMetadata() { + return ToolMetadata.builder().build(); + } + + @Override + public String call(String toolInput) { + return "{\"result\": \"tool2_success\"}"; + } + }; + + // Create multiple ToolCalls + AssistantMessage.ToolCall toolCall1 = new AssistantMessage.ToolCall("1", "function", "tool1", + "{\"param\": \"test\"}"); + AssistantMessage.ToolCall toolCall2 = new AssistantMessage.ToolCall("2", "function", "tool2", + "{\"value\": 42}"); + + // Create ChatResponse with multiple tool calls + AssistantMessage assistantMessage = AssistantMessage.builder() + .content("") + .properties(Map.of()) + .toolCalls(List.of(toolCall1, toolCall2)) + .build(); + Generation generation = new Generation(assistantMessage); + ChatResponse chatResponse = new ChatResponse(List.of(generation)); + + Prompt prompt = new Prompt(List.of(new UserMessage("test multiple tools"))); + + DefaultToolCallingManager manager = DefaultToolCallingManager.builder() + .observationRegistry(ObservationRegistry.NOOP) + .toolCallbackResolver(toolName -> { + if ("tool1".equals(toolName)) + return toolCallback1; + if ("tool2".equals(toolName)) + return toolCallback2; + return null; + }) + .build(); + + assertThatNoException().isThrownBy(() -> manager.executeToolCalls(prompt, chatResponse)); + } + + @Test + void shouldHandleToolCallWithComplexJsonArguments() { + ToolCallback complexToolCallback = new ToolCallback() { + @Override + public ToolDefinition getToolDefinition() { + return DefaultToolDefinition.builder() + .name("complexTool") + .description("A tool with complex JSON input") + .inputSchema("{\"type\": \"object\", \"properties\": {\"nested\": {\"type\": \"object\"}}}") + .build(); + } + + @Override + public ToolMetadata getToolMetadata() { + return ToolMetadata.builder().build(); + } + + @Override + public String call(String toolInput) { + assertThat(toolInput).contains("nested"); + assertThat(toolInput).contains("array"); + return "{\"result\": \"processed\"}"; + } + }; + + String complexJson = "{\"nested\": {\"level1\": {\"level2\": \"value\"}}, \"array\": [1, 2, 3]}"; + AssistantMessage.ToolCall toolCall = new AssistantMessage.ToolCall("1", "function", "complexTool", complexJson); + + AssistantMessage assistantMessage = AssistantMessage.builder() + .content("") + .properties(Map.of()) + .toolCalls(List.of(toolCall)) + .build(); + Generation generation = new Generation(assistantMessage); + ChatResponse chatResponse = new ChatResponse(List.of(generation)); + + Prompt prompt = new Prompt(List.of(new UserMessage("test complex json"))); + + DefaultToolCallingManager manager = DefaultToolCallingManager.builder() + .observationRegistry(ObservationRegistry.NOOP) + .toolCallbackResolver(toolName -> "complexTool".equals(toolName) ? complexToolCallback : null) + .build(); + + assertThatNoException().isThrownBy(() -> manager.executeToolCalls(prompt, chatResponse)); + } + + @Test + void shouldHandleToolCallWithMalformedJson() { + ToolCallback toolCallback = new ToolCallback() { + @Override + public ToolDefinition getToolDefinition() { + return DefaultToolDefinition.builder() + .name("testTool") + .description("Test tool") + .inputSchema("{}") + .build(); + } + + @Override + public ToolMetadata getToolMetadata() { + return ToolMetadata.builder().build(); + } + + @Override + public String call(String toolInput) { + // Should still receive some input even if malformed + assertThat(toolInput).isNotNull(); + return "{\"result\": \"handled\"}"; + } + }; + + // Malformed JSON as tool arguments + AssistantMessage.ToolCall toolCall = new AssistantMessage.ToolCall("1", "function", "testTool", + "{invalid json}"); + + AssistantMessage assistantMessage = AssistantMessage.builder() + .content("") + .properties(Map.of()) + .toolCalls(List.of(toolCall)) + .build(); + Generation generation = new Generation(assistantMessage); + ChatResponse chatResponse = new ChatResponse(List.of(generation)); + + Prompt prompt = new Prompt(List.of(new UserMessage("test malformed json"))); + + DefaultToolCallingManager manager = DefaultToolCallingManager.builder() + .observationRegistry(ObservationRegistry.NOOP) + .toolCallbackResolver(toolName -> "testTool".equals(toolName) ? toolCallback : null) + .build(); + + assertThatNoException().isThrownBy(() -> manager.executeToolCalls(prompt, chatResponse)); + } + + @Test + void shouldHandleToolCallReturningNull() { + ToolCallback toolCallback = new ToolCallback() { + @Override + public ToolDefinition getToolDefinition() { + return DefaultToolDefinition.builder() + .name("nullReturningTool") + .description("Tool that returns null") + .inputSchema("{}") + .build(); + } + + @Override + public ToolMetadata getToolMetadata() { + return ToolMetadata.builder().build(); + } + + @Override + public String call(String toolInput) { + return null; // Return null + } + }; + + AssistantMessage.ToolCall toolCall = new AssistantMessage.ToolCall("1", "function", "nullReturningTool", "{}"); + + AssistantMessage assistantMessage = AssistantMessage.builder() + .content("") + .properties(Map.of()) + .toolCalls(List.of(toolCall)) + .build(); + Generation generation = new Generation(assistantMessage); + ChatResponse chatResponse = new ChatResponse(List.of(generation)); + + Prompt prompt = new Prompt(List.of(new UserMessage("test null return"))); + + DefaultToolCallingManager manager = DefaultToolCallingManager.builder() + .observationRegistry(ObservationRegistry.NOOP) + .toolCallbackResolver(toolName -> "nullReturningTool".equals(toolName) ? toolCallback : null) + .build(); + + assertThatNoException().isThrownBy(() -> manager.executeToolCalls(prompt, chatResponse)); + } + + @Test + void shouldHandleMultipleGenerationsWithToolCalls() { + ToolCallback toolCallback = new ToolCallback() { + @Override + public ToolDefinition getToolDefinition() { + return DefaultToolDefinition.builder() + .name("multiGenTool") + .description("Tool for multiple generations") + .inputSchema("{}") + .build(); + } + + @Override + public ToolMetadata getToolMetadata() { + return ToolMetadata.builder().build(); + } + + @Override + public String call(String toolInput) { + return "{\"result\": \"success\"}"; + } + }; + + // Create multiple generations with tool calls + AssistantMessage.ToolCall toolCall1 = new AssistantMessage.ToolCall("1", "function", "multiGenTool", "{}"); + AssistantMessage.ToolCall toolCall2 = new AssistantMessage.ToolCall("2", "function", "multiGenTool", "{}"); + + AssistantMessage assistantMessage1 = AssistantMessage.builder() + .content("") + .properties(Map.of()) + .toolCalls(List.of(toolCall1)) + .build(); + + AssistantMessage assistantMessage2 = AssistantMessage.builder() + .content("") + .properties(Map.of()) + .toolCalls(List.of(toolCall2)) + .build(); + + Generation generation1 = new Generation(assistantMessage1); + Generation generation2 = new Generation(assistantMessage2); + + ChatResponse chatResponse = new ChatResponse(List.of(generation1, generation2)); + + Prompt prompt = new Prompt(List.of(new UserMessage("test multiple generations"))); + + DefaultToolCallingManager manager = DefaultToolCallingManager.builder() + .observationRegistry(ObservationRegistry.NOOP) + .toolCallbackResolver(toolName -> "multiGenTool".equals(toolName) ? toolCallback : null) + .build(); + + assertThatNoException().isThrownBy(() -> manager.executeToolCalls(prompt, chatResponse)); + } + } From 41f5f77751ccdb1ee59660dbbae3efa042bf3409 Mon Sep 17 00:00:00 2001 From: Oleksandr Klymenko Date: Mon, 6 Oct 2025 23:12:18 +0200 Subject: [PATCH 2/3] style: fix formatting Co-authored-by: Oleksandr Klymenko Signed-off-by: Oleksandr Klymenko --- .../ai/model/tool/DefaultToolCallingManagerTest.java | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingManagerTest.java b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingManagerTest.java index 0ea3af84fee..81818c034ec 100644 --- a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingManagerTest.java +++ b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingManagerTest.java @@ -220,10 +220,11 @@ public String call(String toolInput) { DefaultToolCallingManager manager = DefaultToolCallingManager.builder() .observationRegistry(ObservationRegistry.NOOP) .toolCallbackResolver(toolName -> { - if ("tool1".equals(toolName)) + if ("tool1".equals(toolName)) { return toolCallback1; - if ("tool2".equals(toolName)) + } if ("tool2".equals(toolName)) { return toolCallback2; + } return null; }) .build(); From 2b8eb344d1bb13e4a44094904b826c807d91e5f4 Mon Sep 17 00:00:00 2001 From: Oleksandr Klymenko Date: Mon, 6 Oct 2025 23:15:56 +0200 Subject: [PATCH 3/3] style: fix formatting Co-authored-by: Oleksandr Klymenko Signed-off-by: Oleksandr Klymenko --- .../ai/model/tool/DefaultToolCallingManagerTest.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingManagerTest.java b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingManagerTest.java index 81818c034ec..bd60639c323 100644 --- a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingManagerTest.java +++ b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingManagerTest.java @@ -222,7 +222,8 @@ public String call(String toolInput) { .toolCallbackResolver(toolName -> { if ("tool1".equals(toolName)) { return toolCallback1; - } if ("tool2".equals(toolName)) { + } + if ("tool2".equals(toolName)) { return toolCallback2; } return null;