Skip to content
Open
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.parseLLMOutput;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.substitute;
import static org.opensearch.ml.engine.algorithms.agent.PromptTemplate.CHAT_HISTORY_PREFIX;
import static org.opensearch.ml.engine.algorithms.agent.PromptTemplate.SUMMARY_PROMPT_TEMPLATE;

import java.security.PrivilegedActionException;
import java.util.ArrayList;
Expand Down Expand Up @@ -125,9 +126,11 @@ public class MLChatAgentRunner implements MLAgentRunner {
public static final String INJECT_DATETIME_FIELD = "inject_datetime";
public static final String DATETIME_FORMAT_FIELD = "datetime_format";
public static final String SYSTEM_PROMPT_FIELD = "system_prompt";
public static final String SUMMARIZE_WHEN_MAX_ITERATION = "summarize_when_max_iteration";

private static final String DEFAULT_MAX_ITERATIONS = "10";
private static final String MAX_ITERATIONS_MESSAGE = "Agent reached maximum iterations (%d) without completing the task";
private static final String MAX_ITERATIONS_SUMMARY_MESSAGE = MAX_ITERATIONS_MESSAGE + ". Here's a summary of the steps taken:\n\n%s";

private Client client;
private Settings settings;
Expand Down Expand Up @@ -322,7 +325,6 @@ private void runReAct(

StringBuilder scratchpadBuilder = new StringBuilder();
List<String> interactions = new CopyOnWriteArrayList<>();

StringSubstitutor tmpSubstitutor = new StringSubstitutor(Map.of(SCRATCHPAD, scratchpadBuilder.toString()), "${parameters.", "}");
AtomicReference<String> newPrompt = new AtomicReference<>(tmpSubstitutor.replace(prompt));
tmpParameters.put(PROMPT, newPrompt.get());
Expand Down Expand Up @@ -414,7 +416,10 @@ private void runReAct(
additionalInfo,
lastThought,
maxIterations,
tools
tools,
tmpParameters,
llm,
tenantId
);
return;
}
Expand Down Expand Up @@ -513,7 +518,10 @@ private void runReAct(
additionalInfo,
lastThought,
maxIterations,
tools
tools,
tmpParameters,
llm,
tenantId
);
return;
}
Expand Down Expand Up @@ -885,11 +893,70 @@ private void handleMaxIterationsReached(
Map<String, Object> additionalInfo,
AtomicReference<String> lastThought,
int maxIterations,
Map<String, Tool> tools,
Map<String, String> parameters,
LLMSpec llmSpec,
String tenantId
) {
boolean shouldSummarize = Boolean.parseBoolean(parameters.getOrDefault(SUMMARIZE_WHEN_MAX_ITERATION, "false"));

if (shouldSummarize && !traceTensors.isEmpty()) {
generateLLMSummary(traceTensors, llmSpec, tenantId, ActionListener.wrap(summary -> {
String summaryResponse = String.format(Locale.ROOT, MAX_ITERATIONS_SUMMARY_MESSAGE, maxIterations, summary);
sendTraditionalMaxIterationsResponse(
sessionId,
listener,
question,
parentInteractionId,
verbose,
traceDisabled,
traceTensors,
conversationIndexMemory,
traceNumber,
additionalInfo,
summaryResponse,
tools
);
}, e -> {
log.error("Failed to generate LLM summary", e);
listener.onFailure(e);
cleanUpResource(tools);
}));
} else {
String response = (lastThought.get() != null && !lastThought.get().isEmpty() && !"null".equals(lastThought.get()))
? String.format("%s. Last thought: %s", String.format(MAX_ITERATIONS_MESSAGE, maxIterations), lastThought.get())
: String.format(MAX_ITERATIONS_MESSAGE, maxIterations);
sendTraditionalMaxIterationsResponse(
sessionId,
listener,
question,
parentInteractionId,
verbose,
traceDisabled,
traceTensors,
conversationIndexMemory,
traceNumber,
additionalInfo,
response,
tools
);
}
}

