diff --git a/pom.xml b/pom.xml index 2b97718..8a79f43 100644 --- a/pom.xml +++ b/pom.xml @@ -57,7 +57,7 @@ software.amazon.payloadoffloading payloadoffloading-common - 2.2.0 + 2.3.0 diff --git a/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedAsyncClient.java b/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedAsyncClient.java index 7ebe3b8..7e8b1d6 100644 --- a/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedAsyncClient.java +++ b/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedAsyncClient.java @@ -9,6 +9,10 @@ import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClientUtil.isS3ReceiptHandle; import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClientUtil.updateMessageAttributePayloadSize; +import java.io.IOException; +import java.io.ByteArrayInputStream; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -47,9 +51,12 @@ import software.amazon.awssdk.services.sqs.model.SendMessageRequest; import software.amazon.awssdk.services.sqs.model.SendMessageResponse; import software.amazon.awssdk.utils.StringUtils; +import software.amazon.awssdk.utils.IoUtils; import software.amazon.payloadoffloading.PayloadStoreAsync; import software.amazon.payloadoffloading.S3AsyncDao; import software.amazon.payloadoffloading.S3BackedPayloadStoreAsync; +import software.amazon.payloadoffloading.S3BackedStreamPayloadStoreAsync; +import software.amazon.payloadoffloading.StreamPayloadStoreAsync; import software.amazon.payloadoffloading.Util; /** @@ -129,7 +136,12 @@ public AmazonSQSExtendedAsyncClient(SqsAsyncClient sqsClient, S3AsyncDao s3Dao = new S3AsyncDao(clientConfiguration.getS3AsyncClient(), clientConfiguration.getServerSideEncryptionStrategy(), clientConfiguration.getObjectCannedACL()); - this.payloadStore = new S3BackedPayloadStoreAsync(s3Dao, clientConfiguration.getS3BucketName()); + + if (clientConfiguration.isStreamUploadEnabled()) { + this.payloadStore = new S3BackedStreamPayloadStoreAsync(s3Dao, clientConfiguration.getS3BucketName()); + } else { + this.payloadStore = new S3BackedPayloadStoreAsync(s3Dao, clientConfiguration.getS3BucketName()); + } } /** @@ -284,6 +296,154 @@ public CompletableFuture receiveMessage(ReceiveMessageRe }); } + /** + * Receives messages from the specified queue with streaming access to large payloads. + * + * @param receiveMessageRequest The receive message request + * @return A CompletableFuture containing the response with streaming message access + */ + public CompletableFuture receiveMessageAsStream(ReceiveMessageRequest receiveMessageRequest) { + if (receiveMessageRequest == null) { + String errorMessage = "receiveMessageRequest cannot be null."; + LOG.error(errorMessage); + throw SdkClientException.create(errorMessage); + } + + ReceiveMessageRequest.Builder receiveMessageRequestBuilder = receiveMessageRequest.toBuilder(); + appendUserAgent(receiveMessageRequestBuilder); + + if (!clientConfiguration.isPayloadSupportEnabled()) { + // If payload support is disabled, fall back to regular receive and wrap in stream response + return super.receiveMessage(receiveMessageRequestBuilder.build()) + .thenApply(receiveMessageResponse -> { + List streamMessages = receiveMessageResponse.messages().stream() + .map(message -> new StreamMessage(message, null)) + .collect(Collectors.toList()); + return new ReceiveStreamMessageResponse(streamMessages); + }); + } + + // Remove before adding to avoid any duplicates + List messageAttributeNames = new ArrayList<>(receiveMessageRequest.messageAttributeNames()); + messageAttributeNames.removeAll(AmazonSQSExtendedClientUtil.RESERVED_ATTRIBUTE_NAMES); + messageAttributeNames.addAll(AmazonSQSExtendedClientUtil.RESERVED_ATTRIBUTE_NAMES); + receiveMessageRequestBuilder.messageAttributeNames(messageAttributeNames); + String queueUrl = receiveMessageRequest.queueUrl(); + receiveMessageRequest = receiveMessageRequestBuilder.build(); + + return super.receiveMessage(receiveMessageRequest) + .thenCompose(receiveMessageResponse -> { + List messages = receiveMessageResponse.messages(); + + // Check for no messages. If so, no need to process further. + if (messages.isEmpty()) { + return CompletableFuture.completedFuture(new ArrayList()); + } + + List> streamMessageFutures = new ArrayList<>(messages.size()); + for (Message message : messages) { + // For each received message check if they are stored in S3. + Optional largePayloadAttributeName = getReservedAttributeNameIfPresent( + message.messageAttributes()); + if (!largePayloadAttributeName.isPresent()) { + // Not S3 - create StreamMessage without stream + streamMessageFutures.add(CompletableFuture.completedFuture(new StreamMessage(message, null))); + } else { + // In S3 - get streaming access to payload + final String largeMessagePointer = message.body() + .replace("com.amazon.sqs.javamessaging.MessageS3Pointer", + "software.amazon.payloadoffloading.PayloadS3Pointer"); + + // Check if we have stream payload store for streaming + if (payloadStore instanceof StreamPayloadStoreAsync) { + StreamPayloadStoreAsync streamStore = (StreamPayloadStoreAsync) payloadStore; + + streamMessageFutures.add(streamStore.getOriginalPayloadStreamStream(largeMessagePointer) + .handle((stream, throwable) -> { + if (throwable != null) { + if (clientConfiguration.ignoresPayloadNotFound()) { + DeleteMessageRequest deleteMessageRequest = DeleteMessageRequest + .builder() + .queueUrl(queueUrl) + .receiptHandle(message.receiptHandle()) + .build(); + + deleteMessage(deleteMessageRequest).join(); + LOG.warn("Message deleted from SQS since payload with pointer could not be found in S3."); + return null; + } else { + throw new CompletionException(throwable); + } + } + + Message.Builder messageBuilder = message.toBuilder(); + // Remove the additional attribute before returning the message + Map messageAttributes = new HashMap<>( + message.messageAttributes()); + messageAttributes.keySet().removeAll(AmazonSQSExtendedClientUtil.RESERVED_ATTRIBUTE_NAMES); + messageBuilder.messageAttributes(messageAttributes); + + // Embed s3 object pointer in the receipt handle. + String modifiedReceiptHandle = embedS3PointerInReceiptHandle( + message.receiptHandle(), + largeMessagePointer); + messageBuilder.receiptHandle(modifiedReceiptHandle); + + return new StreamMessage(messageBuilder.build(), stream); + })); + } else { + // Fall back to regular payload retrieval if no streaming support + streamMessageFutures.add(payloadStore.getOriginalPayload(largeMessagePointer) + .handle((originalPayload, throwable) -> { + if (throwable != null) { + if (clientConfiguration.ignoresPayloadNotFound()) { + DeleteMessageRequest deleteMessageRequest = DeleteMessageRequest + .builder() + .queueUrl(queueUrl) + .receiptHandle(message.receiptHandle()) + .build(); + + deleteMessage(deleteMessageRequest).join(); + LOG.warn("Message deleted from SQS since payload with pointer could not be found in S3."); + return null; + } else { + throw new CompletionException(throwable); + } + } + + // Set original payload and create StreamMessage without stream + Message.Builder messageBuilder = message.toBuilder(); + messageBuilder.body(originalPayload); + + // Remove the additional attribute before returning the message + Map messageAttributes = new HashMap<>( + message.messageAttributes()); + messageAttributes.keySet().removeAll(AmazonSQSExtendedClientUtil.RESERVED_ATTRIBUTE_NAMES); + messageBuilder.messageAttributes(messageAttributes); + + // Embed s3 object pointer in the receipt handle. + String modifiedReceiptHandle = embedS3PointerInReceiptHandle( + message.receiptHandle(), + largeMessagePointer); + messageBuilder.receiptHandle(modifiedReceiptHandle); + + return new StreamMessage(messageBuilder.build(), null); + })); + } + } + } + + // Convert list of stream message futures to a future list of stream messages. + return CompletableFuture.allOf( + streamMessageFutures.toArray(new CompletableFuture[streamMessageFutures.size()])) + .thenApply(v -> streamMessageFutures.stream() + .map(CompletableFuture::join) + .filter(Objects::nonNull) + .collect(Collectors.toList())); + }) + .thenApply(streamMessages -> new ReceiveStreamMessageResponse(streamMessages)); + } + /** * {@inheritDoc} */ @@ -443,6 +603,69 @@ public CompletableFuture deleteMessageBatch( return super.deleteMessageBatch(deleteMessageBatchRequestBuilder.build()); } + /** + *

+ * Delivers a message to the specified queue using an InputStream for the message body. + * This method allows sending large messages without loading them entirely into memory. + *

+ * + * @param sendMessageRequest The send message request with message body as InputStream + * @param messageBodyStream InputStream containing the message body content + * @param contentLength The total length of the content in the stream + * @return CompletableFuture of the SendMessage operation returned by the service. + */ + public CompletableFuture sendStreamMessage(SendMessageRequest sendMessageRequest, + InputStream messageBodyStream, + long contentLength) { + if (sendMessageRequest == null) { + String errorMessage = "sendMessageRequest cannot be null."; + LOG.error(errorMessage); + throw SdkClientException.create(errorMessage); + } + + SendMessageRequest.Builder sendMessageRequestBuilder = sendMessageRequest.toBuilder(); + sendMessageRequest = appendUserAgent(sendMessageRequestBuilder).build(); + + if (!clientConfiguration.isPayloadSupportEnabled()) { + // Convert stream to string for non-extended client + try { + String messageBody = IoUtils.toUtf8String(messageBodyStream); + SendMessageRequest finalRequest = sendMessageRequest.toBuilder().messageBody(messageBody).build(); + return super.sendMessage(finalRequest); + } catch (IOException e) { + CompletableFuture failedFuture = new CompletableFuture<>(); + failedFuture.completeExceptionally(new RuntimeException("Failed to read from InputStream", e)); + return failedFuture; + } + } + + if (messageBodyStream == null) { + String errorMessage = "messageBodyStream cannot be null."; + LOG.error(errorMessage); + throw SdkClientException.create(errorMessage); + } + + //Check message attributes for ExtendedClient related constraints + checkMessageAttributes(clientConfiguration.getPayloadSizeThreshold(), sendMessageRequest.messageAttributes()); + + if (clientConfiguration.isAlwaysThroughS3() + || contentLength >= clientConfiguration.getPayloadSizeThreshold()) { + return storeStreamMessageInS3(sendMessageRequest, messageBodyStream, contentLength) + .thenCompose(modifiedRequest -> super.sendMessage(modifiedRequest)); + } else { + // Convert stream to string for small messages + try { + String messageBody = IoUtils.toUtf8String(messageBodyStream); + SendMessageRequest finalRequest = sendMessageRequest.toBuilder().messageBody(messageBody).build(); + return super.sendMessage(finalRequest); + } catch (IOException e) { + CompletableFuture failedFuture = new CompletableFuture<>(); + failedFuture.completeExceptionally(new RuntimeException("Failed to read from InputStream", e)); + return failedFuture; + } + } + } + /** * {@inheritDoc} */ @@ -530,6 +753,58 @@ private CompletableFuture storeOriginalPayload(String messageContentStr) return payloadStore.storeOriginalPayload(messageContentStr, s3KeyPrefix + UUID.randomUUID()); } + + private CompletableFuture storeStreamMessageInS3(SendMessageRequest sendMessageRequest, InputStream messageBodyStream, long contentLength) { + SendMessageRequest.Builder sendMessageRequestBuilder = sendMessageRequest.toBuilder(); + + sendMessageRequestBuilder.messageAttributes( + updateMessageAttributePayloadSize(sendMessageRequest.messageAttributes(), contentLength, + clientConfiguration.usesLegacyReservedAttributeName())); + + // Store the message content in S3. + return storeOriginalPayload(messageBodyStream) + .thenApply(largeMessagePointer -> { + sendMessageRequestBuilder.messageBody(largeMessagePointer); + return sendMessageRequestBuilder.build(); + }); + } + + private CompletableFuture storeOriginalPayload(InputStream messageContentStream) { + String s3KeyPrefix = clientConfiguration.getS3KeyPrefix(); + String key = StringUtils.isBlank(s3KeyPrefix) ? UUID.randomUUID().toString() : s3KeyPrefix + UUID.randomUUID(); + + if (payloadStore instanceof StreamPayloadStoreAsync) { + return ((StreamPayloadStoreAsync) payloadStore) + .storeOriginalPayloadStream(messageContentStream, key) + .exceptionally(ex -> { + LOG.warn("Stream upload attempt failed; falling back to standard single-part upload."); + try { + messageContentStream.reset(); + } catch (Exception resetEx) { + LOG.warn("Failed to reset stream after multipart upload failure, stream may be exhausted", resetEx); + throw new RuntimeException("Stream upload failed and stream cannot be reset", ex); + } + // Fall back to reading the stream and using string method + try { + String content = IoUtils.toUtf8String(messageContentStream); + return payloadStore.storeOriginalPayload(content, key).join(); + } catch (IOException ioEx) { + throw new RuntimeException("Failed to read from InputStream", ioEx); + } + }); + } + + // Fall back to reading the stream and using string method + try { + String content = IoUtils.toUtf8String(messageContentStream); + return payloadStore.storeOriginalPayload(content, key); + } catch (IOException e) { + CompletableFuture failedFuture = new CompletableFuture<>(); + failedFuture.completeExceptionally(new RuntimeException("Failed to read from InputStream", e)); + return failedFuture; + } + } + private static T appendUserAgent(final T builder) { return AmazonSQSExtendedClientUtil.appendUserAgent(builder, USER_AGENT_NAME, USER_AGENT_VERSION); } diff --git a/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClient.java b/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClient.java index 5b372a9..9a6ed07 100644 --- a/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClient.java +++ b/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClient.java @@ -24,6 +24,7 @@ import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClientUtil.isS3ReceiptHandle; import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClientUtil.updateMessageAttributePayloadSize; +import java.net.URI; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -33,8 +34,11 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; + import software.amazon.awssdk.awscore.AwsRequest; import software.amazon.awssdk.awscore.exception.AwsServiceException; +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.exception.SdkException; import software.amazon.awssdk.core.util.VersionInfo; @@ -75,11 +79,17 @@ import software.amazon.awssdk.services.sqs.model.SqsException; import software.amazon.awssdk.services.sqs.model.TooManyEntriesInBatchRequestException; import software.amazon.awssdk.utils.StringUtils; +import software.amazon.awssdk.utils.IoUtils; import software.amazon.payloadoffloading.PayloadStore; +import software.amazon.payloadoffloading.StreamPayloadStore; import software.amazon.payloadoffloading.S3BackedPayloadStore; +import software.amazon.payloadoffloading.S3BackedStreamPayloadStore; import software.amazon.payloadoffloading.S3Dao; import software.amazon.payloadoffloading.Util; +import java.io.InputStream; +import java.io.ByteArrayInputStream; +import java.io.IOException; /** * Amazon SQS Extended Client extends the functionality of Amazon SQS client. @@ -142,10 +152,19 @@ public AmazonSQSExtendedClient(SqsClient sqsClient) { public AmazonSQSExtendedClient(SqsClient sqsClient, ExtendedClientConfiguration extendedClientConfig) { super(sqsClient); this.clientConfiguration = new ExtendedClientConfiguration(extendedClientConfig); - S3Dao s3Dao = new S3Dao(clientConfiguration.getS3Client(), - clientConfiguration.getServerSideEncryptionStrategy(), - clientConfiguration.getObjectCannedACL()); - this.payloadStore = new S3BackedPayloadStore(s3Dao, clientConfiguration.getS3BucketName()); + + if (clientConfiguration.isStreamUploadEnabled()) { + S3Dao s3Dao = new S3Dao( + clientConfiguration.getS3Client(), + clientConfiguration.getServerSideEncryptionStrategy(), + clientConfiguration.getObjectCannedACL(), + clientConfiguration.getStreamUploadPartSize(), + clientConfiguration.getStreamUploadThreshold()); + this.payloadStore = new S3BackedStreamPayloadStore(s3Dao, clientConfiguration.getS3BucketName()); + } else { + S3Dao s3Dao = new S3Dao(clientConfiguration.getS3Client(), clientConfiguration.getServerSideEncryptionStrategy(), clientConfiguration.getObjectCannedACL()); + this.payloadStore = new S3BackedPayloadStore(s3Dao, clientConfiguration.getS3BucketName()); + } } /** @@ -376,6 +395,88 @@ public ReceiveMessageResponse receiveMessage(ReceiveMessageRequest receiveMessag return receiveMessageResponseBuilder.build(); } + public ReceiveStreamMessageResponse receiveMessageAsStream(ReceiveMessageRequest receiveMessageRequest) { + //TODO: Clone request since it's modified in this method and will cause issues if the client reuses request object. + if (receiveMessageRequest == null) { + String errorMessage = "receiveMessageRequest cannot be null."; + LOG.error(errorMessage); + throw SdkClientException.create(errorMessage); + } + + ReceiveMessageRequest.Builder receiveMessageRequestBuilder = receiveMessageRequest.toBuilder(); + appendUserAgent(receiveMessageRequestBuilder); + + if (!clientConfiguration.isPayloadSupportEnabled()) { + // If payload support is disabled, fall back to regular receive and wrap in StreamMessage with null streams + ReceiveMessageResponse response = super.receiveMessage(receiveMessageRequestBuilder.build()); + List streamMessages = response.messages().stream() + .map(message -> new StreamMessage(message, null)) + .collect(java.util.stream.Collectors.toList()); + return new ReceiveStreamMessageResponse(streamMessages); + } + + //Remove before adding to avoid any duplicates + List messageAttributeNames = new ArrayList<>(receiveMessageRequest.messageAttributeNames()); + messageAttributeNames.removeAll(AmazonSQSExtendedClientUtil.RESERVED_ATTRIBUTE_NAMES); + messageAttributeNames.addAll(AmazonSQSExtendedClientUtil.RESERVED_ATTRIBUTE_NAMES); + receiveMessageRequestBuilder.messageAttributeNames(messageAttributeNames); + receiveMessageRequest = receiveMessageRequestBuilder.build(); + + ReceiveMessageResponse receiveMessageResponse = super.receiveMessage(receiveMessageRequest); + List streamMessages = new ArrayList<>(receiveMessageResponse.messages().size()); + + for (Message message : receiveMessageResponse.messages()) { + Message.Builder messageBuilder = message.toBuilder(); + ResponseInputStream payloadStream = null; + + // for each received message check if they are stored in S3. + Optional largePayloadAttributeName = getReservedAttributeNameIfPresent(message.messageAttributes()); + if (largePayloadAttributeName.isPresent()) { + String largeMessagePointer = message.body(); + largeMessagePointer = largeMessagePointer.replace("com.amazon.sqs.javamessaging.MessageS3Pointer", "software.amazon.payloadoffloading.PayloadS3Pointer"); + + try { + if (payloadStore instanceof StreamPayloadStore) { + payloadStream = ((StreamPayloadStore) payloadStore).getOriginalPayloadStreamStream(largeMessagePointer); + } else { + // Fallback: load into memory and create a stream from it + System.out.println("Warning: payload store is not a StreamPayloadStore, loading entire payload into memory"); + String payload = payloadStore.getOriginalPayload(largeMessagePointer); + ByteArrayInputStream byteStream = new ByteArrayInputStream(payload.getBytes(java.nio.charset.StandardCharsets.UTF_8)); + payloadStream = new ResponseInputStream<>(GetObjectResponse.builder().build(), byteStream); + } + } catch (SdkException e) { + if (e.getCause() instanceof NoSuchKeyException && clientConfiguration.ignoresPayloadNotFound()) { + DeleteMessageRequest deleteMessageRequest = DeleteMessageRequest + .builder() + .queueUrl(receiveMessageRequest.queueUrl()) + .receiptHandle(message.receiptHandle()) + .build(); + deleteMessage(deleteMessageRequest); + LOG.warn("Message deleted from SQS since payload with pointer could not be found in S3."); + continue; + } else throw e; + } + + // remove the additional attribute before returning the message to user + Map messageAttributes = new HashMap<>(message.messageAttributes()); + messageAttributes.keySet().removeAll(AmazonSQSExtendedClientUtil.RESERVED_ATTRIBUTE_NAMES); + messageBuilder.messageAttributes(messageAttributes); + + // Embed s3 object pointer in the receipt handle. + String modifiedReceiptHandle = embedS3PointerInReceiptHandle( + message.receiptHandle(), + largeMessagePointer); + + messageBuilder.receiptHandle(modifiedReceiptHandle); + } + + streamMessages.add(new StreamMessage(messageBuilder.build(), payloadStream)); + } + + return new ReceiveStreamMessageResponse(streamMessages); + } + /** *

* Deletes the specified message from the specified queue. To select the message to delete, use the @@ -652,6 +753,150 @@ public SendMessageBatchResponse sendMessageBatch(SendMessageBatchRequest sendMes return super.sendMessageBatch(sendMessageBatchRequest); } + /** + *

+ * Delivers a message to the specified queue using an InputStream for the message body. + * This method allows sending large messages without loading them entirely into memory. + *

+ * + * @param sendMessageRequest The send message request with message body as InputStream + * @param messageBodyStream InputStream containing the message body content + * @param contentLength The total length of the content in the stream + * @return Result of the SendMessage operation returned by the service. + */ + public SendMessageResponse sendStreamMessage(SendMessageRequest sendMessageRequest, InputStream messageBodyStream, long contentLength) { + if (sendMessageRequest == null) { + String errorMessage = "sendMessageRequest cannot be null."; + LOG.error(errorMessage); + throw SdkClientException.create(errorMessage); + } + + SendMessageRequest.Builder sendMessageRequestBuilder = sendMessageRequest.toBuilder(); + sendMessageRequest = appendUserAgent(sendMessageRequestBuilder).build(); + + if (!clientConfiguration.isPayloadSupportEnabled()) { + // Convert stream to string for non-extended client + try { + String messageBody = IoUtils.toUtf8String(messageBodyStream); + sendMessageRequest = sendMessageRequest.toBuilder().messageBody(messageBody).build(); + return super.sendMessage(sendMessageRequest); + } catch (IOException e) { + throw new RuntimeException("Failed to read from InputStream", e); + } + } + + if (messageBodyStream == null) { + String errorMessage = "messageBodyStream cannot be null."; + LOG.error(errorMessage); + throw SdkClientException.create(errorMessage); + } + + //Check message attributes for ExtendedClient related constraints + checkMessageAttributes(clientConfiguration.getPayloadSizeThreshold(), sendMessageRequest.messageAttributes()); + + if (clientConfiguration.isAlwaysThroughS3() + || contentLength >= clientConfiguration.getPayloadSizeThreshold()) { + sendMessageRequest = storeStreamMessageInS3(sendMessageRequest, messageBodyStream, contentLength); + } else { + // Convert stream to string for small messages + try { + String messageBody = IoUtils.toUtf8String(messageBodyStream); + sendMessageRequest = sendMessageRequest.toBuilder().messageBody(messageBody).build(); + } catch (IOException e) { + throw new RuntimeException("Failed to read from InputStream", e); + } + } + return super.sendMessage(sendMessageRequest); + } + + /** + *

+ * Delivers up to ten messages to the specified queue using InputStreams for message bodies. + * This method allows sending large messages without loading them entirely into memory. + *

+ * + * @param sendMessageBatchRequest The send message batch request + * @param messageBodyStreams List of InputStreams containing message body content for each entry + * @param contentLengths List of content lengths corresponding to each stream + * @return Result of the SendMessageBatch operation returned by the service. + */ + public SendMessageBatchResponse sendStreamMessageBatch(SendMessageBatchRequest sendMessageBatchRequest, + List messageBodyStreams, + List contentLengths) { + + if (sendMessageBatchRequest == null) { + String errorMessage = "sendMessageBatchRequest cannot be null."; + LOG.error(errorMessage); + throw SdkClientException.create(errorMessage); + } + + if (messageBodyStreams == null || contentLengths == null) { + String errorMessage = "messageBodyStreams and contentLengths cannot be null."; + LOG.error(errorMessage); + throw SdkClientException.create(errorMessage); + } + + if (messageBodyStreams.size() != sendMessageBatchRequest.entries().size() || + contentLengths.size() != sendMessageBatchRequest.entries().size()) { + String errorMessage = "messageBodyStreams and contentLengths must have the same size as batch entries."; + LOG.error(errorMessage); + throw SdkClientException.create(errorMessage); + } + + SendMessageBatchRequest.Builder sendMessageBatchRequestBuilder = sendMessageBatchRequest.toBuilder(); + appendUserAgent(sendMessageBatchRequestBuilder); + sendMessageBatchRequest = sendMessageBatchRequestBuilder.build(); + + if (!clientConfiguration.isPayloadSupportEnabled()) { + // Convert streams to strings for non-extended client + List entries = new ArrayList<>(); + for (int i = 0; i < sendMessageBatchRequest.entries().size(); i++) { + SendMessageBatchRequestEntry entry = sendMessageBatchRequest.entries().get(i); + try { + String messageBody = IoUtils.toUtf8String(messageBodyStreams.get(i)); + entries.add(entry.toBuilder().messageBody(messageBody).build()); + } catch (IOException e) { + throw new RuntimeException("Failed to read from InputStream", e); + } + } + sendMessageBatchRequest = sendMessageBatchRequest.toBuilder().entries(entries).build(); + return super.sendMessageBatch(sendMessageBatchRequest); + } + + List batchEntries = new ArrayList<>(sendMessageBatchRequest.entries().size()); + + boolean hasS3Entries = false; + for (int i = 0; i < sendMessageBatchRequest.entries().size(); i++) { + SendMessageBatchRequestEntry entry = sendMessageBatchRequest.entries().get(i); + InputStream stream = messageBodyStreams.get(i); + long contentLength = contentLengths.get(i); + + //Check message attributes for ExtendedClient related constraints + checkMessageAttributes(clientConfiguration.getPayloadSizeThreshold(), entry.messageAttributes()); + + if (clientConfiguration.isAlwaysThroughS3() + || contentLength >= clientConfiguration.getPayloadSizeThreshold()) { + entry = storeStreamMessageInS3(entry, stream, contentLength); + hasS3Entries = true; + } else { + // Convert stream to string for small messages + try { + String messageBody = IoUtils.toUtf8String(stream); + entry = entry.toBuilder().messageBody(messageBody).build(); + } catch (IOException e) { + throw new RuntimeException("Failed to read from InputStream", e); + } + } + batchEntries.add(entry); + } + + if (hasS3Entries) { + sendMessageBatchRequest = sendMessageBatchRequest.toBuilder().entries(batchEntries).build(); + } + + return super.sendMessageBatch(sendMessageBatchRequest); + } + /** *

* Deletes up to ten messages from the specified queue. This is a batch version of @@ -900,6 +1145,68 @@ private String storeOriginalPayload(String messageContentStr) { } return payloadStore.storeOriginalPayload(messageContentStr, s3KeyPrefix + UUID.randomUUID()); } + + private SendMessageRequest storeStreamMessageInS3(SendMessageRequest sendMessageRequest, InputStream messageBodyStream, long contentLength) { + SendMessageRequest.Builder sendMessageRequestBuilder = sendMessageRequest.toBuilder(); + + sendMessageRequestBuilder.messageAttributes( + updateMessageAttributePayloadSize(sendMessageRequest.messageAttributes(), contentLength, + clientConfiguration.usesLegacyReservedAttributeName())); + + // Store the message content in S3. + String largeMessagePointer = storeOriginalPayload(messageBodyStream); + sendMessageRequestBuilder.messageBody(largeMessagePointer); + + return sendMessageRequestBuilder.build(); + } + + private SendMessageBatchRequestEntry storeStreamMessageInS3(SendMessageBatchRequestEntry batchEntry, InputStream messageBodyStream, long contentLength) { + SendMessageBatchRequestEntry.Builder batchEntryBuilder = batchEntry.toBuilder(); + + batchEntryBuilder.messageAttributes( + updateMessageAttributePayloadSize(batchEntry.messageAttributes(), contentLength, + clientConfiguration.usesLegacyReservedAttributeName())); + + // Store the message content in S3. + String largeMessagePointer = storeOriginalPayload(messageBodyStream); + batchEntryBuilder.messageBody(largeMessagePointer); + + return batchEntryBuilder.build(); + } + + private String storeOriginalPayload(InputStream messageContentStream) { + String s3KeyPrefix = clientConfiguration.getS3KeyPrefix(); + String key = StringUtils.isBlank(s3KeyPrefix) ? UUID.randomUUID().toString() : s3KeyPrefix + UUID.randomUUID(); + + if (payloadStore instanceof StreamPayloadStore) { + try { + return ((StreamPayloadStore) payloadStore).storeOriginalPayloadStream(messageContentStream, key); + } catch (RuntimeException e) { + LOG.warn("Stream upload attempt failed; falling back to standard single-part upload."); + try { + messageContentStream.reset(); + } catch (Exception resetEx) { + LOG.warn("Failed to reset stream after multipart upload failure, stream may be exhausted", resetEx); + throw e; + } + // Fall back to reading the stream and using string method + try { + String content = IoUtils.toUtf8String(messageContentStream); + return payloadStore.storeOriginalPayload(content, key); + } catch (IOException ioEx) { + throw new RuntimeException("Failed to read from InputStream", ioEx); + } + } + } + + // Fall back to reading the stream and using string method + try { + String content = IoUtils.toUtf8String(messageContentStream); + return payloadStore.storeOriginalPayload(content, key); + } catch (IOException e) { + throw new RuntimeException("Failed to read from InputStream", e); + } + } @SuppressWarnings("unchecked") private static T appendUserAgent(final T builder) { diff --git a/src/main/java/com/amazon/sqs/javamessaging/ExtendedAsyncClientConfiguration.java b/src/main/java/com/amazon/sqs/javamessaging/ExtendedAsyncClientConfiguration.java index e96fff2..281bf3b 100644 --- a/src/main/java/com/amazon/sqs/javamessaging/ExtendedAsyncClientConfiguration.java +++ b/src/main/java/com/amazon/sqs/javamessaging/ExtendedAsyncClientConfiguration.java @@ -212,4 +212,10 @@ public ExtendedAsyncClientConfiguration withServerSideEncryption(ServerSideEncry this.setServerSideEncryptionStrategy(serverSideEncryption); return this; } + + @Override + public ExtendedAsyncClientConfiguration withStreamUploadEnabled(boolean enabled) { + setStreamUploadEnabled(enabled); + return this; + } } \ No newline at end of file diff --git a/src/main/java/com/amazon/sqs/javamessaging/ExtendedClientConfiguration.java b/src/main/java/com/amazon/sqs/javamessaging/ExtendedClientConfiguration.java index 75a30f8..9e0d41e 100644 --- a/src/main/java/com/amazon/sqs/javamessaging/ExtendedClientConfiguration.java +++ b/src/main/java/com/amazon/sqs/javamessaging/ExtendedClientConfiguration.java @@ -232,6 +232,26 @@ public ExtendedClientConfiguration withServerSideEncryption(ServerSideEncryption return this; } + /** + * Enables or disables stream upload support for large payload storage operations. + * @param enabled true to enable stream uploads when threshold exceeded. + * @return updated configuration + */ + public ExtendedClientConfiguration withStreamUploadEnabled(boolean enabled) { + setStreamUploadEnabled(enabled); + return this; + } + + public ExtendedClientConfiguration withStreamUploadThreshold(int threshold) { + setStreamUploadThreshold(threshold); + return this; + } + + public ExtendedClientConfiguration withStreamUploadPartSize(int partSize) { + setStreamUploadPartSize(partSize); + return this; + } + /** * Enables support for large-payload messages. * diff --git a/src/main/java/com/amazon/sqs/javamessaging/ReceiveStreamMessageResponse.java b/src/main/java/com/amazon/sqs/javamessaging/ReceiveStreamMessageResponse.java new file mode 100644 index 0000000..c93d458 --- /dev/null +++ b/src/main/java/com/amazon/sqs/javamessaging/ReceiveStreamMessageResponse.java @@ -0,0 +1,21 @@ +package com.amazon.sqs.javamessaging; + +import java.util.List; + +/** + * Response containing messages with streaming access to large payloads. + */ +public class ReceiveStreamMessageResponse { + private final List streamMessages; + + public ReceiveStreamMessageResponse(List streamMessages) { + this.streamMessages = streamMessages; + } + + /** + * @return list of messages with streaming payload access + */ + public List streamMessages() { + return streamMessages; + } +} \ No newline at end of file diff --git a/src/main/java/com/amazon/sqs/javamessaging/StreamMessage.java b/src/main/java/com/amazon/sqs/javamessaging/StreamMessage.java new file mode 100644 index 0000000..9f08986 --- /dev/null +++ b/src/main/java/com/amazon/sqs/javamessaging/StreamMessage.java @@ -0,0 +1,40 @@ +package com.amazon.sqs.javamessaging; + +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.services.sqs.model.Message; + +/** + * Represents a message received from SQS that can provide streaming access to large payloads stored in S3. + * Instead of loading the entire payload into memory, this provides access to the payload as a stream. + */ +public class StreamMessage { + private final Message originalMessage; + private final ResponseInputStream payloadStream; + + public StreamMessage(Message originalMessage, ResponseInputStream payloadStream) { + this.originalMessage = originalMessage; + this.payloadStream = payloadStream; + } + + /** + * @return the original SQS message metadata (messageId, receiptHandle, attributes, etc.) + */ + public Message getMessage() { + return originalMessage; + } + + /** + * @return stream for accessing the message payload, or null if payload is not stored in S3 + */ + public ResponseInputStream getPayloadStream() { + return payloadStream; + } + + /** + * @return true if this message has a streaming payload from S3 + */ + public boolean hasStreamPayload() { + return payloadStream != null; + } +} \ No newline at end of file diff --git a/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedAsyncClientTest.java b/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedAsyncClientTest.java index 219e57d..e3a42e9 100644 --- a/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedAsyncClientTest.java +++ b/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedAsyncClientTest.java @@ -6,9 +6,12 @@ import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNotEquals; import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.isA; import static org.mockito.Mockito.doThrow; @@ -20,6 +23,8 @@ import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; +import java.io.IOException; +import java.io.InputStream; import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Arrays; @@ -38,6 +43,7 @@ import software.amazon.awssdk.core.ResponseBytes; import software.amazon.awssdk.core.async.AsyncRequestBody; import software.amazon.awssdk.core.async.AsyncResponseTransformer; +import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.exception.SdkException; import software.amazon.awssdk.services.s3.S3AsyncClient; import software.amazon.awssdk.services.s3.model.DeleteObjectRequest; @@ -63,10 +69,14 @@ import software.amazon.awssdk.services.sqs.model.SendMessageBatchResponse; import software.amazon.awssdk.services.sqs.model.SendMessageRequest; import software.amazon.awssdk.services.sqs.model.SendMessageResponse; -import software.amazon.awssdk.utils.ImmutableMap; +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.payloadoffloading.StreamPayloadStoreAsync; +import software.amazon.payloadoffloading.PayloadStoreAsync; import software.amazon.payloadoffloading.PayloadS3Pointer; import software.amazon.payloadoffloading.ServerSideEncryptionFactory; import software.amazon.payloadoffloading.ServerSideEncryptionStrategy; +import software.amazon.awssdk.utils.ImmutableMap; public class AmazonSQSExtendedAsyncClientTest { @@ -88,6 +98,11 @@ public class AmazonSQSExtendedAsyncClientTest { // should be > 1 and << SQS_SIZE_LIMIT private static final int ARBITRARY_SMALLER_THRESHOLD = 500; + + // Stream upload thresholds + private static final int STREAM_UPLOAD_THRESHOLD = 5 * 1024 * 1024; // 5MB default + private static final int LESS_THAN_STREAM_THRESHOLD = STREAM_UPLOAD_THRESHOLD - 1; + private static final int MORE_THAN_STREAM_THRESHOLD = STREAM_UPLOAD_THRESHOLD + 1; @BeforeEach public void setupClients() { @@ -729,4 +744,338 @@ private String generateStringWithLength(int messageLength) { Arrays.fill(charArray, 'x'); return new String(charArray); } + + @Test + public void testReceiveMessageAsStream_PayloadSupportDisabled_ReturnsMessagesWithoutStreams() { + ExtendedAsyncClientConfiguration config = new ExtendedAsyncClientConfiguration() + .withPayloadSupportDisabled(); + AmazonSQSExtendedAsyncClient clientWithDisabledPayload = new AmazonSQSExtendedAsyncClient(mockSqsBackend, config); + + ReceiveMessageRequest request = ReceiveMessageRequest.builder() + .queueUrl(SQS_QUEUE_URL) + .build(); + + Message message = Message.builder() + .messageId("msg1") + .body("small message") + .receiptHandle("receipt1") + .build(); + + ReceiveMessageResponse sqsResponse = ReceiveMessageResponse.builder() + .messages(message) + .build(); + + when(mockSqsBackend.receiveMessage(any(ReceiveMessageRequest.class))) + .thenReturn(CompletableFuture.completedFuture(sqsResponse)); + + CompletableFuture future = clientWithDisabledPayload.receiveMessageAsStream(request); + ReceiveStreamMessageResponse response = future.join(); + + assertEquals(1, response.streamMessages().size()); + StreamMessage streamMessage = response.streamMessages().get(0); + assertEquals("msg1", streamMessage.getMessage().messageId()); + assertEquals("small message", streamMessage.getMessage().body()); + assertFalse(streamMessage.hasStreamPayload()); + } + + @Test + public void testReceiveMessageAsStream_LargeMessage_WithStreamStore_ReturnsMessageWithStream() throws IOException { + ExtendedAsyncClientConfiguration config = new ExtendedAsyncClientConfiguration() + .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME) + .withPayloadSizeThreshold(262144); // 256KB + config.setStreamUploadEnabled(true); + + AmazonSQSExtendedAsyncClient clientWithStream = new AmazonSQSExtendedAsyncClient(mockSqsBackend, config); + + ReceiveMessageRequest request = ReceiveMessageRequest.builder() + .queueUrl(SQS_QUEUE_URL) + .build(); + + String s3Pointer = new PayloadS3Pointer(S3_BUCKET_NAME, "test-key").toJson(); + Message message = Message.builder() + .messageId("msg1") + .body(s3Pointer) + .receiptHandle("receipt1") + .messageAttributes(ImmutableMap.of( + "ExtendedPayloadSize", MessageAttributeValue.builder().stringValue("300000").dataType("Number").build() + )) + .build(); + + ReceiveMessageResponse sqsResponse = ReceiveMessageResponse.builder() + .messages(message) + .build(); + + @SuppressWarnings("unchecked") + ResponseInputStream mockStream = mock(ResponseInputStream.class); + when(mockStream.read(any(byte[].class))).thenReturn(-1); + + when(mockSqsBackend.receiveMessage(any(ReceiveMessageRequest.class))) + .thenReturn(CompletableFuture.completedFuture(sqsResponse)); + + @SuppressWarnings("unchecked") + CompletableFuture> futureStream = CompletableFuture.completedFuture(mockStream); + when(mockS3.getObject(any(GetObjectRequest.class), any(AsyncResponseTransformer.class))) + .thenReturn(futureStream); + + CompletableFuture future = clientWithStream.receiveMessageAsStream(request); + ReceiveStreamMessageResponse response = future.join(); + + assertEquals(1, response.streamMessages().size()); + StreamMessage streamMessage = response.streamMessages().get(0); + assertEquals("msg1", streamMessage.getMessage().messageId()); + assertTrue(streamMessage.hasStreamPayload()); + assertSame(mockStream, streamMessage.getPayloadStream()); + assertTrue(streamMessage.getMessage().receiptHandle().contains("test-key")); + } + + @Test + public void testReceiveMessageAsStream_LargeMessage_WithoutStreamStore_FallsBackToRegularRetrieval() { + ExtendedAsyncClientConfiguration config = new ExtendedAsyncClientConfiguration() + .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME); + + AmazonSQSExtendedAsyncClient clientWithRegularStore = new AmazonSQSExtendedAsyncClient(mockSqsBackend, config); + + ReceiveMessageRequest request = ReceiveMessageRequest.builder() + .queueUrl(SQS_QUEUE_URL) + .build(); + + String s3Pointer = new PayloadS3Pointer(S3_BUCKET_NAME, "test-key").toJson(); + Message message = Message.builder() + .messageId("msg1") + .body(s3Pointer) + .receiptHandle("receipt1") + .messageAttributes(ImmutableMap.of( + "ExtendedPayloadSize", MessageAttributeValue.builder().stringValue("300000").dataType("Number").build() + )) + .build(); + + ReceiveMessageResponse sqsResponse = ReceiveMessageResponse.builder() + .messages(message) + .build(); + + when(mockSqsBackend.receiveMessage(any(ReceiveMessageRequest.class))) + .thenReturn(CompletableFuture.completedFuture(sqsResponse)); + + ResponseBytes s3Object = ResponseBytes.fromByteArray( + GetObjectResponse.builder().build(), + "large payload content".getBytes(StandardCharsets.UTF_8)); + when(mockS3.getObject(any(GetObjectRequest.class), any(AsyncResponseTransformer.class))) + .thenReturn(CompletableFuture.completedFuture(s3Object)); + + CompletableFuture future = clientWithRegularStore.receiveMessageAsStream(request); + ReceiveStreamMessageResponse response = future.join(); + + assertEquals(1, response.streamMessages().size()); + StreamMessage streamMessage = response.streamMessages().get(0); + assertEquals("msg1", streamMessage.getMessage().messageId()); + assertEquals("large payload content", streamMessage.getMessage().body()); + assertFalse(streamMessage.hasStreamPayload()); + } + + @Test + public void testReceiveMessageAsStream_StreamRetrievalFails_IgnoreNotFoundEnabled_DeletesMessage() { + ExtendedAsyncClientConfiguration config = new ExtendedAsyncClientConfiguration() + .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME) + .withIgnorePayloadNotFound(true); + + AmazonSQSExtendedAsyncClient clientWithIgnore = new AmazonSQSExtendedAsyncClient(mockSqsBackend, config); + + ReceiveMessageRequest request = ReceiveMessageRequest.builder() + .queueUrl(SQS_QUEUE_URL) + .build(); + + String s3Pointer = new PayloadS3Pointer(S3_BUCKET_NAME, "test-key").toJson(); + Message message = Message.builder() + .messageId("msg1") + .body(s3Pointer) + .receiptHandle("receipt1") + .messageAttributes(ImmutableMap.of( + "ExtendedPayloadSize", MessageAttributeValue.builder().stringValue("300000").dataType("Number").build() + )) + .build(); + + ReceiveMessageResponse sqsResponse = ReceiveMessageResponse.builder() + .messages(message) + .build(); + + when(mockSqsBackend.receiveMessage(any(ReceiveMessageRequest.class))) + .thenReturn(CompletableFuture.completedFuture(sqsResponse)); + + CompletableFuture> failedFuture = new CompletableFuture<>(); + failedFuture.completeExceptionally(NoSuchKeyException.builder().build()); + when(mockS3.getObject(any(GetObjectRequest.class), any(AsyncResponseTransformer.class))) + .thenReturn(failedFuture); + + when(mockSqsBackend.deleteMessage(any(DeleteMessageRequest.class))) + .thenReturn(CompletableFuture.completedFuture(DeleteMessageResponse.builder().build())); + + CompletableFuture future = clientWithIgnore.receiveMessageAsStream(request); + ReceiveStreamMessageResponse response = future.join(); + + assertTrue(response.streamMessages().isEmpty()); + + ArgumentCaptor deleteCaptor = ArgumentCaptor.forClass(DeleteMessageRequest.class); + verify(mockSqsBackend).deleteMessage(deleteCaptor.capture()); + assertEquals("receipt1", deleteCaptor.getValue().receiptHandle()); + } + + @Test + public void testReceiveMessageAsStream_StreamRetrievalFails_IgnoreNotFoundDisabled_ThrowsException() { + ExtendedAsyncClientConfiguration config = new ExtendedAsyncClientConfiguration() + .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME) + .withIgnorePayloadNotFound(false); + + AmazonSQSExtendedAsyncClient clientWithIgnore = new AmazonSQSExtendedAsyncClient(mockSqsBackend, config); + ReceiveMessageRequest request = ReceiveMessageRequest.builder() + .queueUrl(SQS_QUEUE_URL) + .build(); + + String s3Pointer = new PayloadS3Pointer(S3_BUCKET_NAME, "test-key").toJson(); + Message message = Message.builder() + .messageId("msg1") + .body(s3Pointer) + .receiptHandle("receipt1") + .messageAttributes(ImmutableMap.of( + "ExtendedPayloadSize", MessageAttributeValue.builder().stringValue("300000").dataType("Number").build() + )) + .build(); + + ReceiveMessageResponse sqsResponse = ReceiveMessageResponse.builder() + .messages(message) + .build(); + + when(mockSqsBackend.receiveMessage(any(ReceiveMessageRequest.class))) + .thenReturn(CompletableFuture.completedFuture(sqsResponse)); + + CompletableFuture> failedFuture = new CompletableFuture<>(); + failedFuture.completeExceptionally(NoSuchKeyException.builder().build()); + when(mockS3.getObject(any(GetObjectRequest.class), any(AsyncResponseTransformer.class))) + .thenReturn(failedFuture); + + assertThrows(CompletionException.class, () -> { + clientWithIgnore.receiveMessageAsStream(request).join(); + }); + } + + @Test + public void testReceiveMessageAsStream_MultipleMessages_MixedTypes() throws IOException { + ExtendedAsyncClientConfiguration config = new ExtendedAsyncClientConfiguration() + .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME) + .withPayloadSizeThreshold(262144); // 256KB + config.setStreamUploadEnabled(true); // Enable stream uploads + + AmazonSQSExtendedAsyncClient clientWithStream = new AmazonSQSExtendedAsyncClient(mockSqsBackend, config); + + ReceiveMessageRequest request = ReceiveMessageRequest.builder() + .queueUrl(SQS_QUEUE_URL) + .build(); + + // Small message (no S3) + Message smallMessage = Message.builder() + .messageId("msg1") + .body("small message") + .receiptHandle("receipt1") + .build(); + + // Large message (with S3 pointer) + String s3Pointer = new PayloadS3Pointer(S3_BUCKET_NAME, "test-key").toJson(); + Message largeMessage = Message.builder() + .messageId("msg2") + .body(s3Pointer) + .receiptHandle("receipt2") + .messageAttributes(ImmutableMap.of( + "ExtendedPayloadSize", MessageAttributeValue.builder().stringValue("300000").dataType("Number").build() + )) + .build(); + + ReceiveMessageResponse sqsResponse = ReceiveMessageResponse.builder() + .messages(Arrays.asList(smallMessage, largeMessage)) + .build(); + + @SuppressWarnings("unchecked") + ResponseInputStream mockStream = mock(ResponseInputStream.class); + when(mockStream.read(any(byte[].class))).thenReturn(-1); + + when(mockSqsBackend.receiveMessage(any(ReceiveMessageRequest.class))) + .thenReturn(CompletableFuture.completedFuture(sqsResponse)); + + @SuppressWarnings("unchecked") + CompletableFuture> futureStream = CompletableFuture.completedFuture(mockStream); + when(mockS3.getObject(any(GetObjectRequest.class), any(AsyncResponseTransformer.class))) + .thenReturn(futureStream); + + CompletableFuture future = clientWithStream.receiveMessageAsStream(request); + ReceiveStreamMessageResponse response = future.join(); + + assertEquals(2, response.streamMessages().size()); + + StreamMessage msg1 = response.streamMessages().get(0); + assertEquals("msg1", msg1.getMessage().messageId()); + assertFalse(msg1.hasStreamPayload()); + + StreamMessage msg2 = response.streamMessages().get(1); + assertEquals("msg2", msg2.getMessage().messageId()); + assertTrue(msg2.hasStreamPayload()); + assertSame(mockStream, msg2.getPayloadStream()); + } + + @Test + public void testSendStreamMessage_LargeFileUpload_StoresInS3AndSendsPointer() { + int fileSizeBytes = 500_000; + String fileContent = generateStringWithLength(fileSizeBytes); + java.io.InputStream fileStream = new java.io.ByteArrayInputStream(fileContent.getBytes(StandardCharsets.UTF_8)); + + SendMessageRequest request = SendMessageRequest.builder() + .queueUrl(SQS_QUEUE_URL) + .messageAttributes(ImmutableMap.of( + "fileName", MessageAttributeValue.builder().stringValue("largefile.json").dataType("String").build(), + "contentType", MessageAttributeValue.builder().stringValue("application/json").dataType("String").build() + )) + .build(); + + ((AmazonSQSExtendedAsyncClient) extendedSqsWithDefaultConfig) + .sendStreamMessage(request, fileStream, fileSizeBytes).join(); + + ArgumentCaptor s3Captor = ArgumentCaptor.forClass(PutObjectRequest.class); + verify(mockS3, times(1)).putObject(s3Captor.capture(), any(AsyncRequestBody.class)); + assertEquals(S3_BUCKET_NAME, s3Captor.getValue().bucket()); + + ArgumentCaptor sqsCaptor = ArgumentCaptor.forClass(SendMessageRequest.class); + verify(mockSqsBackend, times(1)).sendMessage(sqsCaptor.capture()); + + assertTrue(sqsCaptor.getValue().messageBody().contains(S3_BUCKET_NAME)); + assertTrue(sqsCaptor.getValue().messageAttributes().containsKey("fileName")); + assertTrue(sqsCaptor.getValue().messageAttributes().containsKey("contentType")); + assertTrue(sqsCaptor.getValue().messageAttributes() + .containsKey(AmazonSQSExtendedClientUtil.LEGACY_RESERVED_ATTRIBUTE_NAME)); + assertEquals(String.valueOf(fileSizeBytes), + sqsCaptor.getValue().messageAttributes() + .get(AmazonSQSExtendedClientUtil.LEGACY_RESERVED_ATTRIBUTE_NAME).stringValue()); + } + + @Test + public void testSendStreamMessage_SmallPayload_SendsDirectlyWithoutS3() { + String smallMessage = "Small notification message"; + java.io.InputStream stream = new java.io.ByteArrayInputStream(smallMessage.getBytes(StandardCharsets.UTF_8)); + + SendMessageRequest request = SendMessageRequest.builder() + .queueUrl(SQS_QUEUE_URL) + .messageAttributes(ImmutableMap.of( + "messageType", MessageAttributeValue.builder().stringValue("notification").dataType("String").build() + )) + .build(); + + ((AmazonSQSExtendedAsyncClient) extendedSqsWithDefaultConfig) + .sendStreamMessage(request, stream, smallMessage.length()).join(); + + verify(mockS3, never()).putObject(any(PutObjectRequest.class), any(AsyncRequestBody.class)); + + ArgumentCaptor sqsCaptor = ArgumentCaptor.forClass(SendMessageRequest.class); + verify(mockSqsBackend, times(1)).sendMessage(sqsCaptor.capture()); + assertEquals(smallMessage, sqsCaptor.getValue().messageBody()); + + assertFalse(sqsCaptor.getValue().messageAttributes() + .containsKey(AmazonSQSExtendedClientUtil.LEGACY_RESERVED_ATTRIBUTE_NAME)); + } } + diff --git a/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientTest.java b/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientTest.java index 96ac44f..858abb6 100644 --- a/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientTest.java +++ b/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientTest.java @@ -68,6 +68,7 @@ import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; +import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.ArgumentMatchers.eq; @@ -82,6 +83,12 @@ import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; +import java.io.IOException; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; +import java.io.ByteArrayInputStream; + + /** * Tests the AmazonSQSExtendedClient class. */ @@ -113,6 +120,13 @@ public class AmazonSQSExtendedClientTest { // should be > 1 and << SQS_SIZE_LIMIT private static final int ARBITRARY_SMALLER_THRESHOLD = 500; + + // Stream upload thresholds + private static final int STREAM_UPLOAD_THRESHOLD = 5 * 1024 * 1024; // 5MB default + private static final int MORE_THAN_STREAM_THRESHOLD = STREAM_UPLOAD_THRESHOLD + 1; + + // Stream part size + private static final int STREAM_UPLOAD_PART_SIZE = 5 * 1024 * 1024; // 5MB @BeforeEach public void setupClients() { @@ -777,4 +791,204 @@ private String getLargeReceiptHandle(String s3Key, String originalReceiptHandle) private String getSampleLargeReceiptHandle(String originalReceiptHandle) { return getLargeReceiptHandle(UUID.randomUUID().toString(), originalReceiptHandle); } + + @Test + public void testReceiveMessageAsStream_LargeMessage_WithStreamStore_ReturnsMessageWithStream() throws IOException { + String largeMessageBody = generateStringWithLength(MORE_THAN_STREAM_THRESHOLD); + + ResponseInputStream mockStream = mock(ResponseInputStream.class); + when(mockStream.read(any(byte[].class))).thenReturn(-1); + when(mockS3.getObject(isA(GetObjectRequest.class))).thenReturn(mockStream); + + String s3Key = "stream-key-" + UUID.randomUUID(); + String pointer = new PayloadS3Pointer(S3_BUCKET_NAME, s3Key).toJson(); + Message sqsMessage = Message.builder() + .messageAttributes(ImmutableMap.of(SQSExtendedClientConstants.RESERVED_ATTRIBUTE_NAME, + MessageAttributeValue.builder().dataType("Number").stringValue(String.valueOf(largeMessageBody.length())).build())) + .body(pointer) + .receiptHandle("test-receipt-handle") + .build(); + + when(mockSqsBackend.receiveMessage(isA(ReceiveMessageRequest.class))) + .thenReturn(ReceiveMessageResponse.builder().messages(sqsMessage).build()); + + ExtendedClientConfiguration extendedClientConfiguration = new ExtendedClientConfiguration() + .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME) + .withStreamUploadEnabled(true) + .withStreamUploadPartSize(STREAM_UPLOAD_PART_SIZE) + .withStreamUploadThreshold(STREAM_UPLOAD_THRESHOLD); + AmazonSQSExtendedClient sqsExtended = new AmazonSQSExtendedClient(mockSqsBackend, extendedClientConfiguration); + + ReceiveMessageRequest request = ReceiveMessageRequest.builder().queueUrl(SQS_QUEUE_URL).build(); + ReceiveStreamMessageResponse response = sqsExtended.receiveMessageAsStream(request); + sqsExtended.close(); + + assertEquals(1, response.streamMessages().size()); + StreamMessage streamMessage = response.streamMessages().get(0); + + assertTrue(streamMessage.hasStreamPayload()); + + ResponseInputStream payloadStream = streamMessage.getPayloadStream(); + assertNotNull(payloadStream); + + verify(mockS3, times(1)).getObject(isA(GetObjectRequest.class)); + assertEquals("-..s3BucketName..-test-bucket-name-..s3BucketName..--..s3Key..-stream-key-test-s3-key-uuid-..s3Key..-test-receipt-handle", streamMessage.getMessage().receiptHandle()); + } + + @Test + public void testSendStreamMessage_LargeFileUpload_StoresInS3AndSendsPointer() { + int fileSizeBytes = MORE_THAN_STREAM_THRESHOLD; + String fileContent = generateStringWithLength(fileSizeBytes); + InputStream fileStream = new ByteArrayInputStream(fileContent.getBytes(StandardCharsets.UTF_8)); + + ExtendedClientConfiguration streamConfig = new ExtendedClientConfiguration() + .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME) + .withStreamUploadEnabled(true) + .withStreamUploadThreshold(STREAM_UPLOAD_THRESHOLD); + + SqsClient streamClient = spy(new AmazonSQSExtendedClient(mockSqsBackend, streamConfig)); + + SendMessageRequest request = SendMessageRequest.builder() + .queueUrl(SQS_QUEUE_URL) + .messageAttributes(ImmutableMap.of( + "fileName", MessageAttributeValue.builder().stringValue("largefile.json").dataType("String").build(), + "contentType", MessageAttributeValue.builder().stringValue("application/json").dataType("String").build() + )) + .build(); + + ((AmazonSQSExtendedClient) streamClient).sendStreamMessage(request, fileStream, fileSizeBytes); + + ArgumentCaptor sqsCaptor = ArgumentCaptor.forClass(SendMessageRequest.class); + verify(mockSqsBackend, times(1)).sendMessage(sqsCaptor.capture()); + + assertTrue(sqsCaptor.getValue().messageBody().contains(S3_BUCKET_NAME)); + assertTrue(sqsCaptor.getValue().messageAttributes().containsKey("fileName")); + assertTrue(sqsCaptor.getValue().messageAttributes().containsKey("contentType")); + assertTrue(sqsCaptor.getValue().messageAttributes() + .containsKey(AmazonSQSExtendedClientUtil.LEGACY_RESERVED_ATTRIBUTE_NAME)); + assertEquals(String.valueOf(fileSizeBytes), + sqsCaptor.getValue().messageAttributes() + .get(AmazonSQSExtendedClientUtil.LEGACY_RESERVED_ATTRIBUTE_NAME).stringValue()); + } + + @Test + public void testSendStreamMessage_SmallPayload_SendsDirectlyWithoutS3() { + String smallMessage = "Small notification message"; + java.io.InputStream stream = new java.io.ByteArrayInputStream(smallMessage.getBytes(java.nio.charset.StandardCharsets.UTF_8)); + + SendMessageRequest request = SendMessageRequest.builder() + .queueUrl(SQS_QUEUE_URL) + .messageAttributes(ImmutableMap.of( + "messageType", MessageAttributeValue.builder().stringValue("notification").dataType("String").build() + )) + .build(); + + ((AmazonSQSExtendedClient) extendedSqsWithDefaultConfig).sendStreamMessage(request, stream, smallMessage.length()); + + verify(mockS3, never()).putObject(any(PutObjectRequest.class), any(RequestBody.class)); + ArgumentCaptor sqsCaptor = ArgumentCaptor.forClass(SendMessageRequest.class); + verify(mockSqsBackend, times(1)).sendMessage(sqsCaptor.capture()); + assertEquals(smallMessage, sqsCaptor.getValue().messageBody()); + assertFalse(sqsCaptor.getValue().messageAttributes() + .containsKey(AmazonSQSExtendedClientUtil.LEGACY_RESERVED_ATTRIBUTE_NAME)); + } + + @Test + public void testSendStreamMessage_WithS3KeyPrefix() { + int dataSize = MORE_THAN_SQS_SIZE_LIMIT; + String largeData = generateStringWithLength(dataSize); + java.io.InputStream stream = new java.io.ByteArrayInputStream(largeData.getBytes(java.nio.charset.StandardCharsets.UTF_8)); + + SendMessageRequest request = SendMessageRequest.builder() + .queueUrl(SQS_QUEUE_URL) + .build(); + + ((AmazonSQSExtendedClient) extendedSqsWithS3KeyPrefix).sendStreamMessage(request, stream, dataSize); + ArgumentCaptor s3Captor = ArgumentCaptor.forClass(PutObjectRequest.class); + verify(mockS3, times(1)).putObject(s3Captor.capture(), any(RequestBody.class)); + assertTrue(s3Captor.getValue().key().startsWith(S3_KEY_PREFIX), + "S3 key should start with prefix: " + S3_KEY_PREFIX); + } + + @Test + public void testSendStreamMessageBatch_MixedSizes_OnlyLargeMessagesUseS3() { + ExtendedClientConfiguration streamConfig = new ExtendedClientConfiguration() + .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME) + .withStreamUploadEnabled(true) + .withStreamUploadThreshold(STREAM_UPLOAD_THRESHOLD); // 5MB + + SqsClient streamClient = spy(new AmazonSQSExtendedClient(mockSqsBackend, streamConfig)); + + List entries = new ArrayList<>(); + List streams = new ArrayList<>(); + List contentLengths = new ArrayList<>(); + + // Small message (100KB - below SQS limit) + String smallMsg = generateStringWithLength(100_000); + entries.add(SendMessageBatchRequestEntry.builder() + .id("msg1") + .messageAttributes(ImmutableMap.of("size", MessageAttributeValue.builder().stringValue("small").dataType("String").build())) + .build()); + streams.add(new ByteArrayInputStream(smallMsg.getBytes(StandardCharsets.UTF_8))); + contentLengths.add((long) smallMsg.length()); + + // Large message (300KB - above SQS limit but below stream threshold) + String largeMsg = generateStringWithLength(300_000); + entries.add(SendMessageBatchRequestEntry.builder() + .id("msg2") + .messageAttributes(ImmutableMap.of("size", MessageAttributeValue.builder().stringValue("large").dataType("String").build())) + .build()); + streams.add(new java.io.ByteArrayInputStream(largeMsg.getBytes(StandardCharsets.UTF_8))); + contentLengths.add((long) largeMsg.length()); + + // Very large message (6MB - above stream threshold, uses multipart) + String veryLargeMsg = generateStringWithLength(MORE_THAN_STREAM_THRESHOLD); + entries.add(SendMessageBatchRequestEntry.builder() + .id("msg3") + .build()); + streams.add(new java.io.ByteArrayInputStream(veryLargeMsg.getBytes(StandardCharsets.UTF_8))); + contentLengths.add((long) veryLargeMsg.length()); + + SendMessageBatchRequest batchRequest = SendMessageBatchRequest.builder() + .queueUrl(SQS_QUEUE_URL) + .entries(entries) + .build(); + + ((AmazonSQSExtendedClient) streamClient).sendStreamMessageBatch(batchRequest, streams, contentLengths); + + ArgumentCaptor sqsCaptor = ArgumentCaptor.forClass(SendMessageBatchRequest.class); + verify(mockSqsBackend, times(1)).sendMessageBatch(sqsCaptor.capture()); + SendMessageBatchRequestEntry firstEntry = sqsCaptor.getValue().entries().get(0); + assertEquals(smallMsg, firstEntry.messageBody()); + SendMessageBatchRequestEntry secondEntry = sqsCaptor.getValue().entries().get(1); + assertTrue(secondEntry.messageBody().contains(S3_BUCKET_NAME)); + assertTrue(secondEntry.messageAttributes().containsKey(AmazonSQSExtendedClientUtil.LEGACY_RESERVED_ATTRIBUTE_NAME)); + assertTrue(secondEntry.messageAttributes().containsKey("size")); + assertEquals("large", secondEntry.messageAttributes().get("size").stringValue()); + SendMessageBatchRequestEntry thirdEntry = sqsCaptor.getValue().entries().get(2); + assertTrue(thirdEntry.messageBody().contains(S3_BUCKET_NAME)); + assertTrue(thirdEntry.messageAttributes().containsKey(AmazonSQSExtendedClientUtil.LEGACY_RESERVED_ATTRIBUTE_NAME)); + } + + @Test + public void testSendStreamMessage_WithEncryption_AppliesKMSToS3Upload() { + int dataSize = MORE_THAN_SQS_SIZE_LIMIT; + String sensitiveData = generateStringWithLength(dataSize); + InputStream stream = new ByteArrayInputStream(sensitiveData.getBytes(StandardCharsets.UTF_8)); + + SendMessageRequest request = SendMessageRequest.builder() + .queueUrl(SQS_QUEUE_URL) + .messageAttributes(ImmutableMap.of( + "dataType", MessageAttributeValue.builder().stringValue("sensitive").dataType("String").build() + )) + .build(); + + ((AmazonSQSExtendedClient) extendedSqsWithCustomKMS).sendStreamMessage(request, stream, dataSize); + + ArgumentCaptor s3Captor = ArgumentCaptor.forClass(PutObjectRequest.class); + verify(mockS3, times(1)).putObject(s3Captor.capture(), any(RequestBody.class)); + assertEquals(S3_SERVER_SIDE_ENCRYPTION_KMS_KEY_ID, s3Captor.getValue().ssekmsKeyId()); + verify(mockSqsBackend, times(1)).sendMessage(any(SendMessageRequest.class)); + } + } diff --git a/src/test/java/com/amazon/sqs/javamessaging/ExtendedAsyncClientConfigurationTest.java b/src/test/java/com/amazon/sqs/javamessaging/ExtendedAsyncClientConfigurationTest.java index 879f098..b772586 100644 --- a/src/test/java/com/amazon/sqs/javamessaging/ExtendedAsyncClientConfigurationTest.java +++ b/src/test/java/com/amazon/sqs/javamessaging/ExtendedAsyncClientConfigurationTest.java @@ -85,4 +85,12 @@ public void testLargePayloadSupportEnabledWithDeleteFromS3Disabled() { assertNotNull(extendedClientConfiguration.getS3AsyncClient()); assertEquals(s3BucketName, extendedClientConfiguration.getS3BucketName()); } + + @Test + public void testStreamUploadEnabledEnabled() { + ExtendedAsyncClientConfiguration extendedClientConfiguration = new ExtendedAsyncClientConfiguration(); + extendedClientConfiguration.withStreamUploadEnabled(true); + + assertTrue(extendedClientConfiguration.isStreamUploadEnabled()); + } } diff --git a/src/test/java/com/amazon/sqs/javamessaging/ExtendedClientConfigurationTest.java b/src/test/java/com/amazon/sqs/javamessaging/ExtendedClientConfigurationTest.java index 2dc5b6b..e7b3757 100644 --- a/src/test/java/com/amazon/sqs/javamessaging/ExtendedClientConfigurationTest.java +++ b/src/test/java/com/amazon/sqs/javamessaging/ExtendedClientConfigurationTest.java @@ -224,4 +224,58 @@ public void testS3keyPrefixWithALargeString() { assertThrows(SdkClientException.class, () -> extendedClientConfiguration.withS3KeyPrefix(s3KeyPrefix)); } + + @Test + public void testStreamUploadEnabled() { + ExtendedClientConfiguration extendedClientConfiguration = new ExtendedClientConfiguration(); + extendedClientConfiguration.withStreamUploadEnabled(true); + + assertTrue(extendedClientConfiguration.isStreamUploadEnabled()); + } + + @Test + public void testStreamUploadThresholdCustomValue() { + int customThreshold = 10 * 1024 * 1024; // 10MB + ExtendedClientConfiguration extendedClientConfiguration = new ExtendedClientConfiguration(); + extendedClientConfiguration.withStreamUploadThreshold(customThreshold); + + assertEquals(customThreshold, extendedClientConfiguration.getStreamUploadThreshold()); + } + + @Test + public void testStreamUploadPartSizeCustomValue() { + int customPartSize = 10 * 1024 * 1024; // 10MB + ExtendedClientConfiguration extendedClientConfiguration = new ExtendedClientConfiguration(); + extendedClientConfiguration.withStreamUploadPartSize(customPartSize); + + assertEquals(customPartSize, extendedClientConfiguration.getStreamUploadPartSize()); + } + + @Test + public void testStreamUploadPartSizeBelowMinimumRoundedUpTo5MB() { + int belowMinimum = 3 * 1024 * 1024; // 3MB (below 5MB minimum) + ExtendedClientConfiguration extendedClientConfiguration = new ExtendedClientConfiguration(); + extendedClientConfiguration.withStreamUploadPartSize(belowMinimum); + + assertEquals(5 * 1024 * 1024, extendedClientConfiguration.getStreamUploadPartSize()); + } + + @Test + public void testStreamConfigurationInCopyConstructor() { + S3Client s3 = mock(S3Client.class); + int customThreshold = 10 * 1024 * 1024; + int customPartSize = 8 * 1024 * 1024; + + ExtendedClientConfiguration originalConfig = new ExtendedClientConfiguration(); + originalConfig.withPayloadSupportEnabled(s3, s3BucketName) + .withStreamUploadEnabled(true) + .withStreamUploadThreshold(customThreshold) + .withStreamUploadPartSize(customPartSize); + + ExtendedClientConfiguration copiedConfig = new ExtendedClientConfiguration(originalConfig); + + assertTrue(copiedConfig.isStreamUploadEnabled()); + assertEquals(customThreshold, copiedConfig.getStreamUploadThreshold()); + assertEquals(customPartSize, copiedConfig.getStreamUploadPartSize()); + } }