diff --git a/spring-ai-commons/src/main/java/org/springframework/ai/document/DefaultContentFormatter.java b/spring-ai-commons/src/main/java/org/springframework/ai/document/DefaultContentFormatter.java index c6f9b6a5cae..e280518c8ce 100644 --- a/spring-ai-commons/src/main/java/org/springframework/ai/document/DefaultContentFormatter.java +++ b/spring-ai-commons/src/main/java/org/springframework/ai/document/DefaultContentFormatter.java @@ -19,7 +19,6 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; -import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; @@ -121,13 +120,13 @@ public String format(Document document, MetadataMode metadataMode) { * @param metadata Document metadata. * @return Returns the filtered by configured mode metadata. */ - protected Map metadataFilter(Map metadata, MetadataMode metadataMode) { + private Map metadataFilter(Map metadata, MetadataMode metadataMode) { if (metadataMode == MetadataMode.ALL) { - return new HashMap<>(metadata); + return metadata; } if (metadataMode == MetadataMode.NONE) { - return new HashMap<>(Collections.emptyMap()); + return Collections.emptyMap(); } Set usableMetadataKeys = new HashSet<>(metadata.keySet()); @@ -139,10 +138,10 @@ else if (metadataMode == MetadataMode.EMBED) { usableMetadataKeys.removeAll(this.excludedEmbedMetadataKeys); } - return new HashMap<>(metadata.entrySet() + return metadata.entrySet() .stream() .filter(e -> usableMetadataKeys.contains(e.getKey())) - .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue))); + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); } public String getMetadataTemplate() { diff --git a/spring-ai-commons/src/main/java/org/springframework/ai/embedding/TokenCountBatchingStrategy.java b/spring-ai-commons/src/main/java/org/springframework/ai/embedding/TokenCountBatchingStrategy.java index 7ffc01fee5d..f421b11591e 100644 --- a/spring-ai-commons/src/main/java/org/springframework/ai/embedding/TokenCountBatchingStrategy.java +++ b/spring-ai-commons/src/main/java/org/springframework/ai/embedding/TokenCountBatchingStrategy.java @@ -50,6 +50,7 @@ * @author Mark Pollack * @author Laura Trotta * @author Jihoon Kim + * @author Yanming Zhou * @since 1.0.0 */ public class TokenCountBatchingStrategy implements BatchingStrategy { @@ -153,15 +154,15 @@ public List> batch(List documents) { documentTokens.put(document, tokenCount); } - for (Document document : documentTokens.keySet()) { - Integer tokenCount = documentTokens.get(document); - if (currentSize + tokenCount > this.maxInputTokenCount) { + for (Map.Entry entry : documentTokens.entrySet()) { + Document document = entry.getKey(); + currentSize += entry.getValue(); + if (currentSize > this.maxInputTokenCount) { batches.add(currentBatch); currentBatch = new ArrayList<>(); currentSize = 0; } currentBatch.add(document); - currentSize += tokenCount; } if (!currentBatch.isEmpty()) { batches.add(currentBatch);