Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,40 @@

package org.opensearch.ml.rest;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;

import java.io.IOException;
import java.lang.reflect.Field;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;

import org.junit.Before;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.model.MLResultDataType;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.common.utils.StringUtils;
import org.opensearch.searchpipelines.questionanswering.generative.client.MachineLearningInternalClient;
import org.opensearch.searchpipelines.questionanswering.generative.llm.ChatCompletionInput;
import org.opensearch.searchpipelines.questionanswering.generative.llm.ChatCompletionOutput;
import org.opensearch.searchpipelines.questionanswering.generative.llm.DefaultLlmImpl;
import org.opensearch.searchpipelines.questionanswering.generative.llm.Llm;

import lombok.SneakyThrows;
import lombok.extern.log4j.Log4j2;
Expand Down Expand Up @@ -82,6 +103,110 @@ public void test_bedrock_embedding_model() throws Exception {
}
}

public void testChatCompletionBedrockContentFormat() throws Exception {
Map<String, Object> response = Map.of("content", List.of(Map.of("text", "Claude V3 response text")));

Map<String, Object> result = invokeBedrockInference(response);

assertTrue(result.containsKey("answers"));
assertEquals("Claude V3 response text", ((List<?>) result.get("answers")).get(0));
}

private static void injectMlClient(DefaultLlmImpl connector, Object mlClient) {
try {
Field field = null;
// Try common field names. Adjust if the actual field is named differently.
try {
field = DefaultLlmImpl.class.getDeclaredField("mlClient");
} catch (NoSuchFieldException e) {
// fallback if different field name
field = DefaultLlmImpl.class.getDeclaredField("client");
}
field.setAccessible(true);
field.set(connector, mlClient);
} catch (ReflectiveOperationException e) {
throw new RuntimeException("Failed to inject mlClient into DefaultLlmImpl", e);
}
}

private Map<String, Object> invokeBedrockInference(Map<String, Object> mockResponse) throws Exception {
// Create DefaultLlmImpl and mock ML client
DefaultLlmImpl connector = new DefaultLlmImpl("model_id", null); // Use getClient() from MLCommonsRestTestCase
MachineLearningInternalClient mlClient = mock(MachineLearningInternalClient.class);
injectMlClient(connector, mlClient);

// Wrap mockResponse inside a ModelTensor -> ModelTensors -> ModelTensorOutput -> MLOutput
ModelTensor tensor = new ModelTensor("tensor", new Number[0], new long[0], MLResultDataType.STRING, null, null, mockResponse);
ModelTensorOutput mlOutput = new ModelTensorOutput(List.of(new ModelTensors(List.of(tensor))));
// Do NOT depend on ActionFuture return path; instead drive the async listener directly.

// Make asynchronous predict(...) call invoke the ActionListener with our mlOutput
doAnswer(invocation -> {
@SuppressWarnings("unchecked")
ActionListener<MLOutput> listener = (ActionListener<MLOutput>) invocation.getArguments()[2];
// Simulate successful ML response
listener.onResponse(mlOutput);
return null;
}).when(mlClient).predict(any(), any(), any());

// Prepare input (use BEDROCK provider so bedrock branch is taken)
ChatCompletionInput input = new ChatCompletionInput(
"bedrock/model",
"question",
Collections.emptyList(),
Collections.emptyList(),
0,
"prompt",
"instructions",
Llm.ModelProvider.BEDROCK,
null,
null
);

// Synchronously wait for callback result
CountDownLatch latch = new CountDownLatch(1);
AtomicReference<Map<String, Object>> resultRef = new AtomicReference<>();

connector.doChatCompletion(input, new ActionListener<>() {
@Override
public void onResponse(ChatCompletionOutput output) {
Map<String, Object> map = new HashMap<>();
map.put("answers", output.getAnswers());
map.put("errors", output.getErrors());
resultRef.set(map);
latch.countDown();
}

@Override
public void onFailure(Exception e) {
Map<String, Object> map = new HashMap<>();
map.put("answers", Collections.emptyList());
map.put("errors", List.of(e.getMessage()));
resultRef.set(map);
latch.countDown();
}
});

boolean completed = latch.await(5, TimeUnit.SECONDS);
if (!completed) {
throw new RuntimeException("Timed out waiting for doChatCompletion callback");
}
return resultRef.get();
}

