diff --git a/java/src/main/java/com/genexus/GXProcedure.java b/java/src/main/java/com/genexus/GXProcedure.java index d2b65dfd2..008cd1d45 100644 --- a/java/src/main/java/com/genexus/GXProcedure.java +++ b/java/src/main/java/com/genexus/GXProcedure.java @@ -275,15 +275,25 @@ protected String callAssistant(String agent, GXProperties properties, ArrayList< } protected ChatResult chatAgent(String agent, GXProperties properties, ArrayList messages, CallResult result) { - callAgent(agent, true, properties, messages, result); - return new ChatResult(this, agent, properties, messages, result, client); + ChatResult chatResult = new ChatResult(); + + new Thread(() -> { + try { + context.setThreadModelContext(context); + callAgent(agent, true, properties, messages, result, chatResult); + } finally { + chatResult.markDone(); + } + }).start(); + + return chatResult; } protected String callAgent(String agent, GXProperties properties, ArrayList messages, CallResult result) { - return callAgent(agent, false, properties, messages, result); + return callAgent(agent, false, properties, messages, result, null); } - protected String callAgent(String agent, boolean stream, GXProperties properties, ArrayList messages, CallResult result) { + protected String callAgent(String agent, boolean stream, GXProperties properties, ArrayList messages, CallResult result, ChatResult chatResult) { OpenAIRequest aiRequest = new OpenAIRequest(); aiRequest.setModel(String.format("saia:agent:%s", agent)); if (!messages.isEmpty()) @@ -292,7 +302,7 @@ protected String callAgent(String agent, boolean stream, GXProperties properties if (stream) aiRequest.setStream(true); client = new HttpClient(); - OpenAIResponse aiResponse = SaiaService.call(aiRequest, client, result); + OpenAIResponse aiResponse = SaiaService.call(this, aiRequest, client, agent, stream, properties, messages, result, chatResult); if (aiResponse != null && aiResponse.getChoices() != null) { for (OpenAIResponse.Choice element : aiResponse.getChoices()) { String finishReason = element.getFinishReason(); @@ -300,7 +310,7 @@ protected String callAgent(String agent, boolean stream, GXProperties properties return element.getMessage().getStringContent(); if (finishReason.equals("tool_calls")) { messages.add(element.getMessage()); - return processNotChunkedResponse(agent, stream, properties, messages, result, element.getMessage().getToolCalls()); + return processNotChunkedResponse(agent, stream, properties, messages, result, chatResult, element.getMessage().getToolCalls()); } } } else if (client.getStatusCode() == 200) { @@ -309,11 +319,11 @@ protected String callAgent(String agent, boolean stream, GXProperties properties return ""; } - public String processNotChunkedResponse(String agent, boolean stream, GXProperties properties, ArrayList messages, CallResult result, ArrayList toolCalls) { + public String processNotChunkedResponse(String agent, boolean stream, GXProperties properties, ArrayList messages, CallResult result, ChatResult chatResult, ArrayList toolCalls) { for (OpenAIResponse.ToolCall tollCall : toolCalls) { processToolCall(tollCall, messages); } - return callAgent(agent, stream, properties, messages, result); + return callAgent(agent, stream, properties, messages, result, chatResult); } private void processToolCall(OpenAIResponse.ToolCall toolCall, ArrayList messages) { diff --git a/java/src/main/java/com/genexus/util/ChatResult.java b/java/src/main/java/com/genexus/util/ChatResult.java index 7496a9289..376ea660f 100644 --- a/java/src/main/java/com/genexus/util/ChatResult.java +++ b/java/src/main/java/com/genexus/util/ChatResult.java @@ -1,54 +1,38 @@ package com.genexus.util; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.genexus.GXProcedure; -import com.genexus.internet.HttpClient; -import com.genexus.util.saia.OpenAIResponse; -import org.json.JSONObject; - -import java.util.ArrayList; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingQueue; public class ChatResult { - private HttpClient client = null; - private String agent = null; - private GXProperties properties = null; - private ArrayList messages = null; - private CallResult result = null; - private GXProcedure agentProcedure = null; - - public ChatResult() { - } + private static final String END_MARKER = new String("__END__"); + private final BlockingQueue chunks = new LinkedBlockingQueue<>(); + private volatile boolean done = false; - public ChatResult(GXProcedure agentProcedure, String agent, GXProperties properties, ArrayList messages, CallResult result, HttpClient client) { - this.agentProcedure = agentProcedure; - this.agent = agent; - this.properties = properties; - this.messages = messages; - this.result = result; - this.client = client; + public synchronized void addChunk(String chunk) { + if (chunk != null) { + chunks.offer(chunk); + } } - public boolean hasMoreData() { - return !client.getEof(); + public void markDone() { + done = true; + chunks.offer(END_MARKER); } public String getMoreData() { - String data = client.readChunk(); - if (data.isEmpty()) - return ""; - int index = data.indexOf("data:") + "data:".length(); - String chunkJson = data.substring(index).trim(); try { - JSONObject jsonResponse = new JSONObject(chunkJson); - OpenAIResponse chunkResponse = new ObjectMapper().readValue(jsonResponse.toString(), OpenAIResponse.class); - OpenAIResponse.Choice choise = chunkResponse.getChoices().get(0); - String chunkString = choise.getDelta().getStringContent(); - if (chunkString == null) + String chunk = chunks.take(); + if (END_MARKER.equals(chunk)) { return ""; - return chunkString; - } - catch (Exception e) { + } + return chunk; + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); return ""; } } -} + + public boolean hasMoreData() { + return !(done && chunks.isEmpty()); + } +} \ No newline at end of file diff --git a/java/src/main/java/com/genexus/util/saia/SaiaService.java b/java/src/main/java/com/genexus/util/saia/SaiaService.java index d34e1ffcc..62f2872ba 100644 --- a/java/src/main/java/com/genexus/util/saia/SaiaService.java +++ b/java/src/main/java/com/genexus/util/saia/SaiaService.java @@ -1,31 +1,36 @@ package com.genexus.util.saia; import com.fasterxml.jackson.databind.ObjectMapper; +import com.genexus.GXProcedure; import com.genexus.SdtMessages_Message; import com.genexus.common.interfaces.SpecificImplementation; import com.genexus.diagnostics.core.ILogger; import com.genexus.diagnostics.core.LogManager; import com.genexus.internet.HttpClient; +import com.genexus.util.ChatResult; +import com.genexus.util.GXProperties; import org.json.JSONObject; import com.genexus.util.CallResult; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.util.ArrayList; + public class SaiaService { private static final ILogger logger = LogManager.getLogger(SaiaService.class); private static final String apiKey = (String) SpecificImplementation.Application.getProperty("AI_PROVIDER_API_KEY", "");; private static final String aiProvider = (String) SpecificImplementation.Application.getProperty("AI_PROVIDER", ""); private static final Logger log = LoggerFactory.getLogger(SaiaService.class); - public static OpenAIResponse call(OpenAIRequest request, HttpClient client, CallResult result) { - return call(request, false, client, result); + public static OpenAIResponse call(GXProcedure proc, OpenAIRequest request, HttpClient client, String agent, boolean stream, GXProperties properties, ArrayList messages, CallResult result, ChatResult chatResult) { + return call(proc, request, false, client, agent, stream, properties, messages, result, chatResult); } public static OpenAIResponse call(OpenAIRequest request, boolean isEmbedding, CallResult result) { - return call(request, isEmbedding, new HttpClient(), result); + return call(null, request, isEmbedding, new HttpClient(), null, false, null, null, result, null); } - public static OpenAIResponse call(OpenAIRequest request, boolean isEmbedding, HttpClient client, CallResult result) { + public static OpenAIResponse call(GXProcedure proc, OpenAIRequest request, boolean isEmbedding, HttpClient client, String agent, boolean stream, GXProperties properties, ArrayList messages, CallResult result, ChatResult chatResult) { try { String jsonRequest = new ObjectMapper().writeValueAsString(request); logger.debug("Agent payload: " + jsonRequest); @@ -44,25 +49,8 @@ public static OpenAIResponse call(OpenAIRequest request, boolean isEmbedding, Ht if (client.getStatusCode() == 200) { String saiaResponse; if (client.getHeader("Content-Type").contains("text/event-stream")){ - saiaResponse = client.readChunk(); - int index = saiaResponse.indexOf("data:") + "data:".length(); - String chunkJson = saiaResponse.substring(index).trim(); - try { - JSONObject jsonResponse = new JSONObject(chunkJson); - OpenAIResponse chunkResponse = new ObjectMapper().readValue(jsonResponse.toString(), OpenAIResponse.class); - OpenAIResponse.Choice choise = chunkResponse.getChoices().get(0); - if (choise.getFinishReason() != null && choise.getFinishReason().equals("tool_calls")){ - saiaResponse = chunkJson; - } - else { - client.unreadChunk(); - return null; - } - } - catch (Exception e) { - client.unreadChunk(); - return null; - } + getChunkedSaiaResponse(proc, client, agent, stream, properties, messages, result, chatResult); + return null; } else { saiaResponse = client.getString(); @@ -88,6 +76,38 @@ public static OpenAIResponse call(OpenAIRequest request, boolean isEmbedding, Ht return null; } + private static void getChunkedSaiaResponse(GXProcedure proc, HttpClient client, String agent, boolean stream, GXProperties properties, ArrayList messages, CallResult result, ChatResult chatResult) { + String saiaChunkResponse = client.readChunk();; + String chunkJson; + while (!client.getEof()) { + logger.debug("Agent response chunk: " + saiaChunkResponse); + if (saiaChunkResponse.isEmpty() || saiaChunkResponse.equals("data: [DONE]")) { + saiaChunkResponse = client.readChunk(); + continue; + } + int index = saiaChunkResponse.indexOf("data:") + "data:".length(); + chunkJson = saiaChunkResponse.substring(index).trim(); + try { + JSONObject jsonResponse = new JSONObject(chunkJson); + OpenAIResponse chunkResponse = new ObjectMapper().readValue(jsonResponse.toString(), OpenAIResponse.class); + if (!chunkResponse.getChoices().isEmpty()) { + OpenAIResponse.Choice choice = chunkResponse.getChoices().get(0); + if (choice.getFinishReason() != null && choice.getFinishReason().equals("tool_calls")) { + messages.add(choice.getMessage()); + proc.processNotChunkedResponse(agent, stream, properties, messages, result, chatResult, choice.getMessage().getToolCalls()); + ; + } else if (choice.getDelta() != null && choice.getDelta().getContent() != null) { + chatResult.addChunk(((OpenAIResponse.StringContent) choice.getDelta().getContent()).getValue()); + } + } + saiaChunkResponse = client.readChunk(); + } + catch (Exception e) { + logger.warn("Error deserializing the response chunk", e); + saiaChunkResponse = client.readChunk(); + } + } + } private static void addResultMessage(String id, byte type, String description, CallResult result){ if (type == 1) diff --git a/java/src/test/java/com/genexus/agent/Agent.java b/java/src/test/java/com/genexus/agent/Agent.java index 2098be5fc..fbca98a1e 100644 --- a/java/src/test/java/com/genexus/agent/Agent.java +++ b/java/src/test/java/com/genexus/agent/Agent.java @@ -84,7 +84,7 @@ else if (AV3Parameter1.equals("chat_stream")) { messages.add(message); ChatResult chatResult = chatAgent( "The weatherman", Gxproperties, messages, new CallResult()) ; while (chatResult.hasMoreData()) { - System.out.print(chatResult.hasMoreData()); + System.out.print(chatResult.getMoreData()); } } else if (AV3Parameter1.equals("toolcall")) { @@ -110,7 +110,7 @@ else if (AV3Parameter1.equals("toolcall_stream")) { messages.add(message); ChatResult chatResult = chatAgent( "ProductInfo", Gxproperties, messages, new CallResult()) ; while (chatResult.hasMoreData()) { - System.out.print(chatResult.hasMoreData()); + System.out.print(chatResult.getMoreData()); } } else {