diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java index 20b207d5c5c..6482b84aa0a 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java @@ -25,6 +25,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Optional; import java.util.function.Consumer; @@ -52,6 +53,7 @@ import org.springframework.ai.content.Media; import org.springframework.ai.converter.BeanOutputConverter; import org.springframework.ai.converter.StructuredOutputConverter; +import org.springframework.ai.model.tool.ToolExecutionResult; import org.springframework.ai.support.ToolCallbacks; import org.springframework.ai.template.TemplateRenderer; import org.springframework.ai.template.st.StTemplateRenderer; @@ -521,8 +523,33 @@ private ChatClientResponse doGetObservableChatClientResponse(ChatClientRequest c @Nullable private static String getContentFromChatResponse(@Nullable ChatResponse chatResponse) { - return Optional.ofNullable(chatResponse) - .map(ChatResponse::getResult) + if (chatResponse == null) { + return null; + } + var results = chatResponse.getResults(); + if (results == null || results.isEmpty()) { + return null; + } + if (results.size() == 1) { + return Optional.ofNullable(results.get(0)) + .map(Generation::getOutput) + .map(AbstractMessage::getText) + .orElse(null); + } + boolean allReturnDirect = results.stream().allMatch(g -> { + var finish = g.getMetadata() != null ? g.getMetadata().getFinishReason() : null; + return finish != null && finish.equalsIgnoreCase(ToolExecutionResult.FINISH_REASON); + }); + if (allReturnDirect) { + return results.stream() + .map(Generation::getOutput) + .map(AbstractMessage::getText) + .filter(Objects::nonNull) + .filter(StringUtils::hasText) + .reduce((a, b) -> a + "\n" + b) + .orElse(null); + } + return Optional.ofNullable(results.get(0)) .map(Generation::getOutput) .map(AbstractMessage::getText) .orElse(null); @@ -594,10 +621,35 @@ public Flux content() { // @formatter:off return doGetObservableFluxChatResponse(this.request) .mapNotNull(ChatClientResponse::chatResponse) - .map(r -> Optional.ofNullable(r.getResult()) + .map(r -> { + var results = r.getResults(); + if (results == null || results.isEmpty()) { + return ""; + } + if (results.size() == 1) { + return Optional.ofNullable(results.get(0)) + .map(Generation::getOutput) + .map(AbstractMessage::getText) + .orElse(""); + } + boolean allReturnDirect = results.stream().allMatch(g -> { + var finish = g.getMetadata() != null ? g.getMetadata().getFinishReason() : null; + return finish != null && finish.equalsIgnoreCase(org.springframework.ai.model.tool.ToolExecutionResult.FINISH_REASON); + }); + if (allReturnDirect) { + return results.stream() + .map(Generation::getOutput) + .map(AbstractMessage::getText) + .filter(java.util.Objects::nonNull) + .filter(StringUtils::hasText) + .reduce((a, b) -> a + "\n" + b) + .orElse(""); + } + return Optional.ofNullable(results.get(0)) .map(Generation::getOutput) .map(AbstractMessage::getText) - .orElse("")) + .orElse(""); + }) .filter(StringUtils::hasLength); // @formatter:on } diff --git a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientReturnDirectAggregationTests.java b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientReturnDirectAggregationTests.java new file mode 100644 index 00000000000..5f61ab61b03 --- /dev/null +++ b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientReturnDirectAggregationTests.java @@ -0,0 +1,79 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.client; + +import java.util.List; + +import org.junit.jupiter.api.Test; + +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.metadata.ChatGenerationMetadata; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.model.tool.ToolExecutionResult; + +import static org.assertj.core.api.Assertions.assertThat; + +/* + * @author: Kuntal Maity + */ +class DefaultChatClientReturnDirectAggregationTests { + + private static Generation generation(String text, String finishReason) { + var metadata = ChatGenerationMetadata.builder().finishReason(finishReason).build(); + return new Generation(new AssistantMessage(text), metadata); + } + + @Test + void aggregatesMultipleReturnDirectGenerationsInContent() { + var chatResponse = new ChatResponse(List.of(generation("DATE=2025-10-18", ToolExecutionResult.FINISH_REASON), + generation("TIME=12:34:56.789", ToolExecutionResult.FINISH_REASON))); + + ChatModel stub = new ChatModel() { + @Override + public ChatResponse call(Prompt prompt) { + return chatResponse; + } + }; + + var client = ChatClient.builder(stub).build(); + String content = client.prompt("now").call().content(); + + assertThat(content).isEqualTo("DATE=2025-10-18\nTIME=12:34:56.789"); + } + + @Test + void returnsFirstWhenNotAllReturnDirect() { + var chatResponse = new ChatResponse( + List.of(generation("FIRST", ToolExecutionResult.FINISH_REASON), generation("SECOND", "stop"))); + + ChatModel stub = new ChatModel() { + @Override + public ChatResponse call(Prompt prompt) { + return chatResponse; + } + }; + + var client = ChatClient.builder(stub).build(); + String content = client.prompt("now").call().content(); + + assertThat(content).isEqualTo("FIRST"); + } + +}