| 
17 | 17 | import static org.opensearch.ml.utils.RestActionUtils.isAsync;  | 
18 | 18 | import static org.opensearch.ml.utils.TenantAwareHelper.getTenantID;  | 
19 | 19 | 
 
  | 
 | 20 | +import java.io.ByteArrayOutputStream;  | 
20 | 21 | import java.io.IOException;  | 
21 | 22 | import java.io.UncheckedIOException;  | 
22 | 23 | import java.nio.ByteBuffer;  | 
 | 
48 | 49 | import org.opensearch.ml.common.MLModel;  | 
49 | 50 | import org.opensearch.ml.common.agent.MLAgent;  | 
50 | 51 | import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;  | 
51 |  | -import org.opensearch.ml.common.exception.MLException;  | 
52 | 52 | import org.opensearch.ml.common.input.Input;  | 
53 | 53 | import org.opensearch.ml.common.input.MLInput;  | 
54 | 54 | import org.opensearch.ml.common.input.execute.agent.AgentMLInput;  | 
@@ -158,10 +158,12 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client  | 
158 | 158 |                 );  | 
159 | 159 |             channel.prepareResponse(RestStatus.OK, headers);  | 
160 | 160 | 
 
  | 
161 |  | -            Flux.from(channel).ofType(HttpChunk.class).concatMap(chunk -> {  | 
162 |  | -                final CompletableFuture<HttpChunk> future = new CompletableFuture<>();  | 
 | 161 | +            Flux.from(channel).ofType(HttpChunk.class).collectList().flatMap(chunks -> {  | 
163 | 162 |                 try {  | 
164 |  | -                    MLExecuteTaskRequest mlExecuteTaskRequest = getRequest(agentId, request, chunk.content());  | 
 | 163 | +                    BytesReference completeContent = combineChunks(chunks);  | 
 | 164 | +                    MLExecuteTaskRequest mlExecuteTaskRequest = getRequest(agentId, request, completeContent);  | 
 | 165 | + | 
 | 166 | +                    final CompletableFuture<HttpChunk> future = new CompletableFuture<>();  | 
165 | 167 |                     StreamTransportResponseHandler<MLTaskResponse> handler = new StreamTransportResponseHandler<MLTaskResponse>() {  | 
166 | 168 |                         @Override  | 
167 | 169 |                         public void handleStreamResponse(StreamTransportResponse<MLTaskResponse> streamResponse) {  | 
@@ -214,11 +216,10 @@ public MLTaskResponse read(StreamInput in) throws IOException {  | 
214 | 216 |                             handler  | 
215 | 217 |                         );  | 
216 | 218 | 
 
  | 
 | 219 | +                    return Mono.fromCompletionStage(future);  | 
217 | 220 |                 } catch (IOException e) {  | 
218 |  | -                    throw new MLException("Got an exception in flux.", e);  | 
 | 221 | +                    return Mono.error(new OpenSearchStatusException("Failed to parse request", RestStatus.BAD_REQUEST, e));  | 
219 | 222 |                 }  | 
220 |  | - | 
221 |  | -                return Mono.fromCompletionStage(future);  | 
222 | 223 |             }).doOnNext(channel::sendChunk).onErrorComplete(ex -> {  | 
223 | 224 |                 // Error handling  | 
224 | 225 |                 try {  | 
@@ -402,6 +403,19 @@ private String extractTensorResult(MLTaskResponse response, String tensorName) {  | 
402 | 403 |         return Map.of();  | 
403 | 404 |     }  | 
404 | 405 | 
 
  | 
 | 406 | +    private BytesReference combineChunks(List<HttpChunk> chunks) {  | 
 | 407 | +        try {  | 
 | 408 | +            ByteArrayOutputStream buffer = new ByteArrayOutputStream();  | 
 | 409 | +            for (HttpChunk chunk : chunks) {  | 
 | 410 | +                chunk.content().writeTo(buffer);  | 
 | 411 | +            }  | 
 | 412 | +            return BytesReference.fromByteBuffer(ByteBuffer.wrap(buffer.toByteArray()));  | 
 | 413 | +        } catch (IOException e) {  | 
 | 414 | +            log.error("Failed to combine chunks", e);  | 
 | 415 | +            throw new OpenSearchStatusException("Failed to combine request chunks", RestStatus.INTERNAL_SERVER_ERROR, e);  | 
 | 416 | +        }  | 
 | 417 | +    }  | 
 | 418 | + | 
405 | 419 |     private HttpChunk createHttpChunk(String sseData, boolean isLast) {  | 
406 | 420 |         BytesReference bytesRef = BytesReference.fromByteBuffer(ByteBuffer.wrap(sseData.getBytes()));  | 
407 | 421 |         return new HttpChunk() {  | 
 | 
0 commit comments