Skip to content

Commit 0a7734d

Browse files
jiapingzengylwu-amzn
authored andcommitted
combine json chunks
Signed-off-by: Jiaping Zeng <jpz@amazon.com>
1 parent f5510c9 commit 0a7734d

File tree

1 file changed

+21
-7
lines changed

1 file changed

+21
-7
lines changed

plugin/src/main/java/org/opensearch/ml/rest/RestMLExecuteStreamAction.java

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import static org.opensearch.ml.utils.RestActionUtils.isAsync;
1818
import static org.opensearch.ml.utils.TenantAwareHelper.getTenantID;
1919

20+
import java.io.ByteArrayOutputStream;
2021
import java.io.IOException;
2122
import java.io.UncheckedIOException;
2223
import java.nio.ByteBuffer;
@@ -48,7 +49,6 @@
4849
import org.opensearch.ml.common.MLModel;
4950
import org.opensearch.ml.common.agent.MLAgent;
5051
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
51-
import org.opensearch.ml.common.exception.MLException;
5252
import org.opensearch.ml.common.input.Input;
5353
import org.opensearch.ml.common.input.MLInput;
5454
import org.opensearch.ml.common.input.execute.agent.AgentMLInput;
@@ -158,10 +158,12 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client
158158
);
159159
channel.prepareResponse(RestStatus.OK, headers);
160160

161-
Flux.from(channel).ofType(HttpChunk.class).concatMap(chunk -> {
162-
final CompletableFuture<HttpChunk> future = new CompletableFuture<>();
161+
Flux.from(channel).ofType(HttpChunk.class).collectList().flatMap(chunks -> {
163162
try {
164-
MLExecuteTaskRequest mlExecuteTaskRequest = getRequest(agentId, request, chunk.content());
163+
BytesReference completeContent = combineChunks(chunks);
164+
MLExecuteTaskRequest mlExecuteTaskRequest = getRequest(agentId, request, completeContent);
165+
166+
final CompletableFuture<HttpChunk> future = new CompletableFuture<>();
165167
StreamTransportResponseHandler<MLTaskResponse> handler = new StreamTransportResponseHandler<MLTaskResponse>() {
166168
@Override
167169
public void handleStreamResponse(StreamTransportResponse<MLTaskResponse> streamResponse) {
@@ -214,11 +216,10 @@ public MLTaskResponse read(StreamInput in) throws IOException {
214216
handler
215217
);
216218

219+
return Mono.fromCompletionStage(future);
217220
} catch (IOException e) {
218-
throw new MLException("Got an exception in flux.", e);
221+
return Mono.error(new OpenSearchStatusException("Failed to parse request", RestStatus.BAD_REQUEST, e));
219222
}
220-
221-
return Mono.fromCompletionStage(future);
222223
}).doOnNext(channel::sendChunk).onErrorComplete(ex -> {
223224
// Error handling
224225
try {
@@ -402,6 +403,19 @@ private String extractTensorResult(MLTaskResponse response, String tensorName) {
402403
return Map.of();
403404
}
404405

406+
private BytesReference combineChunks(List<HttpChunk> chunks) {
407+
try {
408+
ByteArrayOutputStream buffer = new ByteArrayOutputStream();
409+
for (HttpChunk chunk : chunks) {
410+
chunk.content().writeTo(buffer);
411+
}
412+
return BytesReference.fromByteBuffer(ByteBuffer.wrap(buffer.toByteArray()));
413+
} catch (IOException e) {
414+
log.error("Failed to combine chunks", e);
415+
throw new OpenSearchStatusException("Failed to combine request chunks", RestStatus.INTERNAL_SERVER_ERROR, e);
416+
}
417+
}
418+
405419
private HttpChunk createHttpChunk(String sseData, boolean isLast) {
406420
BytesReference bytesRef = BytesReference.fromByteBuffer(ByteBuffer.wrap(sseData.getBytes()));
407421
return new HttpChunk() {

0 commit comments

Comments
 (0)