diff --git a/docs/release_notes.md b/docs/release_notes.md index babf0b8c7..f872d0927 100644 --- a/docs/release_notes.md +++ b/docs/release_notes.md @@ -12,7 +12,8 @@ ### ✨ New Functionality -- Extend `OpenAiClientException` and `OrchestrationClientException` to retrieve error diagnostics information received from remote service. +- Extend `OpenAiClientException` and `OrchestrationClientException` to retrieve error diagnostics information received + from remote service. New available accessors for troubleshooting: `getErrorResponse()`, `getHttpResponse()` and, `getHttpRequest()`. Please note: depending on the error response, these methods may return `null` if the information is not available. - [OpenAI] Added new models for `OpenAiModel`: `GPT_5`, `GPT_5_MINI` and `GPT_5_NANO`. @@ -22,6 +23,8 @@ `OrchestrationAiModel.GEMINI_1_5_FLASH` - Replacement are `GEMINI_2_5_PRO` and `GEMINI_2_5_FLASH`. - [Orchestration] Deprecated `OrchestrationAiModel.IBM_GRANITE_13B_CHAT` with no replacement. +- [OpenAI] [Introduced SpringAI integration with our OpenAI client.](https://sap.github.io/ai-sdk/docs/java/spring-ai/openai) + - Added `OpenAiChatModel` ### 📈 Improvements diff --git a/foundation-models/openai/pom.xml b/foundation-models/openai/pom.xml index 9ab13fbab..40d8d7cd5 100644 --- a/foundation-models/openai/pom.xml +++ b/foundation-models/openai/pom.xml @@ -38,12 +38,12 @@ ${project.basedir}/../../ - 72% - 80% - 76% - 70% - 83% - 84% + 81% + 91% + 88% + 79% + 90% + 92% @@ -112,6 +112,11 @@ spring-ai-model true + + io.projectreactor + reactor-core + true + org.projectlombok @@ -149,6 +154,12 @@ javaparser-core test + + org.springframework.ai + spring-ai-client-chat + test + true + diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiMessage.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiMessage.java index fb2cdcaed..7a5b9c112 100644 --- a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiMessage.java +++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiMessage.java @@ -1,6 +1,7 @@ package com.sap.ai.sdk.foundationmodels.openai; import com.google.common.annotations.Beta; +import java.util.ArrayList; import java.util.List; import javax.annotation.Nonnull; @@ -46,6 +47,18 @@ static OpenAiAssistantMessage assistant(@Nonnull final String message) { return new OpenAiAssistantMessage(message); } + /** + * A convenience method to create an assistant message. + * + * @param toolCalls tool calls to associate with the message. + * @return the assistant message. + */ + @Nonnull + static OpenAiAssistantMessage assistant(@Nonnull final List toolCalls) { + return new OpenAiAssistantMessage( + new OpenAiMessageContent(List.of()), new ArrayList<>(toolCalls)); + } + /** * A convenience method to create a system message. * diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiToolCall.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiToolCall.java index 9c400c4f4..9a4d3ff27 100644 --- a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiToolCall.java +++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiToolCall.java @@ -1,6 +1,7 @@ package com.sap.ai.sdk.foundationmodels.openai; import com.google.common.annotations.Beta; +import javax.annotation.Nonnull; /** * Represents a tool called by an OpenAI model. @@ -8,4 +9,19 @@ * @since 1.6.0 */ @Beta -public sealed interface OpenAiToolCall permits OpenAiFunctionCall {} +public sealed interface OpenAiToolCall permits OpenAiFunctionCall { + /** + * Creates a new instance of {@link OpenAiToolCall}. + * + * @param id The unique identifier for the tool call. + * @param name The name of the tool to be called. + * @param arguments The arguments for the tool call, encoded as a JSON string. + * @return A new instance of {@link OpenAiToolCall}. + * @since 1.10.0 + */ + @Nonnull + static OpenAiToolCall function( + @Nonnull final String id, @Nonnull final String name, @Nonnull final String arguments) { + return new OpenAiFunctionCall(id, name, arguments); + } +} diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/spring/OpenAiChatModel.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/spring/OpenAiChatModel.java new file mode 100644 index 000000000..c31287907 --- /dev/null +++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/spring/OpenAiChatModel.java @@ -0,0 +1,223 @@ +package com.sap.ai.sdk.foundationmodels.openai.spring; + +import static org.springframework.ai.model.tool.ToolCallingChatOptions.isInternalToolExecutionEnabled; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.sap.ai.sdk.foundationmodels.openai.OpenAiChatCompletionDelta; +import com.sap.ai.sdk.foundationmodels.openai.OpenAiChatCompletionRequest; +import com.sap.ai.sdk.foundationmodels.openai.OpenAiChatCompletionResponse; +import com.sap.ai.sdk.foundationmodels.openai.OpenAiClient; +import com.sap.ai.sdk.foundationmodels.openai.OpenAiMessage; +import com.sap.ai.sdk.foundationmodels.openai.OpenAiToolCall; +import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionMessageToolCall; +import com.sap.ai.sdk.foundationmodels.openai.generated.model.ChatCompletionTool; +import com.sap.ai.sdk.foundationmodels.openai.generated.model.CreateChatCompletionResponseChoicesInner; +import com.sap.ai.sdk.foundationmodels.openai.generated.model.FunctionObject; +import io.vavr.control.Option; +import java.math.BigDecimal; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.function.Function; +import javax.annotation.Nonnull; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import lombok.val; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.AssistantMessage.ToolCall; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.ToolResponseMessage; +import org.springframework.ai.chat.metadata.ChatGenerationMetadata; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.model.tool.DefaultToolCallingManager; +import org.springframework.ai.model.tool.ToolCallingChatOptions; +import reactor.core.publisher.Flux; + +/** + * OpenAI Chat Model implementation that interacts with the OpenAI API to generate chat completions. + */ +@Slf4j +@RequiredArgsConstructor +public class OpenAiChatModel implements ChatModel { + + private final OpenAiClient client; + + @Nonnull + private final DefaultToolCallingManager toolCallingManager = + DefaultToolCallingManager.builder().build(); + + @Override + @Nonnull + public ChatResponse call(@Nonnull final Prompt prompt) { + val options = prompt.getOptions(); + var request = new OpenAiChatCompletionRequest(extractMessages(prompt)); + + if (options != null) { + request = extractOptions(request, options); + } + if ((options instanceof ToolCallingChatOptions toolOptions)) { + request = request.withTools(extractTools(toolOptions)); + } + + val result = client.chatCompletion(request); + val response = new ChatResponse(toGenerations(result)); + + if (options != null && isInternalToolExecutionEnabled(options) && response.hasToolCalls()) { + val toolExecutionResult = toolCallingManager.executeToolCalls(prompt, response); + // Send the tool execution result back to the model. + return call(new Prompt(toolExecutionResult.conversationHistory(), options)); + } + return response; + } + + @Override + @Nonnull + public Flux stream(@Nonnull final Prompt prompt) { + val options = prompt.getOptions(); + var request = new OpenAiChatCompletionRequest(extractMessages(prompt)); + + if (options != null) { + request = extractOptions(request, options); + } + if ((options instanceof ToolCallingChatOptions toolOptions)) { + request = request.withTools(extractTools(toolOptions)); + } + + val stream = client.streamChatCompletionDeltas(request); + final Flux flux = + Flux.generate( + stream::iterator, + (iterator, sink) -> { + if (iterator.hasNext()) { + sink.next(iterator.next()); + } else { + sink.complete(); + } + return iterator; + }); + return flux.map( + delta -> { + val assistantMessage = new AssistantMessage(delta.getDeltaContent(), Map.of()); + val metadata = + ChatGenerationMetadata.builder().finishReason(delta.getFinishReason()).build(); + return new ChatResponse(List.of(new Generation(assistantMessage, metadata))); + }); + } + + private static List extractMessages(final Prompt prompt) { + final List result = new ArrayList<>(); + for (final Message message : prompt.getInstructions()) { + switch (message.getMessageType()) { + case USER -> Option.of(message.getText()).peek(t -> result.add(OpenAiMessage.user(t))); + case SYSTEM -> Option.of(message.getText()).peek(t -> result.add(OpenAiMessage.system(t))); + case ASSISTANT -> addAssistantMessage(result, (AssistantMessage) message); + case TOOL -> addToolMessages(result, (ToolResponseMessage) message); + } + } + return result; + } + + private static void addAssistantMessage( + final List result, final AssistantMessage message) { + if (message.getText() != null) { + result.add(OpenAiMessage.assistant(message.getText())); + return; + } + final Function callTranslate = + toolCall -> OpenAiToolCall.function(toolCall.id(), toolCall.name(), toolCall.arguments()); + val calls = message.getToolCalls().stream().map(callTranslate).toList(); + result.add(OpenAiMessage.assistant(calls)); + } + + private static void addToolMessages( + final List result, final ToolResponseMessage message) { + for (final ToolResponseMessage.ToolResponse response : message.getResponses()) { + result.add(OpenAiMessage.tool(response.responseData(), response.id())); + } + } + + @Nonnull + private static List toGenerations( + @Nonnull final OpenAiChatCompletionResponse result) { + return result.getOriginalResponse().getChoices().stream() + .map(OpenAiChatModel::toGeneration) + .toList(); + } + + @Nonnull + private static Generation toGeneration( + @Nonnull final CreateChatCompletionResponseChoicesInner choice) { + val metadata = + ChatGenerationMetadata.builder().finishReason(choice.getFinishReason().getValue()); + metadata.metadata("index", choice.getIndex()); + if (choice.getLogprobs() != null && !choice.getLogprobs().getContent().isEmpty()) { + metadata.metadata("logprobs", choice.getLogprobs().getContent()); + } + val message = choice.getMessage(); + val calls = new ArrayList(); + if (message.getToolCalls() != null) { + for (final ChatCompletionMessageToolCall c : message.getToolCalls()) { + val fnc = c.getFunction(); + calls.add( + new ToolCall(c.getId(), c.getType().getValue(), fnc.getName(), fnc.getArguments())); + } + } + + val assistantMessage = new AssistantMessage(message.getContent(), Map.of(), calls); + return new Generation(assistantMessage, metadata.build()); + } + + /** + * Adds options to the request. + * + * @param request the request to modify + * @param options the options to extract + * @return the modified request with options applied + */ + @Nonnull + protected static OpenAiChatCompletionRequest extractOptions( + @Nonnull OpenAiChatCompletionRequest request, @Nonnull final ChatOptions options) { + request = request.withStop(options.getStopSequences()).withMaxTokens(options.getMaxTokens()); + if (options.getTemperature() != null) { + request = request.withTemperature(BigDecimal.valueOf(options.getTemperature())); + } + if (options.getTopP() != null) { + request = request.withTopP(BigDecimal.valueOf(options.getTopP())); + } + if (options.getPresencePenalty() != null) { + request = request.withPresencePenalty(BigDecimal.valueOf(options.getPresencePenalty())); + } + if (options.getFrequencyPenalty() != null) { + request = request.withFrequencyPenalty(BigDecimal.valueOf(options.getFrequencyPenalty())); + } + return request; + } + + private static List extractTools(final ToolCallingChatOptions options) { + val tools = new ArrayList(); + for (val toolCallback : options.getToolCallbacks()) { + val toolDefinition = toolCallback.getToolDefinition(); + try { + final Map params = + new ObjectMapper().readValue(toolDefinition.inputSchema(), new TypeReference<>() {}); + val toolType = ChatCompletionTool.TypeEnum.FUNCTION; + val toolFunction = + new FunctionObject() + .name(toolDefinition.name()) + .description(toolDefinition.description()) + .parameters(params); + val tool = new ChatCompletionTool().type(toolType).function(toolFunction); + tools.add(tool); + } catch (JsonProcessingException e) { + log.warn("Failed to add tool to the chat request: {}", e.getMessage()); + } + } + return tools; + } +} diff --git a/foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/spring/OpenAiChatModelTest.java b/foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/spring/OpenAiChatModelTest.java new file mode 100644 index 000000000..03b38616f --- /dev/null +++ b/foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/spring/OpenAiChatModelTest.java @@ -0,0 +1,237 @@ +package com.sap.ai.sdk.foundationmodels.openai.spring; + +import static com.github.tomakehurst.wiremock.client.WireMock.aResponse; +import static com.github.tomakehurst.wiremock.client.WireMock.anyUrl; +import static com.github.tomakehurst.wiremock.client.WireMock.equalTo; +import static com.github.tomakehurst.wiremock.client.WireMock.equalToJson; +import static com.github.tomakehurst.wiremock.client.WireMock.post; +import static com.github.tomakehurst.wiremock.client.WireMock.postRequestedFor; +import static com.github.tomakehurst.wiremock.client.WireMock.stubFor; +import static com.github.tomakehurst.wiremock.client.WireMock.urlPathEqualTo; +import static com.github.tomakehurst.wiremock.client.WireMock.verify; +import static com.github.tomakehurst.wiremock.stubbing.Scenario.STARTED; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; + +import com.github.tomakehurst.wiremock.junit5.WireMockRuntimeInfo; +import com.github.tomakehurst.wiremock.junit5.WireMockTest; +import com.sap.ai.sdk.foundationmodels.openai.OpenAiClient; +import com.sap.cloud.sdk.cloudplatform.connectivity.ApacheHttpClient5Accessor; +import com.sap.cloud.sdk.cloudplatform.connectivity.ApacheHttpClient5Cache; +import com.sap.cloud.sdk.cloudplatform.connectivity.DefaultHttpDestination; +import java.io.IOException; +import java.io.InputStream; +import java.util.List; +import java.util.Objects; +import java.util.function.Function; +import lombok.val; +import org.apache.hc.client5.http.classic.HttpClient; +import org.apache.hc.core5.http.ContentType; +import org.apache.hc.core5.http.io.entity.InputStreamEntity; +import org.apache.hc.core5.http.message.BasicClassicHttpResponse; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor; +import org.springframework.ai.chat.memory.InMemoryChatMemoryRepository; +import org.springframework.ai.chat.memory.MessageWindowChatMemory; +import org.springframework.ai.chat.messages.AssistantMessage.ToolCall; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.model.tool.DefaultToolCallingChatOptions; +import org.springframework.ai.support.ToolCallbacks; +import reactor.core.publisher.Flux; + +@WireMockTest +public class OpenAiChatModelTest { + + private final Function fileLoader = + filename -> Objects.requireNonNull(getClass().getClassLoader().getResourceAsStream(filename)); + + private static OpenAiChatModel client; + private static Prompt prompt; + + @BeforeEach + void setup(WireMockRuntimeInfo server) { + final DefaultHttpDestination destination = + DefaultHttpDestination.builder(server.getHttpBaseUrl()).build(); + client = new OpenAiChatModel(OpenAiClient.withCustomDestination(destination)); + prompt = new Prompt("Hello World! Why is this phrase so famous?"); + ApacheHttpClient5Accessor.setHttpClientCache(ApacheHttpClient5Cache.DISABLED); + } + + @AfterEach + void reset() { + ApacheHttpClient5Accessor.setHttpClientCache(null); + ApacheHttpClient5Accessor.setHttpClientFactory(null); + } + + @Test + void testCompletion() { + stubFor( + post(urlPathEqualTo("/chat/completions")) + .withQueryParam("api-version", equalTo("2024-02-01")) + .willReturn( + aResponse() + .withBodyFile("chatCompletionResponse.json") + .withHeader("Content-Type", "application/json"))); + val result = client.call(prompt); + + assertThat(result).isNotNull(); + assertThat(result.getResult().getOutput().getText()).isNotEmpty(); + } + + @Test + void testStreamCompletion() throws IOException { + try (val inputStream = spy(fileLoader.apply("streamChatCompletion.txt"))) { + + val httpClient = mock(HttpClient.class); + ApacheHttpClient5Accessor.setHttpClientFactory(destination -> httpClient); + + // Create a mock response + val mockResponse = new BasicClassicHttpResponse(200, "OK"); + val inputStreamEntity = new InputStreamEntity(inputStream, ContentType.TEXT_PLAIN); + mockResponse.setEntity(inputStreamEntity); + mockResponse.setHeader("Content-Type", "text/event-flux"); + + // Configure the HttpClient mock to return the mock response + doReturn(mockResponse).when(httpClient).executeOpen(any(), any(), any()); + + Flux flux = client.stream(prompt); + val deltaList = flux.toStream().toList(); + + assertThat(deltaList).hasSize(5); + // the first delta doesn't have any content + assertThat(deltaList.get(0).getResult().getOutput().getText()).isEqualTo(""); + assertThat(deltaList.get(1).getResult().getOutput().getText()).isEqualTo(""); + assertThat(deltaList.get(2).getResult().getOutput().getText()).isEqualTo("Sure"); + assertThat(deltaList.get(3).getResult().getOutput().getText()).isEqualTo("!"); + assertThat(deltaList.get(4).getResult().getOutput().getText()).isEqualTo(""); + + assertThat(deltaList.get(0).getResult().getMetadata().getFinishReason()).isEqualTo(null); + assertThat(deltaList.get(1).getResult().getMetadata().getFinishReason()).isEqualTo(null); + assertThat(deltaList.get(2).getResult().getMetadata().getFinishReason()).isEqualTo(null); + assertThat(deltaList.get(3).getResult().getMetadata().getFinishReason()).isEqualTo(null); + assertThat(deltaList.get(4).getResult().getMetadata().getFinishReason()).isEqualTo("stop"); + + Mockito.verify(inputStream, times(1)).close(); + } + } + + @Test + void testToolCallsWithoutExecution() throws IOException { + stubFor( + post(urlPathEqualTo("/chat/completions")) + .willReturn( + aResponse() + .withHeader("Content-Type", "application/json") + .withBodyFile("weatherToolResponse.json"))); + + var options = new DefaultToolCallingChatOptions(); + options.setToolCallbacks(List.of(ToolCallbacks.from(new WeatherMethod()))); + options.setInternalToolExecutionEnabled(false); + val prompt = new Prompt("What is the weather in Potsdam and in Toulouse?", options); + val result = client.call(prompt); + + List toolCalls = result.getResult().getOutput().getToolCalls(); + assertThat(toolCalls).hasSize(2); + ToolCall toolCall1 = toolCalls.get(0); + ToolCall toolCall2 = toolCalls.get(1); + assertThat(toolCall1.type()).isEqualTo("function"); + assertThat(toolCall2.type()).isEqualTo("function"); + assertThat(toolCall1.name()).isEqualTo("getCurrentWeather"); + assertThat(toolCall2.name()).isEqualTo("getCurrentWeather"); + assertThat(toolCall1.arguments()) + .isEqualTo("{\"arg0\": {\"location\": \"Potsdam\", \"unit\": \"C\"}}"); + assertThat(toolCall2.arguments()) + .isEqualTo("{\"arg0\": {\"location\": \"Toulouse\", \"unit\": \"C\"}}"); + + try (var request1InputStream = fileLoader.apply("toolCallsRequest.json")) { + final String request1 = new String(request1InputStream.readAllBytes()); + verify(postRequestedFor(anyUrl()).withRequestBody(equalToJson(request1))); + } + } + + @Test + void testToolCallsWithExecution() throws IOException { + // https://platform.openai.com/docs/guides/function-calling + stubFor( + post(urlPathEqualTo("/chat/completions")) + .inScenario("Tool Calls") + .willReturn( + aResponse() + .withHeader("Content-Type", "application/json") + .withBodyFile("weatherToolResponse.json")) + .willSetStateTo("Second Call")); + + stubFor( + post(urlPathEqualTo("/chat/completions")) + .inScenario("Tool Calls") + .whenScenarioStateIs("Second Call") + .willReturn( + aResponse() + .withBodyFile("weatherToolResponse2.json") + .withHeader("Content-Type", "application/json"))); + + var options = new DefaultToolCallingChatOptions(); + options.setToolCallbacks(List.of(ToolCallbacks.from(new WeatherMethod()))); + val prompt = new Prompt("What is the weather in Potsdam and in Toulouse?", options); + val result = client.call(prompt); + + assertThat(result.getResult().getOutput().getText()) + .isEqualTo("The current temperature in Potsdam is 30°C and in Toulouse 30°C."); + + try (var request1InputStream = fileLoader.apply("toolCallsRequest.json")) { + try (var request2InputStream = fileLoader.apply("toolCallsRequest2.json")) { + final String request1 = new String(request1InputStream.readAllBytes()); + final String request2 = new String(request2InputStream.readAllBytes()); + verify(postRequestedFor(anyUrl()).withRequestBody(equalToJson(request1))); + verify(postRequestedFor(anyUrl()).withRequestBody(equalToJson(request2))); + } + } + } + + @Test + void testChatMemory() throws IOException { + stubFor( + post(urlPathEqualTo("/chat/completions")) + .inScenario("Chat Memory") + .whenScenarioStateIs(STARTED) + .willReturn( + aResponse() + .withBodyFile("templatingResponse.json") // The response is not important + .withHeader("Content-Type", "application/json")) + .willSetStateTo("Second Call")); + + stubFor( + post(urlPathEqualTo("/chat/completions")) + .inScenario("Chat Memory") + .whenScenarioStateIs("Second Call") + .willReturn( + aResponse() + .withBodyFile("templatingResponse2.json") // The response is not important + .withHeader("Content-Type", "application/json"))); + + val repository = new InMemoryChatMemoryRepository(); + val memory = MessageWindowChatMemory.builder().chatMemoryRepository(repository).build(); + val advisor = MessageChatMemoryAdvisor.builder(memory).build(); + val cl = ChatClient.builder(client).defaultAdvisors(advisor).build(); + val prompt1 = new Prompt("What is the capital of France?"); + val prompt2 = new Prompt("And what is the typical food there?"); + + cl.prompt(prompt1).call().content(); + cl.prompt(prompt2).call().content(); + // The response is not important + // We just want to verify that the second call remembered the first call + try (var requestInputStream = fileLoader.apply("chatMemory.json")) { + final String request = new String(requestInputStream.readAllBytes()); + verify(postRequestedFor(anyUrl()).withRequestBody(equalToJson(request))); + } + } +} diff --git a/foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/spring/WeatherMethod.java b/foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/spring/WeatherMethod.java new file mode 100644 index 000000000..d2cd25649 --- /dev/null +++ b/foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/spring/WeatherMethod.java @@ -0,0 +1,41 @@ +package com.sap.ai.sdk.foundationmodels.openai.spring; + +import javax.annotation.Nonnull; +import org.springframework.ai.tool.annotation.Tool; +import org.springframework.ai.tool.annotation.ToolParam; + +public class WeatherMethod { + + /** Unit of temperature */ + public enum Unit { + /** Celsius */ + @SuppressWarnings("unused") + C, + /** Fahrenheit */ + @SuppressWarnings("unused") + F + } + + /** + * Request for the weather + * + * @param location the city + * @param unit the unit of temperature + */ + public record Request(String location, Unit unit) {} + + /** + * Response for the weather + * + * @param temp the temperature + * @param unit the unit of temperature + */ + public record Response(double temp, Unit unit) {} + + @Nonnull + @SuppressWarnings("unused") + @Tool(description = "Get the weather in location") + Response getCurrentWeather(@ToolParam @Nonnull Request request) { + return new Response(30, request.unit); + } +} diff --git a/foundation-models/openai/src/test/resources/__files/templatingResponse.json b/foundation-models/openai/src/test/resources/__files/templatingResponse.json new file mode 100644 index 000000000..be85a5157 --- /dev/null +++ b/foundation-models/openai/src/test/resources/__files/templatingResponse.json @@ -0,0 +1,75 @@ +{ + "choices": [ + { + "content_filter_results": { + "hate": { + "filtered": false, + "severity": "safe" + }, + "self_harm": { + "filtered": false, + "severity": "safe" + }, + "sexual": { + "filtered": false, + "severity": "safe" + }, + "violence": { + "filtered": false, + "severity": "safe" + } + }, + "finish_reason": "stop", + "index": 0, + "message": { + "annotations": [], + "content": "The capital of France is Paris.", + "refusal": null, + "role": "assistant" + } + } + ], + "created": 1755099738, + "id": "chatcmpl-C47uE2MKhMBeb0jm2QY9OAw8fyNZx", + "model": "gpt-4o-mini-2024-07-18", + "object": "chat.completion", + "prompt_filter_results": [ + { + "content_filter_results": { + "hate": { + "filtered": false, + "severity": "safe" + }, + "self_harm": { + "filtered": false, + "severity": "safe" + }, + "sexual": { + "filtered": false, + "severity": "safe" + }, + "violence": { + "filtered": false, + "severity": "safe" + } + }, + "prompt_index": 0 + } + ], + "system_fingerprint": "fp_efad92c60b", + "usage": { + "completion_tokens": 8, + "completion_tokens_details": { + "accepted_prediction_tokens": 0, + "audio_tokens": 0, + "reasoning_tokens": 0, + "rejected_prediction_tokens": 0 + }, + "prompt_tokens": 14, + "prompt_tokens_details": { + "audio_tokens": 0, + "cached_tokens": 0 + }, + "total_tokens": 22 + } +} \ No newline at end of file diff --git a/foundation-models/openai/src/test/resources/__files/templatingResponse2.json b/foundation-models/openai/src/test/resources/__files/templatingResponse2.json new file mode 100644 index 000000000..f96a058c4 --- /dev/null +++ b/foundation-models/openai/src/test/resources/__files/templatingResponse2.json @@ -0,0 +1,75 @@ +{ + "choices": [ + { + "content_filter_results": { + "hate": { + "filtered": false, + "severity": "safe" + }, + "self_harm": { + "filtered": false, + "severity": "safe" + }, + "sexual": { + "filtered": false, + "severity": "safe" + }, + "violence": { + "filtered": false, + "severity": "safe" + } + }, + "finish_reason": "stop", + "index": 0, + "message": { + "annotations": [], + "content": "Typical food in Paris includes a variety of French cuisine, such as:\n\n1. **Croissants** - Flaky, buttery pastries often enjoyed at breakfast.\n2. **Baguette** - A long, thin loaf of French bread, commonly used for making sandwiches.\n3. **Escargots** - Snails typically cooked in garlic butter, often served as an appetizer.\n4. **Coq au Vin** - A classic French dish made with chicken braised in red wine, usually with mushrooms and lardons.\n5. **Ratatouille** - A vegetable dish that includes ingredients like zucchini, eggplant, bell peppers, and tomatoes.\n6. **Duck Confit** - Slow-cooked duck leg that is crispy on the outside and tender on the inside.\n7. **Cr\u00eapes** - Thin pancakes that can be filled with sweet or savory ingredients.\n8. **Tarte Tatin** - An upside-down caramelized apple tart.\n9. **Macarons** - Colorful almond meringue cookies filled with ganache, buttercream, or jam.\n\nParis is also known for its vibrant caf\u00e9 culture, where you can enjoy coffee alongside pastries or light meals.", + "refusal": null, + "role": "assistant" + } + } + ], + "created": 1755099739, + "id": "chatcmpl-C47uFlhhRd3CbStgBf77Unh8RJnMG", + "model": "gpt-4o-mini-2024-07-18", + "object": "chat.completion", + "prompt_filter_results": [ + { + "content_filter_results": { + "hate": { + "filtered": false, + "severity": "safe" + }, + "self_harm": { + "filtered": false, + "severity": "safe" + }, + "sexual": { + "filtered": false, + "severity": "safe" + }, + "violence": { + "filtered": false, + "severity": "safe" + } + }, + "prompt_index": 0 + } + ], + "system_fingerprint": "fp_efad92c60b", + "usage": { + "completion_tokens": 242, + "completion_tokens_details": { + "accepted_prediction_tokens": 0, + "audio_tokens": 0, + "reasoning_tokens": 0, + "rejected_prediction_tokens": 0 + }, + "prompt_tokens": 37, + "prompt_tokens_details": { + "audio_tokens": 0, + "cached_tokens": 0 + }, + "total_tokens": 279 + } +} diff --git a/foundation-models/openai/src/test/resources/__files/weatherToolResponse.json b/foundation-models/openai/src/test/resources/__files/weatherToolResponse.json new file mode 100644 index 000000000..14382b9d8 --- /dev/null +++ b/foundation-models/openai/src/test/resources/__files/weatherToolResponse.json @@ -0,0 +1,76 @@ +{ + "choices": [ + { + "content_filter_results": {}, + "finish_reason": "tool_calls", + "index": 0, + "message": { + "annotations": [], + "content": null, + "refusal": null, + "role": "assistant", + "tool_calls": [ + { + "function": { + "arguments": "{\"arg0\": {\"location\": \"Potsdam\", \"unit\": \"C\"}}", + "name": "getCurrentWeather" + }, + "id": "call_MQ7MyYGmoP5TpMSv6AfeWCg5", + "type": "function" + }, + { + "function": { + "arguments": "{\"arg0\": {\"location\": \"Toulouse\", \"unit\": \"C\"}}", + "name": "getCurrentWeather" + }, + "id": "call_BQpUfvkUUqx7e3yZv7Rmpnxy", + "type": "function" + } + ] + } + } + ], + "created": 1755092903, + "id": "chatcmpl-C467zkarjmr5ggy6qN41vBfseOJBK", + "model": "gpt-4o-mini-2024-07-18", + "object": "chat.completion", + "prompt_filter_results": [ + { + "content_filter_results": { + "hate": { + "filtered": false, + "severity": "safe" + }, + "self_harm": { + "filtered": false, + "severity": "safe" + }, + "sexual": { + "filtered": false, + "severity": "safe" + }, + "violence": { + "filtered": false, + "severity": "safe" + } + }, + "prompt_index": 0 + } + ], + "system_fingerprint": "fp_efad92c60b", + "usage": { + "completion_tokens": 66, + "completion_tokens_details": { + "accepted_prediction_tokens": 0, + "audio_tokens": 0, + "reasoning_tokens": 0, + "rejected_prediction_tokens": 0 + }, + "prompt_tokens": 70, + "prompt_tokens_details": { + "audio_tokens": 0, + "cached_tokens": 0 + }, + "total_tokens": 136 + } +} diff --git a/foundation-models/openai/src/test/resources/__files/weatherToolResponse2.json b/foundation-models/openai/src/test/resources/__files/weatherToolResponse2.json new file mode 100644 index 000000000..2f964869b --- /dev/null +++ b/foundation-models/openai/src/test/resources/__files/weatherToolResponse2.json @@ -0,0 +1,75 @@ +{ + "choices": [ + { + "content_filter_results": { + "hate": { + "filtered": false, + "severity": "safe" + }, + "self_harm": { + "filtered": false, + "severity": "safe" + }, + "sexual": { + "filtered": false, + "severity": "safe" + }, + "violence": { + "filtered": false, + "severity": "safe" + } + }, + "finish_reason": "stop", + "index": 0, + "message": { + "annotations": [], + "content": "The current temperature in Potsdam is 30°C and in Toulouse 30°C.", + "refusal": null, + "role": "assistant" + } + } + ], + "created": 1755092905, + "id": "chatcmpl-C4681YHqzYJIMl0BJy9rtucgtkO8G", + "model": "gpt-4o-mini-2024-07-18", + "object": "chat.completion", + "prompt_filter_results": [ + { + "content_filter_results": { + "hate": { + "filtered": false, + "severity": "safe" + }, + "self_harm": { + "filtered": false, + "severity": "safe" + }, + "sexual": { + "filtered": false, + "severity": "safe" + }, + "violence": { + "filtered": false, + "severity": "safe" + } + }, + "prompt_index": 0 + } + ], + "system_fingerprint": "fp_efad92c60b", + "usage": { + "completion_tokens": 60, + "completion_tokens_details": { + "accepted_prediction_tokens": 0, + "audio_tokens": 0, + "reasoning_tokens": 0, + "rejected_prediction_tokens": 0 + }, + "prompt_tokens": 175, + "prompt_tokens_details": { + "audio_tokens": 0, + "cached_tokens": 0 + }, + "total_tokens": 235 + } +} \ No newline at end of file diff --git a/foundation-models/openai/src/test/resources/chatMemory.json b/foundation-models/openai/src/test/resources/chatMemory.json new file mode 100644 index 000000000..9bc924d65 --- /dev/null +++ b/foundation-models/openai/src/test/resources/chatMemory.json @@ -0,0 +1,16 @@ +{ + "messages": [ + { + "content": "What is the capital of France?", + "role": "user" + }, + { + "content": "The capital of France is Paris.", + "role": "assistant" + }, + { + "content": "And what is the typical food there?", + "role": "user" + } + ] +} \ No newline at end of file diff --git a/foundation-models/openai/src/test/resources/toolCallsRequest.json b/foundation-models/openai/src/test/resources/toolCallsRequest.json new file mode 100644 index 000000000..761322b7f --- /dev/null +++ b/foundation-models/openai/src/test/resources/toolCallsRequest.json @@ -0,0 +1,47 @@ +{ + "messages": [ + { + "content": "What is the weather in Potsdam and in Toulouse?", + "role": "user" + } + ], + "tools": [ + { + "type": "function", + "function": { + "description": "Get the weather in location", + "name": "getCurrentWeather", + "parameters": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "type": "object", + "properties": { + "arg0": { + "type": "object", + "properties": { + "location": { + "type": "string" + }, + "unit": { + "type": "string", + "enum": [ + "C", + "F" + ] + } + }, + "required": [ + "location", + "unit" + ] + } + }, + "required": [ + "arg0" + ], + "additionalProperties": false + }, + "strict": false + } + } + ] +} diff --git a/foundation-models/openai/src/test/resources/toolCallsRequest2.json b/foundation-models/openai/src/test/resources/toolCallsRequest2.json new file mode 100644 index 000000000..783536354 --- /dev/null +++ b/foundation-models/openai/src/test/resources/toolCallsRequest2.json @@ -0,0 +1,78 @@ +{ + "messages": [ + { + "content": "What is the weather in Potsdam and in Toulouse?", + "role": "user" + }, + { + "role": "assistant", + "tool_calls": [ + { + "id": "call_MQ7MyYGmoP5TpMSv6AfeWCg5", + "type": "function", + "function": { + "name": "getCurrentWeather", + "arguments": "{\"arg0\": {\"location\": \"Potsdam\", \"unit\": \"C\"}}" + } + }, + { + "id": "call_BQpUfvkUUqx7e3yZv7Rmpnxy", + "type": "function", + "function": { + "name": "getCurrentWeather", + "arguments": "{\"arg0\": {\"location\": \"Toulouse\", \"unit\": \"C\"}}" + } + } + ] + }, + { + "role": "tool", + "content": "{\"temp\":30.0,\"unit\":\"C\"}", + "tool_call_id": "call_MQ7MyYGmoP5TpMSv6AfeWCg5" + }, + { + "role": "tool", + "content": "{\"temp\":30.0,\"unit\":\"C\"}", + "tool_call_id": "call_BQpUfvkUUqx7e3yZv7Rmpnxy" + } + ], + "tools": [ + { + "type": "function", + "function": { + "description": "Get the weather in location", + "name": "getCurrentWeather", + "parameters": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "type": "object", + "properties": { + "arg0": { + "type": "object", + "properties": { + "location": { + "type": "string" + }, + "unit": { + "type": "string", + "enum": [ + "C", + "F" + ] + } + }, + "required": [ + "location", + "unit" + ] + } + }, + "required": [ + "arg0" + ], + "additionalProperties": false + }, + "strict": false + } + } + ] +} diff --git a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/SpringAiOpenAiController.java b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/SpringAiOpenAiController.java index 7a3d46570..0497ac793 100644 --- a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/SpringAiOpenAiController.java +++ b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/SpringAiOpenAiController.java @@ -1,13 +1,16 @@ package com.sap.ai.sdk.app.controllers; import com.sap.ai.sdk.app.services.SpringAiOpenAiService; +import javax.annotation.Nonnull; import javax.annotation.Nullable; import lombok.val; +import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RequestParam; import org.springframework.web.bind.annotation.RestController; +import reactor.core.publisher.Flux; @SuppressWarnings("unused") @RestController @@ -24,4 +27,47 @@ Object embed(@Nullable @RequestParam(value = "format", required = false) final S } return response.getResult().getOutput(); } + + @GetMapping("/completion") + Object completion( + @Nullable @RequestParam(value = "format", required = false) final String format) { + val response = service.completion(); + + if ("json".equals(format)) { + return response.getResult(); + } + return response.getResult().getOutput().getText(); + } + + @GetMapping("/streamChatCompletion") + @Nonnull + Flux streamChatCompletion() { + return service + .streamChatCompletion() + .map(chatResponse -> chatResponse.getResult().getOutput().getText()); + } + + @GetMapping("/toolCalling") + Object toolCalling( + @Nullable @RequestParam(value = "format", required = false) final String format) { + val response = service.toolCalling(true); + + if ("json".equals(format)) { + return response.getResult(); + } + final AssistantMessage message = response.getResult().getOutput(); + final String text = message.getText(); + return text != null && text.isEmpty() ? message.getToolCalls().toString() : text; + } + + @GetMapping("/chatMemory") + Object chatMemory( + @Nullable @RequestParam(value = "format", required = false) final String format) { + val response = service.chatMemory(); + + if ("json".equals(format)) { + return response.getResult(); + } + return response.getResult().getOutput().getText(); + } } diff --git a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/SpringAiOrchestrationController.java b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/SpringAiOrchestrationController.java index 72ef36e29..aac89186e 100644 --- a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/SpringAiOrchestrationController.java +++ b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/SpringAiOrchestrationController.java @@ -128,7 +128,7 @@ Object toolCalling( } final AssistantMessage message = response.getResult().getOutput(); final String text = message.getText(); - return text.isEmpty() ? message.getToolCalls().toString() : text; + return text != null && text.isEmpty() ? message.getToolCalls().toString() : text; } @GetMapping("/mcp") diff --git a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/SpringAiOpenAiService.java b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/SpringAiOpenAiService.java index e1ff3b343..f02a38fb4 100644 --- a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/SpringAiOpenAiService.java +++ b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/SpringAiOpenAiService.java @@ -2,20 +2,36 @@ import com.sap.ai.sdk.foundationmodels.openai.OpenAiClient; import com.sap.ai.sdk.foundationmodels.openai.OpenAiModel; +import com.sap.ai.sdk.foundationmodels.openai.spring.OpenAiChatModel; import com.sap.ai.sdk.foundationmodels.openai.spring.OpenAiSpringEmbeddingModel; import java.util.List; +import java.util.Objects; import javax.annotation.Nonnull; +import lombok.val; +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor; +import org.springframework.ai.chat.memory.InMemoryChatMemoryRepository; +import org.springframework.ai.chat.memory.MessageWindowChatMemory; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingOptionsBuilder; import org.springframework.ai.embedding.EmbeddingRequest; import org.springframework.ai.embedding.EmbeddingResponse; +import org.springframework.ai.model.tool.DefaultToolCallingChatOptions; +import org.springframework.ai.support.ToolCallbacks; import org.springframework.stereotype.Service; +import reactor.core.publisher.Flux; /** Service class for Spring AI integration with OpenAI */ @Service public class SpringAiOpenAiService { - private final OpenAiClient client = OpenAiClient.forModel(OpenAiModel.TEXT_EMBEDDING_3_SMALL); + private final OpenAiSpringEmbeddingModel embeddingClient = + new OpenAiSpringEmbeddingModel(OpenAiClient.forModel(OpenAiModel.TEXT_EMBEDDING_3_SMALL)); + private final ChatModel chatClient = + new OpenAiChatModel(OpenAiClient.forModel(OpenAiModel.GPT_4O_MINI)); /** * Embeds a list of strings using the OpenAI embedding model. @@ -28,7 +44,7 @@ public EmbeddingResponse embedStrings() { final var springAiRequest = new EmbeddingRequest(List.of("The quick brown fox jumps over the lazy dog."), options); - return new OpenAiSpringEmbeddingModel(client).call(springAiRequest); + return embeddingClient.call(springAiRequest); } /** @@ -39,6 +55,65 @@ public EmbeddingResponse embedStrings() { @Nonnull public float[] embedDocument() { final var document = new Document("The quick brown fox jumps over the lazy dog."); - return new OpenAiSpringEmbeddingModel(client).embed(document); + return embeddingClient.embed(document); + } + + /** + * Chat request to OpenAI through the OpenAI service with a simple prompt. + * + * @return the assistant response object + */ + @Nonnull + public ChatResponse completion() { + val prompt = new Prompt("What is the capital of France?"); + return chatClient.call(prompt); + } + + /** + * Asynchronous stream of an OpenAI chat request + * + * @return a stream of assistant message responses + */ + @Nonnull + public Flux streamChatCompletion() { + val prompt = new Prompt("Can you give me the first 100 numbers of the Fibonacci sequence?"); + return chatClient.stream(prompt); + } + + /** + * Turn a method into a tool by annotating it with @Tool. Spring AI + * Tool Method Declarative Specification + * + * @param internalToolExecutionEnabled whether the internal tool execution is enabled + * @return the assistant response object + */ + @Nonnull + public ChatResponse toolCalling(final boolean internalToolExecutionEnabled) { + val options = new DefaultToolCallingChatOptions(); + options.setToolCallbacks(List.of(ToolCallbacks.from(new WeatherMethod()))); + options.setInternalToolExecutionEnabled(internalToolExecutionEnabled); + + val prompt = new Prompt("What is the weather in Potsdam and in Toulouse?", options); + return chatClient.call(prompt); + } + + /** + * Chat request to OpenAI through the OpenAI service using chat memory. + * + * @return the assistant response object + */ + @Nonnull + public ChatResponse chatMemory() { + val repository = new InMemoryChatMemoryRepository(); + val memory = MessageWindowChatMemory.builder().chatMemoryRepository(repository).build(); + val advisor = MessageChatMemoryAdvisor.builder(memory).build(); + val cl = ChatClient.builder(chatClient).defaultAdvisors(advisor).build(); + val prompt1 = new Prompt("What is the capital of France?"); + val prompt2 = new Prompt("And what is the typical food there?"); + + cl.prompt(prompt1).call().content(); + return Objects.requireNonNull( + cl.prompt(prompt2).call().chatResponse(), "Chat response is null"); } } diff --git a/sample-code/spring-app/src/main/resources/static/index.html b/sample-code/spring-app/src/main/resources/static/index.html index 0e4ac11d6..61e2a9f68 100644 --- a/sample-code/spring-app/src/main/resources/static/index.html +++ b/sample-code/spring-app/src/main/resources/static/index.html @@ -784,7 +784,8 @@
Orchestration Integration
/spring-ai-orchestration/mcp
- Use an MCP file system server as tool to answer questions about the SDK itself. ⚠️ Only works if the server is started with the "mcp" Spring profile ⚠️. + Use an MCP file system server as tool to answer questions about the SDK itself. + ⚠️ Only works if the server is started with the "mcp" Spring profile ⚠️.
@@ -837,11 +838,61 @@
OpenAI
/spring-ai-openai/embed/strings
- Get the embedding for a given string using SpringAI from + Get the embedding for a given string using OpenAI. +
+ + +
  • +
    + +
    + Chat Request with a simple prompt using OpenAI. +
    +
    +
  • +
  • +
    + +
    + Asynchronous stream of a request using OpenAI. +
    +
    +
  • +
  • +
    + +
    + Register a function that will be called when the user asks for the weather using OpenAI.
  • +
  • +
    + +
    + The user firsts asks the capital of France, then the typical + for there, chat memory will remember that the user is + inquiring about France using OpenAI. +
    +
    +
  • diff --git a/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/SpringAiOpenAiTest.java b/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/SpringAiOpenAiTest.java index 7d3ea42fd..285b31a70 100644 --- a/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/SpringAiOpenAiTest.java +++ b/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/SpringAiOpenAiTest.java @@ -4,11 +4,17 @@ import com.sap.ai.sdk.app.services.SpringAiOpenAiService; import com.sap.ai.sdk.foundationmodels.openai.OpenAiModel; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.model.ChatResponse; class SpringAiOpenAiTest { private final SpringAiOpenAiService service = new SpringAiOpenAiService(); + private static final org.slf4j.Logger log = + org.slf4j.LoggerFactory.getLogger(SpringAiOrchestrationTest.class); @Test void testEmbedStrings() { @@ -23,4 +29,66 @@ void testEmbedStrings() { assertThat(response.getMetadata().getModel()) .isEqualTo(OpenAiModel.TEXT_EMBEDDING_3_SMALL.name()); } + + @Test + void testCompletion() { + ChatResponse response = service.completion(); + assertThat(response).isNotNull(); + assertThat(response.getResult().getOutput().getText()).contains("Paris"); + } + + @Test + void testStreamChatCompletion() { + final var stream = service.streamChatCompletion().toStream(); + + final var filledDeltaCount = new AtomicInteger(0); + stream + // foreach consumes all elements, closing the stream at the end + .forEach( + delta -> { + log.info("delta: {}", delta); + String text = delta.getResult().getOutput().getText(); + if (text != null && !text.isEmpty()) { + filledDeltaCount.incrementAndGet(); + } + }); + + // the first two and the last delta don't have any content + // see OpenAiChatCompletionDelta#getDeltaContent + assertThat(filledDeltaCount.get()).isGreaterThan(0); + } + + @Test + void testToolCallingWithExecution() { + ChatResponse response = service.toolCalling(true); + assertThat(response.getResult().getOutput().getText()).contains("Potsdam", "Toulouse", "°C"); + } + + @Test + void testToolCallingWithoutExecution() { + ChatResponse response = service.toolCalling(false); + List toolCalls = response.getResult().getOutput().getToolCalls(); + assertThat(toolCalls).hasSize(2); + AssistantMessage.ToolCall toolCall1 = toolCalls.get(0); + AssistantMessage.ToolCall toolCall2 = toolCalls.get(1); + assertThat(toolCall1.type()).isEqualTo("function"); + assertThat(toolCall2.type()).isEqualTo("function"); + assertThat(toolCall1.name()).isEqualTo("getCurrentWeather"); + assertThat(toolCall2.name()).isEqualTo("getCurrentWeather"); + assertThat(toolCall1.arguments()) + .isEqualTo("{\"arg0\": {\"location\": \"Potsdam\", \"unit\": \"C\"}}"); + assertThat(toolCall2.arguments()) + .isEqualTo("{\"arg0\": {\"location\": \"Toulouse\", \"unit\": \"C\"}}"); + } + + @Test + void testChatMemory() { + ChatResponse response = service.chatMemory(); + assertThat(response).isNotNull(); + String text = response.getResult().getOutput().getText(); + log.info(text); + assertThat(text) + .containsAnyOf( + "French", "onion", "pastries", "cheese", "baguette", "coq au vin", "foie gras"); + } }