From 2df44de6f83bd17812f87ad5e5c0bf881ce99e74 Mon Sep 17 00:00:00 2001 From: OwenDavisBC Date: Mon, 20 Oct 2025 11:01:19 -0600 Subject: [PATCH] fix: Gemini thoughts not correctly accumulated when streaming enabled --- .../java/com/google/adk/models/Gemini.java | 29 ++++- .../com/google/adk/models/GeminiUtil.java | 33 ++--- .../com/google/adk/models/GeminiTest.java | 9 +- .../com/google/adk/models/GeminiUtilTest.java | 120 ++++-------------- 4 files changed, 62 insertions(+), 129 deletions(-) diff --git a/core/src/main/java/com/google/adk/models/Gemini.java b/core/src/main/java/com/google/adk/models/Gemini.java index daa33971..4b21a173 100644 --- a/core/src/main/java/com/google/adk/models/Gemini.java +++ b/core/src/main/java/com/google/adk/models/Gemini.java @@ -35,6 +35,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Objects; +import java.util.Optional; import java.util.concurrent.CompletableFuture; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -236,6 +237,7 @@ public Flowable generateContent(LlmRequest llmRequest, boolean stre static Flowable processRawResponses(Flowable rawResponses) { final StringBuilder accumulatedText = new StringBuilder(); + final StringBuilder accumulatedThoughtText = new StringBuilder(); // Array to bypass final local variable reassignment in lambda. final GenerateContentResponse[] lastRawResponseHolder = {null}; return rawResponses @@ -246,15 +248,26 @@ static Flowable processRawResponses(Flowable responsesToEmit = new ArrayList<>(); LlmResponse currentProcessedLlmResponse = LlmResponse.create(rawResponse); - String currentTextChunk = - GeminiUtil.getTextFromLlmResponse(currentProcessedLlmResponse); + Optional part = GeminiUtil.getPart0FromLlmResponse(currentProcessedLlmResponse); + String currentTextChunk = part.flatMap(Part::text).orElse(""); if (!currentTextChunk.isEmpty()) { - accumulatedText.append(currentTextChunk); + if (part.get().thought().orElse(false)) { + accumulatedThoughtText.append(currentTextChunk); + } else { + accumulatedText.append(currentTextChunk); + } LlmResponse partialResponse = currentProcessedLlmResponse.toBuilder().partial(true).build(); responsesToEmit.add(partialResponse); } else { + if (accumulatedThoughtText.length() > 0 + && GeminiUtil.shouldEmitAccumulatedText(currentProcessedLlmResponse)) { + LlmResponse aggregatedThoughtResponse = + thinkingResponseFromText(accumulatedThoughtText.toString()); + responsesToEmit.add(aggregatedThoughtResponse); + accumulatedThoughtText.setLength(0); + } if (accumulatedText.length() > 0 && GeminiUtil.shouldEmitAccumulatedText(currentProcessedLlmResponse)) { LlmResponse aggregatedTextResponse = responseFromText(accumulatedText.toString()); @@ -296,6 +309,16 @@ private static LlmResponse responseFromText(String accumulatedText) { .build(); } + private static LlmResponse thinkingResponseFromText(String accumulatedThoughtText) { + return LlmResponse.builder() + .content( + Content.builder() + .role("model") + .parts(Part.fromText(accumulatedThoughtText).toBuilder().thought(true).build()) + .build()) + .build(); + } + @Override public BaseLlmConnection connect(LlmRequest llmRequest) { if (!apiClient.vertexAI()) { diff --git a/core/src/main/java/com/google/adk/models/GeminiUtil.java b/core/src/main/java/com/google/adk/models/GeminiUtil.java index 501b279b..60317f75 100644 --- a/core/src/main/java/com/google/adk/models/GeminiUtil.java +++ b/core/src/main/java/com/google/adk/models/GeminiUtil.java @@ -26,6 +26,7 @@ import com.google.genai.types.FileData; import com.google.genai.types.Part; import java.util.List; +import java.util.Optional; import java.util.stream.Stream; /** Request / Response utilities for {@link Gemini}. */ @@ -41,7 +42,7 @@ private GeminiUtil() {} * Prepares an {@link LlmRequest} for the GenerateContent API. * *

This method can optionally sanitize the request and ensures that the last content part is - * from the user to prompt a model response. It also strips out any parts marked as "thoughts". + * from the user to prompt a model response. * * @param llmRequest The original {@link LlmRequest}. * @param sanitize Whether to sanitize the request to be compatible with the Gemini API backend. @@ -53,8 +54,7 @@ public static LlmRequest prepareGenenerateContentRequest( llmRequest = sanitizeRequestForGeminiApi(llmRequest); } List contents = ensureModelResponse(llmRequest.contents()); - List finalContents = stripThoughts(contents); - return llmRequest.toBuilder().contents(finalContents).build(); + return llmRequest.toBuilder().contents(contents).build(); } /** @@ -142,19 +142,17 @@ static List ensureModelResponse(List contents) { } /** - * Extracts text content from the first part of an LlmResponse, if available. + * Extracts the first part of an LlmResponse, if available. * - * @param llmResponse The LlmResponse to extract text from. - * @return The text content, or an empty string if not found. + * @param llmResponse The LlmResponse to extract the first part from. + * @return The first part, or an empty optional if not found. */ - public static String getTextFromLlmResponse(LlmResponse llmResponse) { + public static Optional getPart0FromLlmResponse(LlmResponse llmResponse) { return llmResponse .content() .flatMap(Content::parts) .filter(parts -> !parts.isEmpty()) - .map(parts -> parts.get(0)) - .flatMap(Part::text) - .orElse(""); + .map(parts -> parts.get(0)); } /** @@ -177,19 +175,4 @@ public static boolean shouldEmitAccumulatedText(LlmResponse currentLlmResponse) .flatMap(Part::inlineData) .isEmpty(); } - - /** Removes any `Part` that contains only a `thought` from the content list. */ - public static List stripThoughts(List originalContents) { - return originalContents.stream() - .map( - content -> { - ImmutableList nonThoughtParts = - content.parts().orElse(ImmutableList.of()).stream() - // Keep if thought is not present OR if thought is present but false - .filter(part -> part.thought().map(isThought -> !isThought).orElse(true)) - .collect(toImmutableList()); - return content.toBuilder().parts(nonThoughtParts).build(); - }) - .collect(toImmutableList()); - } } diff --git a/core/src/test/java/com/google/adk/models/GeminiTest.java b/core/src/test/java/com/google/adk/models/GeminiTest.java index 976e1b19..07dd675e 100644 --- a/core/src/test/java/com/google/adk/models/GeminiTest.java +++ b/core/src/test/java/com/google/adk/models/GeminiTest.java @@ -139,7 +139,8 @@ private void assertLlmResponses( private static Predicate isPartialTextResponse(String expectedText) { return response -> { assertThat(response.partial()).hasValue(true); - assertThat(GeminiUtil.getTextFromLlmResponse(response)).isEqualTo(expectedText); + assertThat(GeminiUtil.getPart0FromLlmResponse(response).flatMap(Part::text).orElse("")) + .isEqualTo(expectedText); return true; }; } @@ -147,7 +148,8 @@ private static Predicate isPartialTextResponse(String expectedText) private static Predicate isFinalTextResponse(String expectedText) { return response -> { assertThat(response.partial()).isEmpty(); - assertThat(GeminiUtil.getTextFromLlmResponse(response)).isEqualTo(expectedText); + assertThat(GeminiUtil.getPart0FromLlmResponse(response).flatMap(Part::text).orElse("")) + .isEqualTo(expectedText); return true; }; } @@ -162,7 +164,8 @@ private static Predicate isFunctionCallResponse() { private static Predicate isEmptyResponse() { return response -> { assertThat(response.partial()).isEmpty(); - assertThat(GeminiUtil.getTextFromLlmResponse(response)).isEmpty(); + assertThat(GeminiUtil.getPart0FromLlmResponse(response).flatMap(Part::text).orElse("")) + .isEmpty(); return true; }; } diff --git a/core/src/test/java/com/google/adk/models/GeminiUtilTest.java b/core/src/test/java/com/google/adk/models/GeminiUtilTest.java index 0fae0767..49e73511 100644 --- a/core/src/test/java/com/google/adk/models/GeminiUtilTest.java +++ b/core/src/test/java/com/google/adk/models/GeminiUtilTest.java @@ -39,125 +39,49 @@ public final class GeminiUtilTest { Content.fromParts(Part.fromText(GeminiUtil.CONTINUE_OUTPUT_MESSAGE)); @Test - public void stripThoughts_emptyList_returnsEmptyList() { - assertThat(GeminiUtil.stripThoughts(ImmutableList.of())).isEmpty(); - } - - @Test - public void stripThoughts_contentWithNoParts_returnsContentWithNoParts() { - Content content = Content.builder().build(); - Content expected = toContent(); - - List result = GeminiUtil.stripThoughts(ImmutableList.of(content)); - - assertThat(result).containsExactly(expected); - } - - @Test - public void stripThoughts_partsWithoutThought_returnsAllParts() { - Part part1 = createTextPart("Hello"); - Part part2 = createTextPart("World"); - Content content = toContent(part1, part2); - - List result = GeminiUtil.stripThoughts(ImmutableList.of(content)); - - assertThat(result.get(0).parts().get()).containsExactly(part1, part2).inOrder(); - } - - @Test - public void stripThoughts_partsWithThoughtFalse_returnsAllParts() { - Part part1 = createThoughtPart("Regular text", false); - Part part2 = createTextPart("Another text"); - Content content = toContent(part1, part2); - - List result = GeminiUtil.stripThoughts(ImmutableList.of(content)); - - assertThat(result.get(0).parts().get()).containsExactly(part1, part2).inOrder(); - } - - @Test - public void stripThoughts_partsWithThoughtTrue_stripsThoughtParts() { - Part part1 = createTextPart("Visible text"); - Part part2 = createThoughtPart("Internal thought", true); - Part part3 = createTextPart("More visible text"); - Content content = toContent(part1, part2, part3); - - List result = GeminiUtil.stripThoughts(ImmutableList.of(content)); - - assertThat(result.get(0).parts().get()).containsExactly(part1, part3).inOrder(); - } - - @Test - public void stripThoughts_mixedParts_stripsOnlyThoughtTrue() { - Part part1 = createTextPart("Text 1"); - Part part2 = createThoughtPart("Thought 1", true); - Part part3 = createTextPart("Text 2"); - Part part4 = createThoughtPart("Not a thought", false); - Part part5 = createThoughtPart("Thought 2", true); - Content content = toContent(part1, part2, part3, part4, part5); - - List result = GeminiUtil.stripThoughts(ImmutableList.of(content)); - - assertThat(result.get(0).parts().get()).containsExactly(part1, part3, part4).inOrder(); - } - - @Test - public void stripThoughts_multipleContents_stripsThoughtsFromEach() { - Part partA1 = createTextPart("A1"); - Part partA2 = createThoughtPart("A2 Thought", true); - Content contentA = toContent(partA1, partA2); - - Part partB1 = createThoughtPart("B1 Thought", true); - Part partB2 = createTextPart("B2"); - Part partB3 = createThoughtPart("B3 Not Thought", false); - Content contentB = toContent(partB1, partB2, partB3); - - List result = GeminiUtil.stripThoughts(ImmutableList.of(contentA, contentB)); + public void getPart0FromLlmResponse_noContent_returnsEmpty() { + LlmResponse llmResponse = LlmResponse.builder().build(); - assertThat(result).hasSize(2); - assertThat(result.get(0).parts().get()).containsExactly(partA1); - assertThat(result.get(1).parts().get()).containsExactly(partB2, partB3).inOrder(); + assertThat(GeminiUtil.getPart0FromLlmResponse(llmResponse)).isEmpty(); } @Test - public void getTextFromLlmResponse_noContent_returnsEmptyString() { - LlmResponse llmResponse = LlmResponse.builder().build(); + public void getPart0FromLlmResponse_contentWithNoParts_returnsEmpty() { + LlmResponse llmResponse = toResponse(Content.builder().build()); - assertThat(GeminiUtil.getTextFromLlmResponse(llmResponse)).isEmpty(); + assertThat(GeminiUtil.getPart0FromLlmResponse(llmResponse)).isEmpty(); } @Test - public void getTextFromLlmResponse_contentWithNoParts_returnsEmptyString() { - LlmResponse llmResponse = toResponse(Content.builder().build()); + public void getPart0FromLlmResponse_contentWithEmptyPartsList_returnsEmpty() { + LlmResponse llmResponse = toResponse(toContent()); - assertThat(GeminiUtil.getTextFromLlmResponse(llmResponse)).isEmpty(); + assertThat(GeminiUtil.getPart0FromLlmResponse(llmResponse)).isEmpty(); } @Test - public void getTextFromLlmResponse_firstPartHasNoText_returnsEmptyString() { - Part part1 = Part.builder().inlineData(Blob.builder().mimeType("image/png").build()).build(); - LlmResponse llmResponse = toResponse(part1); + public void getPart0FromLlmResponse_contentWithSinglePart_returnsFirstPart() { + Part expectedPart = createTextPart("Hello world"); + LlmResponse llmResponse = toResponse(expectedPart); - assertThat(GeminiUtil.getTextFromLlmResponse(llmResponse)).isEmpty(); + assertThat(GeminiUtil.getPart0FromLlmResponse(llmResponse)).hasValue(expectedPart); } @Test - public void getTextFromLlmResponse_firstPartHasText_returnsText() { - String expectedText = "The quick brown fox."; - Part part1 = createTextPart(expectedText); - LlmResponse llmResponse = toResponse(part1); + public void getPart0FromLlmResponse_contentWithMultipleParts_returnsFirstPart() { + Part firstPart = createTextPart("First part"); + Part secondPart = createTextPart("Second part"); + LlmResponse llmResponse = toResponse(firstPart, secondPart); - assertThat(GeminiUtil.getTextFromLlmResponse(llmResponse)).isEqualTo(expectedText); + assertThat(GeminiUtil.getPart0FromLlmResponse(llmResponse)).hasValue(firstPart); } @Test - public void getTextFromLlmResponse_multipleParts_returnsTextFromFirstPartOnly() { - String expectedText = "First part text."; - Part part1 = createTextPart(expectedText); - Part part2 = createTextPart("Second part text."); - LlmResponse llmResponse = toResponse(part1, part2); + public void getPart0FromLlmResponse_contentWithThoughtPart_returnsFirstPart() { + Part expectedPart = createThoughtPart("I need to think about this", true); + LlmResponse llmResponse = toResponse(expectedPart); - assertThat(GeminiUtil.getTextFromLlmResponse(llmResponse)).isEqualTo(expectedText); + assertThat(GeminiUtil.getPart0FromLlmResponse(llmResponse)).hasValue(expectedPart); } @Test