Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import static org.opensearch.ml.utils.RestActionUtils.isAsync;
import static org.opensearch.ml.utils.TenantAwareHelper.getTenantID;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.nio.ByteBuffer;
Expand Down Expand Up @@ -48,7 +49,6 @@
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.agent.MLAgent;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.input.Input;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.input.execute.agent.AgentMLInput;
Expand Down Expand Up @@ -158,10 +158,12 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client
);
channel.prepareResponse(RestStatus.OK, headers);

Flux.from(channel).ofType(HttpChunk.class).concatMap(chunk -> {
final CompletableFuture<HttpChunk> future = new CompletableFuture<>();
Flux.from(channel).ofType(HttpChunk.class).collectList().flatMap(chunks -> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this possible to add some unit tests?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not possible to create UT for prepareRequest as chunk collection happens async. However, I did try sanity test with chunked and non-chunked request, see my latest comment.

try {
MLExecuteTaskRequest mlExecuteTaskRequest = getRequest(agentId, request, chunk.content());
BytesReference completeContent = combineChunks(chunks);
MLExecuteTaskRequest mlExecuteTaskRequest = getRequest(agentId, request, completeContent);

final CompletableFuture<HttpChunk> future = new CompletableFuture<>();
StreamTransportResponseHandler<MLTaskResponse> handler = new StreamTransportResponseHandler<MLTaskResponse>() {
@Override
public void handleStreamResponse(StreamTransportResponse<MLTaskResponse> streamResponse) {
Expand Down Expand Up @@ -214,11 +216,10 @@ public MLTaskResponse read(StreamInput in) throws IOException {
handler
);

return Mono.fromCompletionStage(future);
} catch (IOException e) {
throw new MLException("Got an exception in flux.", e);
return Mono.error(new OpenSearchStatusException("Failed to parse request", RestStatus.BAD_REQUEST, e));
}

return Mono.fromCompletionStage(future);
}).doOnNext(channel::sendChunk).onErrorComplete(ex -> {
// Error handling
try {
Expand Down Expand Up @@ -402,6 +403,20 @@ private String extractTensorResult(MLTaskResponse response, String tensorName) {
return Map.of();
}

@VisibleForTesting
BytesReference combineChunks(List<HttpChunk> chunks) {
try {
ByteArrayOutputStream buffer = new ByteArrayOutputStream();
for (HttpChunk chunk : chunks) {
chunk.content().writeTo(buffer);
}
return BytesReference.fromByteBuffer(ByteBuffer.wrap(buffer.toByteArray()));
} catch (IOException e) {
log.error("Failed to combine chunks", e);
throw new OpenSearchStatusException("Failed to combine request chunks", RestStatus.INTERNAL_SERVER_ERROR, e);
}
}

private HttpChunk createHttpChunk(String sseData, boolean isLast) {
BytesReference bytesRef = BytesReference.fromByteBuffer(ByteBuffer.wrap(sseData.getBytes()));
return new HttpChunk() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.bytes.BytesArray;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.http.HttpChunk;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.agent.LLMSpec;
Expand Down Expand Up @@ -302,4 +304,71 @@ public void testGetRequestAgentFrameworkDisabled() {
when(mlFeatureEnabledSetting.isAgentFrameworkEnabled()).thenReturn(false);
assertThrows(IllegalStateException.class, () -> restAction.handleRequest(request, channel, client));
}

@Test
public void testCombineChunksWithSingleChunk() {
String testContent = "{\"parameters\":{\"question\":\"test\"}}";
BytesArray bytesArray = new BytesArray(testContent);

HttpChunk mockChunk = mock(HttpChunk.class);
when(mockChunk.content()).thenReturn(bytesArray);

BytesReference result = restAction.combineChunks(List.of(mockChunk));

assertNotNull(result);
assertEquals(testContent, result.utf8ToString());
}

@Test
public void testCombineChunksWithMultipleChunks() {
String chunk1Content = "{\"parameters\":";
String chunk2Content = "{\"question\":";
String chunk3Content = "\"test\"}}";

BytesArray bytes1 = new BytesArray(chunk1Content);
BytesArray bytes2 = new BytesArray(chunk2Content);
BytesArray bytes3 = new BytesArray(chunk3Content);

HttpChunk mockChunk1 = mock(HttpChunk.class);
HttpChunk mockChunk2 = mock(HttpChunk.class);
HttpChunk mockChunk3 = mock(HttpChunk.class);

when(mockChunk1.content()).thenReturn(bytes1);
when(mockChunk2.content()).thenReturn(bytes2);
when(mockChunk3.content()).thenReturn(bytes3);

BytesReference result = restAction.combineChunks(List.of(mockChunk1, mockChunk2, mockChunk3));

assertNotNull(result);
String expectedContent = chunk1Content + chunk2Content + chunk3Content;
assertEquals(expectedContent, result.utf8ToString());
}

@Test
public void testCombineChunksWithEmptyList() {
BytesReference result = restAction.combineChunks(List.of());

assertNotNull(result);
assertEquals(0, result.length());
}

@Test
public void testCombineChunksWithLargeContent() {
StringBuilder largeContent = new StringBuilder();
for (int i = 0; i < 1000; i++) {
largeContent.append("chunk").append(i).append(",");
}
String content = largeContent.toString();

BytesArray bytesArray = new BytesArray(content);

HttpChunk mockChunk = mock(HttpChunk.class);
when(mockChunk.content()).thenReturn(bytesArray);

BytesReference result = restAction.combineChunks(List.of(mockChunk));

assertNotNull(result);
assertEquals(content.length(), result.length());
assertEquals(content, result.utf8ToString());
}
}
Loading