private void validateErrorOutput(String msg, Map<String, Object> output, String expectedError) {
assertTrue(msg, output.containsKey("error"));
Object error = output.get("error");

if (error instanceof Map) {
assertEquals(msg, expectedError, ((Map<?, ?>) error).get("message"));
} else if (error instanceof String) {
assertEquals(msg, expectedError, error);
} else {
fail("Unexpected error format: " + error.getClass());
}
}

private void validateOutput(String errorMsg, Map<String, Object> output) {
assertTrue(errorMsg, output.containsKey("output"));
assertTrue(errorMsg, output.get("output") instanceof List);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -470,26 +470,6 @@ public class RestMLRAGSearchProcessorIT extends MLCommonsRestTestCase {
+ " }\n"
+ "}";

private static final String BM25_SEARCH_REQUEST_WITH_CONVO_WITH_LLM_RESPONSE_TEMPLATE = "{\n"
+ " \"_source\": [\"%s\"],\n"
+ " \"query\" : {\n"
+ " \"match\": {\"%s\": \"%s\"}\n"
+ " },\n"
+ " \"ext\": {\n"
+ " \"generative_qa_parameters\": {\n"
+ " \"llm_model\": \"%s\",\n"
+ " \"llm_question\": \"%s\",\n"
+ " \"memory_id\": \"%s\",\n"
+ " \"system_prompt\": \"%s\",\n"
+ " \"user_instructions\": \"%s\",\n"
+ " \"context_size\": %d,\n"
+ " \"message_size\": %d,\n"
+ " \"timeout\": %d,\n"
+ " \"llm_response_field\": \"%s\"\n"
+ " }\n"
+ " }\n"
+ "}";

private static final String BM25_SEARCH_REQUEST_WITH_CONVO_AND_IMAGE_TEMPLATE = "{\n"
+ " \"_source\": [\"%s\"],\n"
+ " \"query\" : {\n"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ public class DefaultLlmImpl implements Llm {
private static final String CONNECTOR_OUTPUT_MESSAGE_ROLE = "role";
private static final String CONNECTOR_OUTPUT_MESSAGE_CONTENT = "content";
private static final String CONNECTOR_OUTPUT_ERROR = "error";
private static final String BEDROCK_COMPLETION_FIELD = "completion";
private static final String BEDROCK_CONTENT_FIELD = "content";
private static final String BEDROCK_TEXT_FIELD = "text";

private final String openSearchModelId;

Expand Down Expand Up @@ -191,8 +194,38 @@ protected ChatCompletionOutput buildChatCompletionOutput(ModelProvider provider,
answers = List.of(message.get(CONNECTOR_OUTPUT_MESSAGE_CONTENT));
}
} else if (provider == ModelProvider.BEDROCK) {
answerField = "completion";
fillAnswersOrErrors(dataAsMap, answers, errors, answerField, errorField, defaultErrorMessageField);
// Handle Bedrock model responses (supports both legacy completion and newer content/text formats)

Object contentObj = dataAsMap.get(BEDROCK_CONTENT_FIELD);
if (contentObj == null) {
// Legacy completion-style format
Object completion = dataAsMap.get(BEDROCK_COMPLETION_FIELD);
if (completion != null) {
answers.add(completion.toString());
} else {
errors.add("Unsupported Bedrock response format: " + dataAsMap.keySet());
log.error("Unknown Bedrock response format: {}", dataAsMap);
}
} else {
// Fail-fast checks for new content/text format
if (!(contentObj instanceof List<?> contentList)) {
errors.add("Unexpected type for '" + BEDROCK_CONTENT_FIELD + "' in Bedrock response.");
} else if (contentList.isEmpty()) {
errors.add("Empty content list in Bedrock response.");
} else {
Object first = contentList.get(0);
if (!(first instanceof Map<?, ?> firstMap)) {
errors.add("Unexpected content format in Bedrock response.");
} else {
Object text = firstMap.get(BEDROCK_TEXT_FIELD);
if (text == null) {
errors.add("Bedrock content response missing '" + BEDROCK_TEXT_FIELD + "' field.");
} else {
answers.add(text.toString());
}
}
}
}
} else if (provider == ModelProvider.COHERE) {
answerField = "text";
fillAnswersOrErrors(dataAsMap, answers, errors, answerField, errorField, defaultErrorMessageField);
Expand Down
Loading