Skip to content

Commit 18657c5

Browse files
committed
GH-5007: Fix handling when response contains both text and function calls
Signed-off-by: Nathan Grand <nathangrand@quantexa.com>
1 parent 21db782 commit 18657c5

File tree

3 files changed

+73
-37
lines changed

3 files changed

+73
-37
lines changed

models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatModel.java

Lines changed: 22 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import java.util.HashMap;
2323
import java.util.List;
2424
import java.util.Map;
25+
import java.util.stream.Collectors;
2526

2627
import com.fasterxml.jackson.annotation.JsonInclude;
2728
import com.fasterxml.jackson.annotation.JsonInclude.Include;
@@ -647,45 +648,30 @@ protected List<Generation> responseCandidateToGeneration(Candidate candidate) {
647648
.finishReason(candidateFinishReason.toString())
648649
.build();
649650

650-
boolean isFunctionCall = candidate.content().isPresent() && candidate.content().get().parts().isPresent()
651-
&& candidate.content().get().parts().get().stream().allMatch(part -> part.functionCall().isPresent());
651+
List<Part> parts = candidate.content().get().parts().orElse(List.of());
652652

653-
if (isFunctionCall) {
654-
List<AssistantMessage.ToolCall> assistantToolCalls = candidate.content()
655-
.get()
656-
.parts()
657-
.orElse(List.of())
658-
.stream()
659-
.filter(part -> part.functionCall().isPresent())
660-
.map(part -> {
661-
FunctionCall functionCall = part.functionCall().get();
662-
var functionName = functionCall.name().orElse("");
663-
String functionArguments = mapToJson(functionCall.args().orElse(Map.of()));
664-
return new AssistantMessage.ToolCall("", "function", functionName, functionArguments);
665-
})
666-
.toList();
653+
List<AssistantMessage.ToolCall> assistantToolCalls = parts.stream()
654+
.filter(part -> part.functionCall().isPresent())
655+
.map(part -> {
656+
FunctionCall functionCall = part.functionCall().get();
657+
var functionName = functionCall.name().orElse("");
658+
String functionArguments = mapToJson(functionCall.args().orElse(Map.of()));
659+
return new AssistantMessage.ToolCall("", "function", functionName, functionArguments);
660+
})
661+
.toList();
667662

668-
AssistantMessage assistantMessage = AssistantMessage.builder()
669-
.content("")
670-
.properties(messageMetadata)
671-
.toolCalls(assistantToolCalls)
672-
.build();
663+
String text = parts.stream()
664+
.filter(part -> part.text().isPresent() && !part.text().get().isEmpty())
665+
.map(part -> part.text().get())
666+
.collect(Collectors.joining(" "));
673667

674-
return List.of(new Generation(assistantMessage, chatGenerationMetadata));
675-
}
676-
else {
677-
return candidate.content()
678-
.get()
679-
.parts()
680-
.orElse(List.of())
681-
.stream()
682-
.map(part -> AssistantMessage.builder()
683-
.content(part.text().orElse(""))
684-
.properties(messageMetadata)
685-
.build())
686-
.map(assistantMessage -> new Generation(assistantMessage, chatGenerationMetadata))
687-
.toList();
688-
}
668+
AssistantMessage assistantMessage = AssistantMessage.builder()
669+
.content(text)
670+
.properties(messageMetadata)
671+
.toolCalls(assistantToolCalls)
672+
.build();
673+
674+
return List.of(new Generation(assistantMessage, chatGenerationMetadata));
689675
}
690676

691677
private ChatResponseMetadata toChatResponseMetadata(Usage usage, String modelVersion) {

models/spring-ai-google-genai/src/test/java/org/springframework/ai/google/genai/GoogleGenAiChatModelIT.java

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,41 @@ void jsonTextToolCallingTest() {
388388
assertThat(response).contains("2025-05-08T10:10:10+02:00");
389389
}
390390

391+
/**
392+
* See https://github.com/spring-projects/spring-ai/pull/4599
393+
*/
394+
@Test
395+
void testMixedPartsMessages() {
396+
397+
ToolCallingManager toolCallingManager = ToolCallingManager.builder()
398+
.observationRegistry(ObservationRegistry.NOOP)
399+
.build();
400+
401+
GoogleGenAiChatModel chatModelWithTools = GoogleGenAiChatModel.builder()
402+
.genAiClient(genAiClient())
403+
.toolCallingManager(toolCallingManager)
404+
.defaultOptions(GoogleGenAiChatOptions.builder()
405+
.model(GoogleGenAiChatModel.ChatModel.GEMINI_2_5_FLASH)
406+
.temperature(0.0)
407+
.build())
408+
.build();
409+
410+
ChatClient chatClient = ChatClient.builder(chatModelWithTools).build();
411+
412+
// Create a prompt that will encourage gemini to explain why it is calling tools
413+
// as it does.
414+
AlarmTools alarmTools = new AlarmTools();
415+
String response = chatClient.prompt()
416+
.tools(new CurrentTimeTools(), alarmTools)
417+
.system("You MUST include reasoning when you issue tool calls.")
418+
.user("Set an alarm for an hour from now, and tell me what time that was for.")
419+
.call()
420+
.content();
421+
422+
assertThat(response).isEqualTo("I have set an alarm for 11:10 AM.");
423+
assertThat(alarmTools.getAlarm()).isEqualTo("2025-05-08T11:10:10+02:00");
424+
}
425+
391426
@Test
392427
void testThinkingBudgetGeminiProAutomaticDecisionByModel() {
393428
GoogleGenAiChatModel chatModelWithThinkingBudget = GoogleGenAiChatModel.builder()
@@ -516,6 +551,21 @@ String getCurrentDateTime() {
516551

517552
}
518553

554+
public static class AlarmTools {
555+
556+
private String alarm;
557+
558+
@Tool(description = "Set a user alarm for the given time, provided in ISO-8601 format")
559+
void setAlarm(String time) {
560+
this.alarm = time;
561+
}
562+
563+
public String getAlarm() {
564+
return this.alarm;
565+
}
566+
567+
}
568+
519569
record ActorsFilmsRecord(String actor, List<String> movies) {
520570

521571
}

models/spring-ai-google-genai/src/test/java/org/springframework/ai/google/genai/cache/GoogleGenAiCachedContentServiceTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ void testUpdateCachedContent() {
152152
assertThat(updated.getName()).isEqualTo(name);
153153
assertThat(updated.getTtl()).isEqualTo(newTtl);
154154
assertThat(updated.getUpdateTime()).isNotNull();
155-
assertThat(updated.getUpdateTime()).isAfter(created.getCreateTime());
155+
assertThat(updated.getUpdateTime()).isAfterOrEqualTo(created.getCreateTime());
156156
}
157157

158158
@Test

0 commit comments

Comments
 (0)