batchEntries = new ArrayList<>(sendMessageBatchRequest.entries().size());
+
+ boolean hasS3Entries = false;
+ for (int i = 0; i < sendMessageBatchRequest.entries().size(); i++) {
+ SendMessageBatchRequestEntry entry = sendMessageBatchRequest.entries().get(i);
+ InputStream stream = messageBodyStreams.get(i);
+ long contentLength = contentLengths.get(i);
+
+ //Check message attributes for ExtendedClient related constraints
+ checkMessageAttributes(clientConfiguration.getPayloadSizeThreshold(), entry.messageAttributes());
+
+ if (clientConfiguration.isAlwaysThroughS3()
+ || contentLength >= clientConfiguration.getPayloadSizeThreshold()) {
+ entry = storeStreamMessageInS3(entry, stream, contentLength);
+ hasS3Entries = true;
+ } else {
+ // Convert stream to string for small messages
+ try {
+ String messageBody = IoUtils.toUtf8String(stream);
+ entry = entry.toBuilder().messageBody(messageBody).build();
+ } catch (IOException e) {
+ throw new RuntimeException("Failed to read from InputStream", e);
+ }
+ }
+ batchEntries.add(entry);
+ }
+
+ if (hasS3Entries) {
+ sendMessageBatchRequest = sendMessageBatchRequest.toBuilder().entries(batchEntries).build();
+ }
+
+ return super.sendMessageBatch(sendMessageBatchRequest);
+ }
+
/**
*
* Deletes up to ten messages from the specified queue. This is a batch version of
@@ -900,6 +1145,68 @@ private String storeOriginalPayload(String messageContentStr) {
}
return payloadStore.storeOriginalPayload(messageContentStr, s3KeyPrefix + UUID.randomUUID());
}
+
+ private SendMessageRequest storeStreamMessageInS3(SendMessageRequest sendMessageRequest, InputStream messageBodyStream, long contentLength) {
+ SendMessageRequest.Builder sendMessageRequestBuilder = sendMessageRequest.toBuilder();
+
+ sendMessageRequestBuilder.messageAttributes(
+ updateMessageAttributePayloadSize(sendMessageRequest.messageAttributes(), contentLength,
+ clientConfiguration.usesLegacyReservedAttributeName()));
+
+ // Store the message content in S3.
+ String largeMessagePointer = storeOriginalPayload(messageBodyStream);
+ sendMessageRequestBuilder.messageBody(largeMessagePointer);
+
+ return sendMessageRequestBuilder.build();
+ }
+
+ private SendMessageBatchRequestEntry storeStreamMessageInS3(SendMessageBatchRequestEntry batchEntry, InputStream messageBodyStream, long contentLength) {
+ SendMessageBatchRequestEntry.Builder batchEntryBuilder = batchEntry.toBuilder();
+
+ batchEntryBuilder.messageAttributes(
+ updateMessageAttributePayloadSize(batchEntry.messageAttributes(), contentLength,
+ clientConfiguration.usesLegacyReservedAttributeName()));
+
+ // Store the message content in S3.
+ String largeMessagePointer = storeOriginalPayload(messageBodyStream);
+ batchEntryBuilder.messageBody(largeMessagePointer);
+
+ return batchEntryBuilder.build();
+ }
+
+ private String storeOriginalPayload(InputStream messageContentStream) {
+ String s3KeyPrefix = clientConfiguration.getS3KeyPrefix();
+ String key = StringUtils.isBlank(s3KeyPrefix) ? UUID.randomUUID().toString() : s3KeyPrefix + UUID.randomUUID();
+
+ if (payloadStore instanceof StreamPayloadStore) {
+ try {
+ return ((StreamPayloadStore) payloadStore).storeOriginalPayloadStream(messageContentStream, key);
+ } catch (RuntimeException e) {
+ LOG.warn("Stream upload attempt failed; falling back to standard single-part upload.");
+ try {
+ messageContentStream.reset();
+ } catch (Exception resetEx) {
+ LOG.warn("Failed to reset stream after multipart upload failure, stream may be exhausted", resetEx);
+ throw e;
+ }
+ // Fall back to reading the stream and using string method
+ try {
+ String content = IoUtils.toUtf8String(messageContentStream);
+ return payloadStore.storeOriginalPayload(content, key);
+ } catch (IOException ioEx) {
+ throw new RuntimeException("Failed to read from InputStream", ioEx);
+ }
+ }
+ }
+
+ // Fall back to reading the stream and using string method
+ try {
+ String content = IoUtils.toUtf8String(messageContentStream);
+ return payloadStore.storeOriginalPayload(content, key);
+ } catch (IOException e) {
+ throw new RuntimeException("Failed to read from InputStream", e);
+ }
+ }
@SuppressWarnings("unchecked")
private static T appendUserAgent(final T builder) {
diff --git a/src/main/java/com/amazon/sqs/javamessaging/ExtendedAsyncClientConfiguration.java b/src/main/java/com/amazon/sqs/javamessaging/ExtendedAsyncClientConfiguration.java
index e96fff2..281bf3b 100644
--- a/src/main/java/com/amazon/sqs/javamessaging/ExtendedAsyncClientConfiguration.java
+++ b/src/main/java/com/amazon/sqs/javamessaging/ExtendedAsyncClientConfiguration.java
@@ -212,4 +212,10 @@ public ExtendedAsyncClientConfiguration withServerSideEncryption(ServerSideEncry
this.setServerSideEncryptionStrategy(serverSideEncryption);
return this;
}
+
+ @Override
+ public ExtendedAsyncClientConfiguration withStreamUploadEnabled(boolean enabled) {
+ setStreamUploadEnabled(enabled);
+ return this;
+ }
}
\ No newline at end of file
diff --git a/src/main/java/com/amazon/sqs/javamessaging/ExtendedClientConfiguration.java b/src/main/java/com/amazon/sqs/javamessaging/ExtendedClientConfiguration.java
index 75a30f8..9e0d41e 100644
--- a/src/main/java/com/amazon/sqs/javamessaging/ExtendedClientConfiguration.java
+++ b/src/main/java/com/amazon/sqs/javamessaging/ExtendedClientConfiguration.java
@@ -232,6 +232,26 @@ public ExtendedClientConfiguration withServerSideEncryption(ServerSideEncryption
return this;
}
+ /**
+ * Enables or disables stream upload support for large payload storage operations.
+ * @param enabled true to enable stream uploads when threshold exceeded.
+ * @return updated configuration
+ */
+ public ExtendedClientConfiguration withStreamUploadEnabled(boolean enabled) {
+ setStreamUploadEnabled(enabled);
+ return this;
+ }
+
+ public ExtendedClientConfiguration withStreamUploadThreshold(int threshold) {
+ setStreamUploadThreshold(threshold);
+ return this;
+ }
+
+ public ExtendedClientConfiguration withStreamUploadPartSize(int partSize) {
+ setStreamUploadPartSize(partSize);
+ return this;
+ }
+
/**
* Enables support for large-payload messages.
*
diff --git a/src/main/java/com/amazon/sqs/javamessaging/ReceiveStreamMessageResponse.java b/src/main/java/com/amazon/sqs/javamessaging/ReceiveStreamMessageResponse.java
new file mode 100644
index 0000000..c93d458
--- /dev/null
+++ b/src/main/java/com/amazon/sqs/javamessaging/ReceiveStreamMessageResponse.java
@@ -0,0 +1,21 @@
+package com.amazon.sqs.javamessaging;
+
+import java.util.List;
+
+/**
+ * Response containing messages with streaming access to large payloads.
+ */
+public class ReceiveStreamMessageResponse {
+ private final List streamMessages;
+
+ public ReceiveStreamMessageResponse(List streamMessages) {
+ this.streamMessages = streamMessages;
+ }
+
+ /**
+ * @return list of messages with streaming payload access
+ */
+ public List streamMessages() {
+ return streamMessages;
+ }
+}
\ No newline at end of file
diff --git a/src/main/java/com/amazon/sqs/javamessaging/StreamMessage.java b/src/main/java/com/amazon/sqs/javamessaging/StreamMessage.java
new file mode 100644
index 0000000..9f08986
--- /dev/null
+++ b/src/main/java/com/amazon/sqs/javamessaging/StreamMessage.java
@@ -0,0 +1,40 @@
+package com.amazon.sqs.javamessaging;
+
+import software.amazon.awssdk.core.ResponseInputStream;
+import software.amazon.awssdk.services.s3.model.GetObjectResponse;
+import software.amazon.awssdk.services.sqs.model.Message;
+
+/**
+ * Represents a message received from SQS that can provide streaming access to large payloads stored in S3.
+ * Instead of loading the entire payload into memory, this provides access to the payload as a stream.
+ */
+public class StreamMessage {
+ private final Message originalMessage;
+ private final ResponseInputStream payloadStream;
+
+ public StreamMessage(Message originalMessage, ResponseInputStream payloadStream) {
+ this.originalMessage = originalMessage;
+ this.payloadStream = payloadStream;
+ }
+
+ /**
+ * @return the original SQS message metadata (messageId, receiptHandle, attributes, etc.)
+ */
+ public Message getMessage() {
+ return originalMessage;
+ }
+
+ /**
+ * @return stream for accessing the message payload, or null if payload is not stored in S3
+ */
+ public ResponseInputStream getPayloadStream() {
+ return payloadStream;
+ }
+
+ /**
+ * @return true if this message has a streaming payload from S3
+ */
+ public boolean hasStreamPayload() {
+ return payloadStream != null;
+ }
+}
\ No newline at end of file
diff --git a/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedAsyncClientTest.java b/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedAsyncClientTest.java
index 219e57d..e3a42e9 100644
--- a/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedAsyncClientTest.java
+++ b/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedAsyncClientTest.java
@@ -6,9 +6,12 @@
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotEquals;
import static org.junit.jupiter.api.Assertions.assertNull;
+import static org.junit.jupiter.api.Assertions.assertSame;
+import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;
import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.isA;
import static org.mockito.Mockito.doThrow;
@@ -20,6 +23,8 @@
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when;
+import java.io.IOException;
+import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
@@ -38,6 +43,7 @@
import software.amazon.awssdk.core.ResponseBytes;
import software.amazon.awssdk.core.async.AsyncRequestBody;
import software.amazon.awssdk.core.async.AsyncResponseTransformer;
+import software.amazon.awssdk.core.exception.SdkClientException;
import software.amazon.awssdk.core.exception.SdkException;
import software.amazon.awssdk.services.s3.S3AsyncClient;
import software.amazon.awssdk.services.s3.model.DeleteObjectRequest;
@@ -63,10 +69,14 @@
import software.amazon.awssdk.services.sqs.model.SendMessageBatchResponse;
import software.amazon.awssdk.services.sqs.model.SendMessageRequest;
import software.amazon.awssdk.services.sqs.model.SendMessageResponse;
-import software.amazon.awssdk.utils.ImmutableMap;
+import software.amazon.awssdk.core.ResponseInputStream;
+import software.amazon.awssdk.services.s3.model.GetObjectResponse;
+import software.amazon.payloadoffloading.StreamPayloadStoreAsync;
+import software.amazon.payloadoffloading.PayloadStoreAsync;
import software.amazon.payloadoffloading.PayloadS3Pointer;
import software.amazon.payloadoffloading.ServerSideEncryptionFactory;
import software.amazon.payloadoffloading.ServerSideEncryptionStrategy;
+import software.amazon.awssdk.utils.ImmutableMap;
public class AmazonSQSExtendedAsyncClientTest {
@@ -88,6 +98,11 @@ public class AmazonSQSExtendedAsyncClientTest {
// should be > 1 and << SQS_SIZE_LIMIT
private static final int ARBITRARY_SMALLER_THRESHOLD = 500;
+
+ // Stream upload thresholds
+ private static final int STREAM_UPLOAD_THRESHOLD = 5 * 1024 * 1024; // 5MB default
+ private static final int LESS_THAN_STREAM_THRESHOLD = STREAM_UPLOAD_THRESHOLD - 1;
+ private static final int MORE_THAN_STREAM_THRESHOLD = STREAM_UPLOAD_THRESHOLD + 1;
@BeforeEach
public void setupClients() {
@@ -729,4 +744,338 @@ private String generateStringWithLength(int messageLength) {
Arrays.fill(charArray, 'x');
return new String(charArray);
}
+
+ @Test
+ public void testReceiveMessageAsStream_PayloadSupportDisabled_ReturnsMessagesWithoutStreams() {
+ ExtendedAsyncClientConfiguration config = new ExtendedAsyncClientConfiguration()
+ .withPayloadSupportDisabled();
+ AmazonSQSExtendedAsyncClient clientWithDisabledPayload = new AmazonSQSExtendedAsyncClient(mockSqsBackend, config);
+
+ ReceiveMessageRequest request = ReceiveMessageRequest.builder()
+ .queueUrl(SQS_QUEUE_URL)
+ .build();
+
+ Message message = Message.builder()
+ .messageId("msg1")
+ .body("small message")
+ .receiptHandle("receipt1")
+ .build();
+
+ ReceiveMessageResponse sqsResponse = ReceiveMessageResponse.builder()
+ .messages(message)
+ .build();
+
+ when(mockSqsBackend.receiveMessage(any(ReceiveMessageRequest.class)))
+ .thenReturn(CompletableFuture.completedFuture(sqsResponse));
+
+ CompletableFuture future = clientWithDisabledPayload.receiveMessageAsStream(request);
+ ReceiveStreamMessageResponse response = future.join();
+
+ assertEquals(1, response.streamMessages().size());
+ StreamMessage streamMessage = response.streamMessages().get(0);
+ assertEquals("msg1", streamMessage.getMessage().messageId());
+ assertEquals("small message", streamMessage.getMessage().body());
+ assertFalse(streamMessage.hasStreamPayload());
+ }
+
+ @Test
+ public void testReceiveMessageAsStream_LargeMessage_WithStreamStore_ReturnsMessageWithStream() throws IOException {
+ ExtendedAsyncClientConfiguration config = new ExtendedAsyncClientConfiguration()
+ .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME)
+ .withPayloadSizeThreshold(262144); // 256KB
+ config.setStreamUploadEnabled(true);
+
+ AmazonSQSExtendedAsyncClient clientWithStream = new AmazonSQSExtendedAsyncClient(mockSqsBackend, config);
+
+ ReceiveMessageRequest request = ReceiveMessageRequest.builder()
+ .queueUrl(SQS_QUEUE_URL)
+ .build();
+
+ String s3Pointer = new PayloadS3Pointer(S3_BUCKET_NAME, "test-key").toJson();
+ Message message = Message.builder()
+ .messageId("msg1")
+ .body(s3Pointer)
+ .receiptHandle("receipt1")
+ .messageAttributes(ImmutableMap.of(
+ "ExtendedPayloadSize", MessageAttributeValue.builder().stringValue("300000").dataType("Number").build()
+ ))
+ .build();
+
+ ReceiveMessageResponse sqsResponse = ReceiveMessageResponse.builder()
+ .messages(message)
+ .build();
+
+ @SuppressWarnings("unchecked")
+ ResponseInputStream mockStream = mock(ResponseInputStream.class);
+ when(mockStream.read(any(byte[].class))).thenReturn(-1);
+
+ when(mockSqsBackend.receiveMessage(any(ReceiveMessageRequest.class)))
+ .thenReturn(CompletableFuture.completedFuture(sqsResponse));
+
+ @SuppressWarnings("unchecked")
+ CompletableFuture> futureStream = CompletableFuture.completedFuture(mockStream);
+ when(mockS3.getObject(any(GetObjectRequest.class), any(AsyncResponseTransformer.class)))
+ .thenReturn(futureStream);
+
+ CompletableFuture future = clientWithStream.receiveMessageAsStream(request);
+ ReceiveStreamMessageResponse response = future.join();
+
+ assertEquals(1, response.streamMessages().size());
+ StreamMessage streamMessage = response.streamMessages().get(0);
+ assertEquals("msg1", streamMessage.getMessage().messageId());
+ assertTrue(streamMessage.hasStreamPayload());
+ assertSame(mockStream, streamMessage.getPayloadStream());
+ assertTrue(streamMessage.getMessage().receiptHandle().contains("test-key"));
+ }
+
+ @Test
+ public void testReceiveMessageAsStream_LargeMessage_WithoutStreamStore_FallsBackToRegularRetrieval() {
+ ExtendedAsyncClientConfiguration config = new ExtendedAsyncClientConfiguration()
+ .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME);
+
+ AmazonSQSExtendedAsyncClient clientWithRegularStore = new AmazonSQSExtendedAsyncClient(mockSqsBackend, config);
+
+ ReceiveMessageRequest request = ReceiveMessageRequest.builder()
+ .queueUrl(SQS_QUEUE_URL)
+ .build();
+
+ String s3Pointer = new PayloadS3Pointer(S3_BUCKET_NAME, "test-key").toJson();
+ Message message = Message.builder()
+ .messageId("msg1")
+ .body(s3Pointer)
+ .receiptHandle("receipt1")
+ .messageAttributes(ImmutableMap.of(
+ "ExtendedPayloadSize", MessageAttributeValue.builder().stringValue("300000").dataType("Number").build()
+ ))
+ .build();
+
+ ReceiveMessageResponse sqsResponse = ReceiveMessageResponse.builder()
+ .messages(message)
+ .build();
+
+ when(mockSqsBackend.receiveMessage(any(ReceiveMessageRequest.class)))
+ .thenReturn(CompletableFuture.completedFuture(sqsResponse));
+
+ ResponseBytes s3Object = ResponseBytes.fromByteArray(
+ GetObjectResponse.builder().build(),
+ "large payload content".getBytes(StandardCharsets.UTF_8));
+ when(mockS3.getObject(any(GetObjectRequest.class), any(AsyncResponseTransformer.class)))
+ .thenReturn(CompletableFuture.completedFuture(s3Object));
+
+ CompletableFuture future = clientWithRegularStore.receiveMessageAsStream(request);
+ ReceiveStreamMessageResponse response = future.join();
+
+ assertEquals(1, response.streamMessages().size());
+ StreamMessage streamMessage = response.streamMessages().get(0);
+ assertEquals("msg1", streamMessage.getMessage().messageId());
+ assertEquals("large payload content", streamMessage.getMessage().body());
+ assertFalse(streamMessage.hasStreamPayload());
+ }
+
+ @Test
+ public void testReceiveMessageAsStream_StreamRetrievalFails_IgnoreNotFoundEnabled_DeletesMessage() {
+ ExtendedAsyncClientConfiguration config = new ExtendedAsyncClientConfiguration()
+ .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME)
+ .withIgnorePayloadNotFound(true);
+
+ AmazonSQSExtendedAsyncClient clientWithIgnore = new AmazonSQSExtendedAsyncClient(mockSqsBackend, config);
+
+ ReceiveMessageRequest request = ReceiveMessageRequest.builder()
+ .queueUrl(SQS_QUEUE_URL)
+ .build();
+
+ String s3Pointer = new PayloadS3Pointer(S3_BUCKET_NAME, "test-key").toJson();
+ Message message = Message.builder()
+ .messageId("msg1")
+ .body(s3Pointer)
+ .receiptHandle("receipt1")
+ .messageAttributes(ImmutableMap.of(
+ "ExtendedPayloadSize", MessageAttributeValue.builder().stringValue("300000").dataType("Number").build()
+ ))
+ .build();
+
+ ReceiveMessageResponse sqsResponse = ReceiveMessageResponse.builder()
+ .messages(message)
+ .build();
+
+ when(mockSqsBackend.receiveMessage(any(ReceiveMessageRequest.class)))
+ .thenReturn(CompletableFuture.completedFuture(sqsResponse));
+
+ CompletableFuture> failedFuture = new CompletableFuture<>();
+ failedFuture.completeExceptionally(NoSuchKeyException.builder().build());
+ when(mockS3.getObject(any(GetObjectRequest.class), any(AsyncResponseTransformer.class)))
+ .thenReturn(failedFuture);
+
+ when(mockSqsBackend.deleteMessage(any(DeleteMessageRequest.class)))
+ .thenReturn(CompletableFuture.completedFuture(DeleteMessageResponse.builder().build()));
+
+ CompletableFuture future = clientWithIgnore.receiveMessageAsStream(request);
+ ReceiveStreamMessageResponse response = future.join();
+
+ assertTrue(response.streamMessages().isEmpty());
+
+ ArgumentCaptor deleteCaptor = ArgumentCaptor.forClass(DeleteMessageRequest.class);
+ verify(mockSqsBackend).deleteMessage(deleteCaptor.capture());
+ assertEquals("receipt1", deleteCaptor.getValue().receiptHandle());
+ }
+
+ @Test
+ public void testReceiveMessageAsStream_StreamRetrievalFails_IgnoreNotFoundDisabled_ThrowsException() {
+ ExtendedAsyncClientConfiguration config = new ExtendedAsyncClientConfiguration()
+ .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME)
+ .withIgnorePayloadNotFound(false);
+
+ AmazonSQSExtendedAsyncClient clientWithIgnore = new AmazonSQSExtendedAsyncClient(mockSqsBackend, config);
+ ReceiveMessageRequest request = ReceiveMessageRequest.builder()
+ .queueUrl(SQS_QUEUE_URL)
+ .build();
+
+ String s3Pointer = new PayloadS3Pointer(S3_BUCKET_NAME, "test-key").toJson();
+ Message message = Message.builder()
+ .messageId("msg1")
+ .body(s3Pointer)
+ .receiptHandle("receipt1")
+ .messageAttributes(ImmutableMap.of(
+ "ExtendedPayloadSize", MessageAttributeValue.builder().stringValue("300000").dataType("Number").build()
+ ))
+ .build();
+
+ ReceiveMessageResponse sqsResponse = ReceiveMessageResponse.builder()
+ .messages(message)
+ .build();
+
+ when(mockSqsBackend.receiveMessage(any(ReceiveMessageRequest.class)))
+ .thenReturn(CompletableFuture.completedFuture(sqsResponse));
+
+ CompletableFuture> failedFuture = new CompletableFuture<>();
+ failedFuture.completeExceptionally(NoSuchKeyException.builder().build());
+ when(mockS3.getObject(any(GetObjectRequest.class), any(AsyncResponseTransformer.class)))
+ .thenReturn(failedFuture);
+
+ assertThrows(CompletionException.class, () -> {
+ clientWithIgnore.receiveMessageAsStream(request).join();
+ });
+ }
+
+ @Test
+ public void testReceiveMessageAsStream_MultipleMessages_MixedTypes() throws IOException {
+ ExtendedAsyncClientConfiguration config = new ExtendedAsyncClientConfiguration()
+ .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME)
+ .withPayloadSizeThreshold(262144); // 256KB
+ config.setStreamUploadEnabled(true); // Enable stream uploads
+
+ AmazonSQSExtendedAsyncClient clientWithStream = new AmazonSQSExtendedAsyncClient(mockSqsBackend, config);
+
+ ReceiveMessageRequest request = ReceiveMessageRequest.builder()
+ .queueUrl(SQS_QUEUE_URL)
+ .build();
+
+ // Small message (no S3)
+ Message smallMessage = Message.builder()
+ .messageId("msg1")
+ .body("small message")
+ .receiptHandle("receipt1")
+ .build();
+
+ // Large message (with S3 pointer)
+ String s3Pointer = new PayloadS3Pointer(S3_BUCKET_NAME, "test-key").toJson();
+ Message largeMessage = Message.builder()
+ .messageId("msg2")
+ .body(s3Pointer)
+ .receiptHandle("receipt2")
+ .messageAttributes(ImmutableMap.of(
+ "ExtendedPayloadSize", MessageAttributeValue.builder().stringValue("300000").dataType("Number").build()
+ ))
+ .build();
+
+ ReceiveMessageResponse sqsResponse = ReceiveMessageResponse.builder()
+ .messages(Arrays.asList(smallMessage, largeMessage))
+ .build();
+
+ @SuppressWarnings("unchecked")
+ ResponseInputStream mockStream = mock(ResponseInputStream.class);
+ when(mockStream.read(any(byte[].class))).thenReturn(-1);
+
+ when(mockSqsBackend.receiveMessage(any(ReceiveMessageRequest.class)))
+ .thenReturn(CompletableFuture.completedFuture(sqsResponse));
+
+ @SuppressWarnings("unchecked")
+ CompletableFuture> futureStream = CompletableFuture.completedFuture(mockStream);
+ when(mockS3.getObject(any(GetObjectRequest.class), any(AsyncResponseTransformer.class)))
+ .thenReturn(futureStream);
+
+ CompletableFuture future = clientWithStream.receiveMessageAsStream(request);
+ ReceiveStreamMessageResponse response = future.join();
+
+ assertEquals(2, response.streamMessages().size());
+
+ StreamMessage msg1 = response.streamMessages().get(0);
+ assertEquals("msg1", msg1.getMessage().messageId());
+ assertFalse(msg1.hasStreamPayload());
+
+ StreamMessage msg2 = response.streamMessages().get(1);
+ assertEquals("msg2", msg2.getMessage().messageId());
+ assertTrue(msg2.hasStreamPayload());
+ assertSame(mockStream, msg2.getPayloadStream());
+ }
+
+ @Test
+ public void testSendStreamMessage_LargeFileUpload_StoresInS3AndSendsPointer() {
+ int fileSizeBytes = 500_000;
+ String fileContent = generateStringWithLength(fileSizeBytes);
+ java.io.InputStream fileStream = new java.io.ByteArrayInputStream(fileContent.getBytes(StandardCharsets.UTF_8));
+
+ SendMessageRequest request = SendMessageRequest.builder()
+ .queueUrl(SQS_QUEUE_URL)
+ .messageAttributes(ImmutableMap.of(
+ "fileName", MessageAttributeValue.builder().stringValue("largefile.json").dataType("String").build(),
+ "contentType", MessageAttributeValue.builder().stringValue("application/json").dataType("String").build()
+ ))
+ .build();
+
+ ((AmazonSQSExtendedAsyncClient) extendedSqsWithDefaultConfig)
+ .sendStreamMessage(request, fileStream, fileSizeBytes).join();
+
+ ArgumentCaptor s3Captor = ArgumentCaptor.forClass(PutObjectRequest.class);
+ verify(mockS3, times(1)).putObject(s3Captor.capture(), any(AsyncRequestBody.class));
+ assertEquals(S3_BUCKET_NAME, s3Captor.getValue().bucket());
+
+ ArgumentCaptor sqsCaptor = ArgumentCaptor.forClass(SendMessageRequest.class);
+ verify(mockSqsBackend, times(1)).sendMessage(sqsCaptor.capture());
+
+ assertTrue(sqsCaptor.getValue().messageBody().contains(S3_BUCKET_NAME));
+ assertTrue(sqsCaptor.getValue().messageAttributes().containsKey("fileName"));
+ assertTrue(sqsCaptor.getValue().messageAttributes().containsKey("contentType"));
+ assertTrue(sqsCaptor.getValue().messageAttributes()
+ .containsKey(AmazonSQSExtendedClientUtil.LEGACY_RESERVED_ATTRIBUTE_NAME));
+ assertEquals(String.valueOf(fileSizeBytes),
+ sqsCaptor.getValue().messageAttributes()
+ .get(AmazonSQSExtendedClientUtil.LEGACY_RESERVED_ATTRIBUTE_NAME).stringValue());
+ }
+
+ @Test
+ public void testSendStreamMessage_SmallPayload_SendsDirectlyWithoutS3() {
+ String smallMessage = "Small notification message";
+ java.io.InputStream stream = new java.io.ByteArrayInputStream(smallMessage.getBytes(StandardCharsets.UTF_8));
+
+ SendMessageRequest request = SendMessageRequest.builder()
+ .queueUrl(SQS_QUEUE_URL)
+ .messageAttributes(ImmutableMap.of(
+ "messageType", MessageAttributeValue.builder().stringValue("notification").dataType("String").build()
+ ))
+ .build();
+
+ ((AmazonSQSExtendedAsyncClient) extendedSqsWithDefaultConfig)
+ .sendStreamMessage(request, stream, smallMessage.length()).join();
+
+ verify(mockS3, never()).putObject(any(PutObjectRequest.class), any(AsyncRequestBody.class));
+
+ ArgumentCaptor sqsCaptor = ArgumentCaptor.forClass(SendMessageRequest.class);
+ verify(mockSqsBackend, times(1)).sendMessage(sqsCaptor.capture());
+ assertEquals(smallMessage, sqsCaptor.getValue().messageBody());
+
+ assertFalse(sqsCaptor.getValue().messageAttributes()
+ .containsKey(AmazonSQSExtendedClientUtil.LEGACY_RESERVED_ATTRIBUTE_NAME));
+ }
}
+
diff --git a/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientTest.java b/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientTest.java
index 96ac44f..858abb6 100644
--- a/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientTest.java
+++ b/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientTest.java
@@ -68,6 +68,7 @@
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.ArgumentMatchers.eq;
@@ -82,6 +83,12 @@
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when;
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.charset.StandardCharsets;
+import java.io.ByteArrayInputStream;
+
+
/**
* Tests the AmazonSQSExtendedClient class.
*/
@@ -113,6 +120,13 @@ public class AmazonSQSExtendedClientTest {
// should be > 1 and << SQS_SIZE_LIMIT
private static final int ARBITRARY_SMALLER_THRESHOLD = 500;
+
+ // Stream upload thresholds
+ private static final int STREAM_UPLOAD_THRESHOLD = 5 * 1024 * 1024; // 5MB default
+ private static final int MORE_THAN_STREAM_THRESHOLD = STREAM_UPLOAD_THRESHOLD + 1;
+
+ // Stream part size
+ private static final int STREAM_UPLOAD_PART_SIZE = 5 * 1024 * 1024; // 5MB
@BeforeEach
public void setupClients() {
@@ -777,4 +791,204 @@ private String getLargeReceiptHandle(String s3Key, String originalReceiptHandle)
private String getSampleLargeReceiptHandle(String originalReceiptHandle) {
return getLargeReceiptHandle(UUID.randomUUID().toString(), originalReceiptHandle);
}
+
+ @Test
+ public void testReceiveMessageAsStream_LargeMessage_WithStreamStore_ReturnsMessageWithStream() throws IOException {
+ String largeMessageBody = generateStringWithLength(MORE_THAN_STREAM_THRESHOLD);
+
+ ResponseInputStream mockStream = mock(ResponseInputStream.class);
+ when(mockStream.read(any(byte[].class))).thenReturn(-1);
+ when(mockS3.getObject(isA(GetObjectRequest.class))).thenReturn(mockStream);
+
+ String s3Key = "stream-key-" + UUID.randomUUID();
+ String pointer = new PayloadS3Pointer(S3_BUCKET_NAME, s3Key).toJson();
+ Message sqsMessage = Message.builder()
+ .messageAttributes(ImmutableMap.of(SQSExtendedClientConstants.RESERVED_ATTRIBUTE_NAME,
+ MessageAttributeValue.builder().dataType("Number").stringValue(String.valueOf(largeMessageBody.length())).build()))
+ .body(pointer)
+ .receiptHandle("test-receipt-handle")
+ .build();
+
+ when(mockSqsBackend.receiveMessage(isA(ReceiveMessageRequest.class)))
+ .thenReturn(ReceiveMessageResponse.builder().messages(sqsMessage).build());
+
+ ExtendedClientConfiguration extendedClientConfiguration = new ExtendedClientConfiguration()
+ .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME)
+ .withStreamUploadEnabled(true)
+ .withStreamUploadPartSize(STREAM_UPLOAD_PART_SIZE)
+ .withStreamUploadThreshold(STREAM_UPLOAD_THRESHOLD);
+ AmazonSQSExtendedClient sqsExtended = new AmazonSQSExtendedClient(mockSqsBackend, extendedClientConfiguration);
+
+ ReceiveMessageRequest request = ReceiveMessageRequest.builder().queueUrl(SQS_QUEUE_URL).build();
+ ReceiveStreamMessageResponse response = sqsExtended.receiveMessageAsStream(request);
+ sqsExtended.close();
+
+ assertEquals(1, response.streamMessages().size());
+ StreamMessage streamMessage = response.streamMessages().get(0);
+
+ assertTrue(streamMessage.hasStreamPayload());
+
+ ResponseInputStream payloadStream = streamMessage.getPayloadStream();
+ assertNotNull(payloadStream);
+
+ verify(mockS3, times(1)).getObject(isA(GetObjectRequest.class));
+ assertEquals("-..s3BucketName..-test-bucket-name-..s3BucketName..--..s3Key..-stream-key-test-s3-key-uuid-..s3Key..-test-receipt-handle", streamMessage.getMessage().receiptHandle());
+ }
+
+ @Test
+ public void testSendStreamMessage_LargeFileUpload_StoresInS3AndSendsPointer() {
+ int fileSizeBytes = MORE_THAN_STREAM_THRESHOLD;
+ String fileContent = generateStringWithLength(fileSizeBytes);
+ InputStream fileStream = new ByteArrayInputStream(fileContent.getBytes(StandardCharsets.UTF_8));
+
+ ExtendedClientConfiguration streamConfig = new ExtendedClientConfiguration()
+ .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME)
+ .withStreamUploadEnabled(true)
+ .withStreamUploadThreshold(STREAM_UPLOAD_THRESHOLD);
+
+ SqsClient streamClient = spy(new AmazonSQSExtendedClient(mockSqsBackend, streamConfig));
+
+ SendMessageRequest request = SendMessageRequest.builder()
+ .queueUrl(SQS_QUEUE_URL)
+ .messageAttributes(ImmutableMap.of(
+ "fileName", MessageAttributeValue.builder().stringValue("largefile.json").dataType("String").build(),
+ "contentType", MessageAttributeValue.builder().stringValue("application/json").dataType("String").build()
+ ))
+ .build();
+
+ ((AmazonSQSExtendedClient) streamClient).sendStreamMessage(request, fileStream, fileSizeBytes);
+
+ ArgumentCaptor sqsCaptor = ArgumentCaptor.forClass(SendMessageRequest.class);
+ verify(mockSqsBackend, times(1)).sendMessage(sqsCaptor.capture());
+
+ assertTrue(sqsCaptor.getValue().messageBody().contains(S3_BUCKET_NAME));
+ assertTrue(sqsCaptor.getValue().messageAttributes().containsKey("fileName"));
+ assertTrue(sqsCaptor.getValue().messageAttributes().containsKey("contentType"));
+ assertTrue(sqsCaptor.getValue().messageAttributes()
+ .containsKey(AmazonSQSExtendedClientUtil.LEGACY_RESERVED_ATTRIBUTE_NAME));
+ assertEquals(String.valueOf(fileSizeBytes),
+ sqsCaptor.getValue().messageAttributes()
+ .get(AmazonSQSExtendedClientUtil.LEGACY_RESERVED_ATTRIBUTE_NAME).stringValue());
+ }
+
+ @Test
+ public void testSendStreamMessage_SmallPayload_SendsDirectlyWithoutS3() {
+ String smallMessage = "Small notification message";
+ java.io.InputStream stream = new java.io.ByteArrayInputStream(smallMessage.getBytes(java.nio.charset.StandardCharsets.UTF_8));
+
+ SendMessageRequest request = SendMessageRequest.builder()
+ .queueUrl(SQS_QUEUE_URL)
+ .messageAttributes(ImmutableMap.of(
+ "messageType", MessageAttributeValue.builder().stringValue("notification").dataType("String").build()
+ ))
+ .build();
+
+ ((AmazonSQSExtendedClient) extendedSqsWithDefaultConfig).sendStreamMessage(request, stream, smallMessage.length());
+
+ verify(mockS3, never()).putObject(any(PutObjectRequest.class), any(RequestBody.class));
+ ArgumentCaptor sqsCaptor = ArgumentCaptor.forClass(SendMessageRequest.class);
+ verify(mockSqsBackend, times(1)).sendMessage(sqsCaptor.capture());
+ assertEquals(smallMessage, sqsCaptor.getValue().messageBody());
+ assertFalse(sqsCaptor.getValue().messageAttributes()
+ .containsKey(AmazonSQSExtendedClientUtil.LEGACY_RESERVED_ATTRIBUTE_NAME));
+ }
+
+ @Test
+ public void testSendStreamMessage_WithS3KeyPrefix() {
+ int dataSize = MORE_THAN_SQS_SIZE_LIMIT;
+ String largeData = generateStringWithLength(dataSize);
+ java.io.InputStream stream = new java.io.ByteArrayInputStream(largeData.getBytes(java.nio.charset.StandardCharsets.UTF_8));
+
+ SendMessageRequest request = SendMessageRequest.builder()
+ .queueUrl(SQS_QUEUE_URL)
+ .build();
+
+ ((AmazonSQSExtendedClient) extendedSqsWithS3KeyPrefix).sendStreamMessage(request, stream, dataSize);
+ ArgumentCaptor s3Captor = ArgumentCaptor.forClass(PutObjectRequest.class);
+ verify(mockS3, times(1)).putObject(s3Captor.capture(), any(RequestBody.class));
+ assertTrue(s3Captor.getValue().key().startsWith(S3_KEY_PREFIX),
+ "S3 key should start with prefix: " + S3_KEY_PREFIX);
+ }
+
+ @Test
+ public void testSendStreamMessageBatch_MixedSizes_OnlyLargeMessagesUseS3() {
+ ExtendedClientConfiguration streamConfig = new ExtendedClientConfiguration()
+ .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME)
+ .withStreamUploadEnabled(true)
+ .withStreamUploadThreshold(STREAM_UPLOAD_THRESHOLD); // 5MB
+
+ SqsClient streamClient = spy(new AmazonSQSExtendedClient(mockSqsBackend, streamConfig));
+
+ List entries = new ArrayList<>();
+ List streams = new ArrayList<>();
+ List contentLengths = new ArrayList<>();
+
+ // Small message (100KB - below SQS limit)
+ String smallMsg = generateStringWithLength(100_000);
+ entries.add(SendMessageBatchRequestEntry.builder()
+ .id("msg1")
+ .messageAttributes(ImmutableMap.of("size", MessageAttributeValue.builder().stringValue("small").dataType("String").build()))
+ .build());
+ streams.add(new ByteArrayInputStream(smallMsg.getBytes(StandardCharsets.UTF_8)));
+ contentLengths.add((long) smallMsg.length());
+
+ // Large message (300KB - above SQS limit but below stream threshold)
+ String largeMsg = generateStringWithLength(300_000);
+ entries.add(SendMessageBatchRequestEntry.builder()
+ .id("msg2")
+ .messageAttributes(ImmutableMap.of("size", MessageAttributeValue.builder().stringValue("large").dataType("String").build()))
+ .build());
+ streams.add(new java.io.ByteArrayInputStream(largeMsg.getBytes(StandardCharsets.UTF_8)));
+ contentLengths.add((long) largeMsg.length());
+
+ // Very large message (6MB - above stream threshold, uses multipart)
+ String veryLargeMsg = generateStringWithLength(MORE_THAN_STREAM_THRESHOLD);
+ entries.add(SendMessageBatchRequestEntry.builder()
+ .id("msg3")
+ .build());
+ streams.add(new java.io.ByteArrayInputStream(veryLargeMsg.getBytes(StandardCharsets.UTF_8)));
+ contentLengths.add((long) veryLargeMsg.length());
+
+ SendMessageBatchRequest batchRequest = SendMessageBatchRequest.builder()
+ .queueUrl(SQS_QUEUE_URL)
+ .entries(entries)
+ .build();
+
+ ((AmazonSQSExtendedClient) streamClient).sendStreamMessageBatch(batchRequest, streams, contentLengths);
+
+ ArgumentCaptor sqsCaptor = ArgumentCaptor.forClass(SendMessageBatchRequest.class);
+ verify(mockSqsBackend, times(1)).sendMessageBatch(sqsCaptor.capture());
+ SendMessageBatchRequestEntry firstEntry = sqsCaptor.getValue().entries().get(0);
+ assertEquals(smallMsg, firstEntry.messageBody());
+ SendMessageBatchRequestEntry secondEntry = sqsCaptor.getValue().entries().get(1);
+ assertTrue(secondEntry.messageBody().contains(S3_BUCKET_NAME));
+ assertTrue(secondEntry.messageAttributes().containsKey(AmazonSQSExtendedClientUtil.LEGACY_RESERVED_ATTRIBUTE_NAME));
+ assertTrue(secondEntry.messageAttributes().containsKey("size"));
+ assertEquals("large", secondEntry.messageAttributes().get("size").stringValue());
+ SendMessageBatchRequestEntry thirdEntry = sqsCaptor.getValue().entries().get(2);
+ assertTrue(thirdEntry.messageBody().contains(S3_BUCKET_NAME));
+ assertTrue(thirdEntry.messageAttributes().containsKey(AmazonSQSExtendedClientUtil.LEGACY_RESERVED_ATTRIBUTE_NAME));
+ }
+
+ @Test
+ public void testSendStreamMessage_WithEncryption_AppliesKMSToS3Upload() {
+ int dataSize = MORE_THAN_SQS_SIZE_LIMIT;
+ String sensitiveData = generateStringWithLength(dataSize);
+ InputStream stream = new ByteArrayInputStream(sensitiveData.getBytes(StandardCharsets.UTF_8));
+
+ SendMessageRequest request = SendMessageRequest.builder()
+ .queueUrl(SQS_QUEUE_URL)
+ .messageAttributes(ImmutableMap.of(
+ "dataType", MessageAttributeValue.builder().stringValue("sensitive").dataType("String").build()
+ ))
+ .build();
+
+ ((AmazonSQSExtendedClient) extendedSqsWithCustomKMS).sendStreamMessage(request, stream, dataSize);
+
+ ArgumentCaptor s3Captor = ArgumentCaptor.forClass(PutObjectRequest.class);
+ verify(mockS3, times(1)).putObject(s3Captor.capture(), any(RequestBody.class));
+ assertEquals(S3_SERVER_SIDE_ENCRYPTION_KMS_KEY_ID, s3Captor.getValue().ssekmsKeyId());
+ verify(mockSqsBackend, times(1)).sendMessage(any(SendMessageRequest.class));
+ }
+
}
diff --git a/src/test/java/com/amazon/sqs/javamessaging/ExtendedAsyncClientConfigurationTest.java b/src/test/java/com/amazon/sqs/javamessaging/ExtendedAsyncClientConfigurationTest.java
index 879f098..b772586 100644
--- a/src/test/java/com/amazon/sqs/javamessaging/ExtendedAsyncClientConfigurationTest.java
+++ b/src/test/java/com/amazon/sqs/javamessaging/ExtendedAsyncClientConfigurationTest.java
@@ -85,4 +85,12 @@ public void testLargePayloadSupportEnabledWithDeleteFromS3Disabled() {
assertNotNull(extendedClientConfiguration.getS3AsyncClient());
assertEquals(s3BucketName, extendedClientConfiguration.getS3BucketName());
}
+
+ @Test
+ public void testStreamUploadEnabledEnabled() {
+ ExtendedAsyncClientConfiguration extendedClientConfiguration = new ExtendedAsyncClientConfiguration();
+ extendedClientConfiguration.withStreamUploadEnabled(true);
+
+ assertTrue(extendedClientConfiguration.isStreamUploadEnabled());
+ }
}
diff --git a/src/test/java/com/amazon/sqs/javamessaging/ExtendedClientConfigurationTest.java b/src/test/java/com/amazon/sqs/javamessaging/ExtendedClientConfigurationTest.java
index 2dc5b6b..e7b3757 100644
--- a/src/test/java/com/amazon/sqs/javamessaging/ExtendedClientConfigurationTest.java
+++ b/src/test/java/com/amazon/sqs/javamessaging/ExtendedClientConfigurationTest.java
@@ -224,4 +224,58 @@ public void testS3keyPrefixWithALargeString() {
assertThrows(SdkClientException.class, () -> extendedClientConfiguration.withS3KeyPrefix(s3KeyPrefix));
}
+
+ @Test
+ public void testStreamUploadEnabled() {
+ ExtendedClientConfiguration extendedClientConfiguration = new ExtendedClientConfiguration();
+ extendedClientConfiguration.withStreamUploadEnabled(true);
+
+ assertTrue(extendedClientConfiguration.isStreamUploadEnabled());
+ }
+
+ @Test
+ public void testStreamUploadThresholdCustomValue() {
+ int customThreshold = 10 * 1024 * 1024; // 10MB
+ ExtendedClientConfiguration extendedClientConfiguration = new ExtendedClientConfiguration();
+ extendedClientConfiguration.withStreamUploadThreshold(customThreshold);
+
+ assertEquals(customThreshold, extendedClientConfiguration.getStreamUploadThreshold());
+ }
+
+ @Test
+ public void testStreamUploadPartSizeCustomValue() {
+ int customPartSize = 10 * 1024 * 1024; // 10MB
+ ExtendedClientConfiguration extendedClientConfiguration = new ExtendedClientConfiguration();
+ extendedClientConfiguration.withStreamUploadPartSize(customPartSize);
+
+ assertEquals(customPartSize, extendedClientConfiguration.getStreamUploadPartSize());
+ }
+
+ @Test
+ public void testStreamUploadPartSizeBelowMinimumRoundedUpTo5MB() {
+ int belowMinimum = 3 * 1024 * 1024; // 3MB (below 5MB minimum)
+ ExtendedClientConfiguration extendedClientConfiguration = new ExtendedClientConfiguration();
+ extendedClientConfiguration.withStreamUploadPartSize(belowMinimum);
+
+ assertEquals(5 * 1024 * 1024, extendedClientConfiguration.getStreamUploadPartSize());
+ }
+
+ @Test
+ public void testStreamConfigurationInCopyConstructor() {
+ S3Client s3 = mock(S3Client.class);
+ int customThreshold = 10 * 1024 * 1024;
+ int customPartSize = 8 * 1024 * 1024;
+
+ ExtendedClientConfiguration originalConfig = new ExtendedClientConfiguration();
+ originalConfig.withPayloadSupportEnabled(s3, s3BucketName)
+ .withStreamUploadEnabled(true)
+ .withStreamUploadThreshold(customThreshold)
+ .withStreamUploadPartSize(customPartSize);
+
+ ExtendedClientConfiguration copiedConfig = new ExtendedClientConfiguration(originalConfig);
+
+ assertTrue(copiedConfig.isStreamUploadEnabled());
+ assertEquals(customThreshold, copiedConfig.getStreamUploadThreshold());
+ assertEquals(customPartSize, copiedConfig.getStreamUploadPartSize());
+ }
}