private void sendTraditionalMaxIterationsResponse(
String sessionId,
ActionListener<Object> listener,
String question,
String parentInteractionId,
boolean verbose,
boolean traceDisabled,
List<ModelTensors> traceTensors,
ConversationIndexMemory conversationIndexMemory,
AtomicInteger traceNumber,
Map<String, Object> additionalInfo,
String response,
Map<String, Tool> tools
) {
String incompleteResponse = (lastThought.get() != null && !lastThought.get().isEmpty() && !"null".equals(lastThought.get()))
? String.format("%s. Last thought: %s", String.format(MAX_ITERATIONS_MESSAGE, maxIterations), lastThought.get())
: String.format(MAX_ITERATIONS_MESSAGE, maxIterations);
sendFinalAnswer(
sessionId,
listener,
Expand All @@ -901,11 +968,86 @@ private void handleMaxIterationsReached(
conversationIndexMemory,
traceNumber,
additionalInfo,
incompleteResponse
response
);
cleanUpResource(tools);
}

void generateLLMSummary(List<ModelTensors> stepsSummary, LLMSpec llmSpec, String tenantId, ActionListener<String> listener) {
if (stepsSummary == null || stepsSummary.isEmpty()) {
listener.onFailure(new IllegalArgumentException("Steps summary cannot be null or empty"));
return;
}

try {
Map<String, String> summaryParams = new HashMap<>();
if (llmSpec.getParameters() != null) {
summaryParams.putAll(llmSpec.getParameters());
}

// Convert ModelTensors to strings before joining
List<String> stepStrings = new ArrayList<>();
for (ModelTensors tensor : stepsSummary) {
if (tensor != null && tensor.getMlModelTensors() != null) {
for (ModelTensor modelTensor : tensor.getMlModelTensors()) {
if (modelTensor.getResult() != null) {
stepStrings.add(modelTensor.getResult());
} else if (modelTensor.getDataAsMap() != null && modelTensor.getDataAsMap().containsKey("response")) {
stepStrings.add(String.valueOf(modelTensor.getDataAsMap().get("response")));
}
}
}
}
String summaryPrompt = String.format(Locale.ROOT, SUMMARY_PROMPT_TEMPLATE, String.join("\n", stepStrings));
summaryParams.put(PROMPT, summaryPrompt);
summaryParams.putIfAbsent(SYSTEM_PROMPT_FIELD, "");

ActionRequest request = new MLPredictionTaskRequest(
llmSpec.getModelId(),
RemoteInferenceMLInput
.builder()
.algorithm(FunctionName.REMOTE)
.inputDataset(RemoteInferenceInputDataSet.builder().parameters(summaryParams).build())
.build(),
null,
tenantId
);
client.execute(MLPredictionTaskAction.INSTANCE, request, ActionListener.wrap(response -> {
String summary = extractSummaryFromResponse(response);
if (summary != null) {
listener.onResponse(summary);
} else {
listener.onFailure(new RuntimeException("Empty or invalid LLM summary response"));
}
}, listener::onFailure));
} catch (Exception e) {
listener.onFailure(e);
}
}

public String extractSummaryFromResponse(MLTaskResponse response) {
try {
String outputString = outputToOutputString(response.getOutput());
if (outputString != null && !outputString.trim().isEmpty()) {
try {
Map<String, Object> dataMap = gson.fromJson(outputString, Map.class);
if (dataMap.containsKey("response")) {
String summary = String.valueOf(dataMap.get("response"));
if (summary != null && !summary.trim().isEmpty() && !"null".equals(summary)) {
return summary.trim();
}
}
} catch (Exception jsonException) {
return outputString.trim();
}
}
return null;
} catch (Exception e) {
log.error("Failed to extract summary from response", e);
throw new RuntimeException("Failed to extract summary from response", e);
}
}

private void saveMessage(
ConversationIndexMemory memory,
String question,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,4 +140,7 @@ public class PromptTemplate {
- Avoid making assumptions and relying on implicit knowledge.
- Your response must be self-contained and ready for the planner to use without modification. Never end with a question.
- Break complex searches into simpler queries when appropriate.""";

public static final String SUMMARY_PROMPT_TEMPLATE =
"Please provide a concise summary of the following agent execution steps. Focus on what the agent was trying to accomplish and what progress was made:\n\n%s\n\nPlease respond in the following JSON format:\n{\"response\": \"your summary here\"}";
}
Original file line number Diff line number Diff line change
Expand Up @@ -1118,4 +1118,145 @@ public void testConstructLLMParams_DefaultValues() {
Assert.assertTrue(result.containsKey(AgentUtils.RESPONSE_FORMAT_INSTRUCTION));
Assert.assertTrue(result.containsKey(AgentUtils.TOOL_RESPONSE));
}

@Test
public void testMaxIterationsWithSummaryEnabled() {
// Create LLM spec with max_iteration = 1 to simplify test
LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").parameters(Map.of("max_iteration", "1")).build();
MLToolSpec firstToolSpec = MLToolSpec.builder().name(FIRST_TOOL).type(FIRST_TOOL).build();
final MLAgent mlAgent = MLAgent
.builder()
.name("TestAgent")
.type(MLAgentType.CONVERSATIONAL.name())
.llm(llmSpec)
.memory(mlMemorySpec)
.tools(Arrays.asList(firstToolSpec))
.build();

// Reset and setup fresh mocks
Mockito.reset(client);
Mockito.reset(firstTool);
when(firstTool.getName()).thenReturn(FIRST_TOOL);
when(firstTool.validate(Mockito.anyMap())).thenReturn(true);
Mockito.doAnswer(generateToolResponse("First tool response")).when(firstTool).run(Mockito.anyMap(), any());

// First call: LLM response without final_answer to trigger max iterations
// Second call: Summary LLM response with result field instead of dataAsMap
Mockito
.doAnswer(getLLMAnswer(ImmutableMap.of("thought", "I need to analyze the data", "action", FIRST_TOOL)))
.doAnswer(invocation -> {
ActionListener<Object> listener = invocation.getArgument(2);
ModelTensor modelTensor = ModelTensor.builder().result("Summary: Analysis step was attempted").build();
ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build();
ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build();
MLTaskResponse mlTaskResponse = MLTaskResponse.builder().output(mlModelTensorOutput).build();
listener.onResponse(mlTaskResponse);
return null;
})
.when(client)
.execute(any(ActionType.class), any(ActionRequest.class), isA(ActionListener.class));

Map<String, String> params = new HashMap<>();
params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "parent_interaction_id");
params.put(MLChatAgentRunner.SUMMARIZE_WHEN_MAX_ITERATION, "true");

mlChatAgentRunner.run(mlAgent, params, agentActionListener);

// Verify response is captured
verify(agentActionListener).onResponse(objectCaptor.capture());
Object capturedResponse = objectCaptor.getValue();
assertTrue(capturedResponse instanceof ModelTensorOutput);

ModelTensorOutput modelTensorOutput = (ModelTensorOutput) capturedResponse;
List<ModelTensor> agentOutput = modelTensorOutput.getMlModelOutputs().get(1).getMlModelTensors();
assertEquals(1, agentOutput.size());

// Verify the response contains summary message
String response = (String) agentOutput.get(0).getDataAsMap().get("response");
assertTrue(
response.startsWith("Agent reached maximum iterations (1) without completing the task. Here's a summary of the steps taken:")
);
assertTrue(response.contains("Summary: Analysis step was attempted"));
}

@Test
public void testMaxIterationsWithSummaryDisabled() {
// Create LLM spec with max_iteration = 1 and summary disabled
LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").parameters(Map.of("max_iteration", "1")).build();
MLToolSpec firstToolSpec = MLToolSpec.builder().name(FIRST_TOOL).type(FIRST_TOOL).build();
final MLAgent mlAgent = MLAgent
.builder()
.name("TestAgent")
.type(MLAgentType.CONVERSATIONAL.name())
.llm(llmSpec)
.memory(mlMemorySpec)
.tools(Arrays.asList(firstToolSpec))
.build();

// Reset client mock for this test
Mockito.reset(client);
// Mock LLM response that doesn't contain final_answer
Mockito
.doAnswer(getLLMAnswer(ImmutableMap.of("thought", "I need to use the tool", "action", FIRST_TOOL)))
.when(client)
.execute(any(ActionType.class), any(ActionRequest.class), isA(ActionListener.class));

Map<String, String> params = new HashMap<>();
params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "parent_interaction_id");
params.put(MLChatAgentRunner.SUMMARIZE_WHEN_MAX_ITERATION, "false");

mlChatAgentRunner.run(mlAgent, params, agentActionListener);

// Verify response is captured
verify(agentActionListener).onResponse(objectCaptor.capture());
Object capturedResponse = objectCaptor.getValue();
assertTrue(capturedResponse instanceof ModelTensorOutput);

ModelTensorOutput modelTensorOutput = (ModelTensorOutput) capturedResponse;
List<ModelTensor> agentOutput = modelTensorOutput.getMlModelOutputs().get(1).getMlModelTensors();
assertEquals(1, agentOutput.size());

// Verify the response contains traditional max iterations message
String response = (String) agentOutput.get(0).getDataAsMap().get("response");
assertEquals("Agent reached maximum iterations (1) without completing the task. Last thought: I need to use the tool", response);
}

@Test
public void testExtractSummaryFromResponse() {
MLTaskResponse response = MLTaskResponse
.builder()
.output(
ModelTensorOutput
.builder()
.mlModelOutputs(
Arrays
.asList(
ModelTensors
.builder()
.mlModelTensors(
Arrays
.asList(
ModelTensor.builder().dataAsMap(ImmutableMap.of("response", "Valid summary text")).build()
)
)
.build()
)
)
.build()
)
.build();

String result = mlChatAgentRunner.extractSummaryFromResponse(response);
assertEquals("Valid summary text", result);
}

@Test
public void testGenerateLLMSummaryWithNullSteps() {
LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build();
ActionListener<String> listener = Mockito.mock(ActionListener.class);

mlChatAgentRunner.generateLLMSummary(null, llmSpec, "tenant", listener);

verify(listener).onFailure(any(IllegalArgumentException.class));
}
}