From 85e558cf66f46c21894ae6e8feed8568fb4a9091 Mon Sep 17 00:00:00 2001 From: akbarsigit Date: Tue, 7 Oct 2025 16:52:08 +0700 Subject: [PATCH 1/8] Add multipart support in uploading to S3 --- .../AmazonSQSExtendedAsyncClient.java | 44 +++- .../AmazonSQSExtendedClient.java | 39 +++- .../ExtendedAsyncClientConfiguration.java | 25 +++ .../ExtendedClientConfiguration.java | 30 +++ .../AmazonSQSExtendedAsyncClientTest.java | 194 +++++++++++++++++ .../AmazonSQSExtendedClientTest.java | 197 ++++++++++++++++++ .../ExtendedAsyncClientConfigurationTest.java | 43 ++++ .../ExtendedClientConfigurationTest.java | 54 +++++ 8 files changed, 617 insertions(+), 9 deletions(-) diff --git a/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedAsyncClient.java b/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedAsyncClient.java index 7ebe3b8..7907cc3 100644 --- a/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedAsyncClient.java +++ b/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedAsyncClient.java @@ -50,6 +50,7 @@ import software.amazon.payloadoffloading.PayloadStoreAsync; import software.amazon.payloadoffloading.S3AsyncDao; import software.amazon.payloadoffloading.S3BackedPayloadStoreAsync; +import software.amazon.payloadoffloading.S3BackedMultipartPayloadStoreAsync; import software.amazon.payloadoffloading.Util; /** @@ -129,7 +130,15 @@ public AmazonSQSExtendedAsyncClient(SqsAsyncClient sqsClient, S3AsyncDao s3Dao = new S3AsyncDao(clientConfiguration.getS3AsyncClient(), clientConfiguration.getServerSideEncryptionStrategy(), clientConfiguration.getObjectCannedACL()); - this.payloadStore = new S3BackedPayloadStoreAsync(s3Dao, clientConfiguration.getS3BucketName()); + if (clientConfiguration.isMultipartUploadEnabled()) { + this.payloadStore = new S3BackedMultipartPayloadStoreAsync( + s3Dao, + clientConfiguration.getS3BucketName(), + clientConfiguration.getMultipartUploadPartSize(), + clientConfiguration.getMultipartUploadThreshold()); + } else { + this.payloadStore = new S3BackedPayloadStoreAsync(s3Dao, clientConfiguration.getS3BucketName()); + } } /** @@ -524,10 +533,37 @@ private CompletableFuture storeMessageInS3(SendMessageReques private CompletableFuture storeOriginalPayload(String messageContentStr) { String s3KeyPrefix = clientConfiguration.getS3KeyPrefix(); - if (StringUtils.isBlank(s3KeyPrefix)) { - return payloadStore.storeOriginalPayload(messageContentStr); + String key = StringUtils.isBlank(s3KeyPrefix) ? UUID.randomUUID().toString() : s3KeyPrefix + UUID.randomUUID(); + + if (clientConfiguration.isMultipartUploadEnabled()) { + CompletableFuture multipartResult = tryMultipartUploadAsync(messageContentStr, key); + return multipartResult.thenCompose(result -> result != null ? + CompletableFuture.completedFuture(result) : + payloadStore.storeOriginalPayload(messageContentStr, key)); + } + + return payloadStore.storeOriginalPayload(messageContentStr, key); + } + + private CompletableFuture tryMultipartUploadAsync(String payload, String candidateKey) { + if (!(payloadStore instanceof software.amazon.payloadoffloading.MultipartPayloadStoreAsync)) { + return CompletableFuture.completedFuture(null); + } + long sizeBytes = Util.getStringSizeInBytes(payload); + if (sizeBytes < clientConfiguration.getMultipartUploadThreshold()) { + return CompletableFuture.completedFuture(null); + } + try { + return ((software.amazon.payloadoffloading.MultipartPayloadStoreAsync) payloadStore) + .storeOriginalPayloadMultipart(payload, candidateKey) + .exceptionally(ex -> { + LOG.warn("Multipart upload attempt failed; falling back to standard single-part upload."); + return null; + }); + } catch (RuntimeException e) { + LOG.warn("Multipart upload attempt failed; falling back to standard single-part upload."); + return CompletableFuture.completedFuture(null); } - return payloadStore.storeOriginalPayload(messageContentStr, s3KeyPrefix + UUID.randomUUID()); } private static T appendUserAgent(final T builder) { diff --git a/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClient.java b/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClient.java index 5b372a9..42f3554 100644 --- a/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClient.java +++ b/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClient.java @@ -76,7 +76,9 @@ import software.amazon.awssdk.services.sqs.model.TooManyEntriesInBatchRequestException; import software.amazon.awssdk.utils.StringUtils; import software.amazon.payloadoffloading.PayloadStore; +import software.amazon.payloadoffloading.MultipartPayloadStore; import software.amazon.payloadoffloading.S3BackedPayloadStore; +import software.amazon.payloadoffloading.S3BackedMultipartPayloadStore; import software.amazon.payloadoffloading.S3Dao; import software.amazon.payloadoffloading.Util; @@ -145,7 +147,15 @@ public AmazonSQSExtendedClient(SqsClient sqsClient, ExtendedClientConfiguration S3Dao s3Dao = new S3Dao(clientConfiguration.getS3Client(), clientConfiguration.getServerSideEncryptionStrategy(), clientConfiguration.getObjectCannedACL()); - this.payloadStore = new S3BackedPayloadStore(s3Dao, clientConfiguration.getS3BucketName()); + if (clientConfiguration.isMultipartUploadEnabled()) { + this.payloadStore = new S3BackedMultipartPayloadStore( + s3Dao, + clientConfiguration.getS3BucketName(), + clientConfiguration.getMultipartUploadPartSize(), + clientConfiguration.getMultipartUploadThreshold()); + } else { + this.payloadStore = new S3BackedPayloadStore(s3Dao, clientConfiguration.getS3BucketName()); + } } /** @@ -895,13 +905,32 @@ private SendMessageRequest storeMessageInS3(SendMessageRequest sendMessageReques private String storeOriginalPayload(String messageContentStr) { String s3KeyPrefix = clientConfiguration.getS3KeyPrefix(); - if (StringUtils.isBlank(s3KeyPrefix)) { - return payloadStore.storeOriginalPayload(messageContentStr); + String key = StringUtils.isBlank(s3KeyPrefix) ? UUID.randomUUID().toString() : s3KeyPrefix + UUID.randomUUID(); + + if (clientConfiguration.isMultipartUploadEnabled()) { + String multipartResult = tryMultipartUpload(messageContentStr, key); + if (multipartResult != null) { + return multipartResult; + } + } + + return payloadStore.storeOriginalPayload(messageContentStr, key); + } + + private String tryMultipartUpload(String payload, String candidateKey) { + long sizeBytes = Util.getStringSizeInBytes(payload); + if (sizeBytes < clientConfiguration.getMultipartUploadThreshold()) { + return null; + } + + try { + return ((MultipartPayloadStore) payloadStore).storeOriginalPayloadMultipart(payload, candidateKey); + } catch (RuntimeException e) { + LOG.warn("Multipart upload attempt failed; falling back to standard single-part upload."); + return null; } - return payloadStore.storeOriginalPayload(messageContentStr, s3KeyPrefix + UUID.randomUUID()); } - @SuppressWarnings("unchecked") private static T appendUserAgent(final T builder) { return AmazonSQSExtendedClientUtil.appendUserAgent(builder, USER_AGENT_NAME, USER_AGENT_VERSION); } diff --git a/src/main/java/com/amazon/sqs/javamessaging/ExtendedAsyncClientConfiguration.java b/src/main/java/com/amazon/sqs/javamessaging/ExtendedAsyncClientConfiguration.java index e96fff2..09d192f 100644 --- a/src/main/java/com/amazon/sqs/javamessaging/ExtendedAsyncClientConfiguration.java +++ b/src/main/java/com/amazon/sqs/javamessaging/ExtendedAsyncClientConfiguration.java @@ -212,4 +212,29 @@ public ExtendedAsyncClientConfiguration withServerSideEncryption(ServerSideEncry this.setServerSideEncryptionStrategy(serverSideEncryption); return this; } + + /** + * Enables or disables multipart upload support for large payload storage operations. + * @param enabled true to enable multipart uploads when threshold exceeded. + * @return updated configuration + */ + public ExtendedAsyncClientConfiguration withMultipartUploadEnabled(boolean enabled) { + setMultipartUploadEnabled(enabled); + return this; + } + + /** + * Sets the multipart upload threshold (in bytes). Only used when multipart upload is enabled. + * @param threshold threshold in bytes (>0). Values <=0 reset to default (5MB) + * @return updated configuration + */ + public ExtendedAsyncClientConfiguration withMultipartUploadThreshold(int threshold) { + setMultipartUploadThreshold(threshold); + return this; + } + + public ExtendedAsyncClientConfiguration withMultipartUploadPartSize(int partSize) { + setMultipartUploadPartSize(partSize); + return this; + } } \ No newline at end of file diff --git a/src/main/java/com/amazon/sqs/javamessaging/ExtendedClientConfiguration.java b/src/main/java/com/amazon/sqs/javamessaging/ExtendedClientConfiguration.java index 75a30f8..49ab26b 100644 --- a/src/main/java/com/amazon/sqs/javamessaging/ExtendedClientConfiguration.java +++ b/src/main/java/com/amazon/sqs/javamessaging/ExtendedClientConfiguration.java @@ -232,6 +232,36 @@ public ExtendedClientConfiguration withServerSideEncryption(ServerSideEncryption return this; } + /** + * Enables or disables multipart upload support for large payload storage operations. + * @param enabled true to enable multipart uploads when threshold exceeded. + * @return updated configuration + */ + public ExtendedClientConfiguration withMultipartUploadEnabled(boolean enabled) { + setMultipartUploadEnabled(enabled); + return this; + } + + /** + * Sets the multipart upload threshold (in bytes). Only used when multipart upload is enabled. + * @param threshold threshold in bytes (>0). Values <=0 reset to default (5MB) + * @return updated configuration + */ + public ExtendedClientConfiguration withMultipartUploadThreshold(int threshold) { + setMultipartUploadThreshold(threshold); + return this; + } + + + /** + * Sets the multipart upload part size (in bytes). Only used when multipart upload is enabled. + * @param partSize part size in bytes (>0). Values <=0 reset to default (5MB) + */ + public ExtendedClientConfiguration withMultipartUploadPartSize(int partSize) { + setMultipartUploadPartSize(partSize); + return this; + } + /** * Enables support for large-payload messages. * diff --git a/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedAsyncClientTest.java b/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedAsyncClientTest.java index 219e57d..36bd9c9 100644 --- a/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedAsyncClientTest.java +++ b/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedAsyncClientTest.java @@ -88,6 +88,11 @@ public class AmazonSQSExtendedAsyncClientTest { // should be > 1 and << SQS_SIZE_LIMIT private static final int ARBITRARY_SMALLER_THRESHOLD = 500; + + // Multipart upload thresholds for testing + private static final int MULTIPART_UPLOAD_THRESHOLD = 5 * 1024 * 1024; // 5MB default + private static final int LESS_THAN_MULTIPART_THRESHOLD = MULTIPART_UPLOAD_THRESHOLD - 1; + private static final int MORE_THAN_MULTIPART_THRESHOLD = MULTIPART_UPLOAD_THRESHOLD + 1; @BeforeEach public void setupClients() { @@ -729,4 +734,193 @@ private String generateStringWithLength(int messageLength) { Arrays.fill(charArray, 'x'); return new String(charArray); } + + @Test + public void testWhenMultipartUploadDisabledThenStandardUploadIsUsed() { + String messageBody = generateStringWithLength(MORE_THAN_MULTIPART_THRESHOLD); + ExtendedAsyncClientConfiguration extendedClientConfiguration = new ExtendedAsyncClientConfiguration() + .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME) + .withMultipartUploadEnabled(false); + SqsAsyncClient sqsExtended = spy(new AmazonSQSExtendedAsyncClient(mockSqsBackend, extendedClientConfiguration)); + + SendMessageRequest messageRequest = SendMessageRequest.builder().queueUrl(SQS_QUEUE_URL).messageBody(messageBody).build(); + sqsExtended.sendMessage(messageRequest).join(); + + verify(mockS3, times(1)).putObject(isA(PutObjectRequest.class), isA(AsyncRequestBody.class)); + } + + @Test + public void testWhenMultipartUploadEnabledAndMessageBelowThresholdThenStandardUploadIsUsed() { + String messageBody = generateStringWithLength(LESS_THAN_MULTIPART_THRESHOLD); + ExtendedAsyncClientConfiguration extendedClientConfiguration = new ExtendedAsyncClientConfiguration() + .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME) + .withMultipartUploadEnabled(true) + .withMultipartUploadThreshold(MULTIPART_UPLOAD_THRESHOLD); + SqsAsyncClient sqsExtended = spy(new AmazonSQSExtendedAsyncClient(mockSqsBackend, extendedClientConfiguration)); + + SendMessageRequest messageRequest = SendMessageRequest.builder().queueUrl(SQS_QUEUE_URL).messageBody(messageBody).build(); + sqsExtended.sendMessage(messageRequest).join(); + + verify(mockS3, times(1)).putObject(isA(PutObjectRequest.class), isA(AsyncRequestBody.class)); + } + + @Test + public void testWhenMultipartUploadEnabledAndMessageAboveThresholdThenMultipartIsAttempted() { + String messageBody = generateStringWithLength(MORE_THAN_MULTIPART_THRESHOLD); + ExtendedAsyncClientConfiguration extendedClientConfiguration = new ExtendedAsyncClientConfiguration() + .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME) + .withMultipartUploadEnabled(true) + .withMultipartUploadThreshold(MULTIPART_UPLOAD_THRESHOLD); + SqsAsyncClient sqsExtended = spy(new AmazonSQSExtendedAsyncClient(mockSqsBackend, extendedClientConfiguration)); + + SendMessageRequest messageRequest = SendMessageRequest.builder().queueUrl(SQS_QUEUE_URL).messageBody(messageBody).build(); + sqsExtended.sendMessage(messageRequest).join(); + + verify(mockSqsBackend, times(1)).sendMessage(isA(SendMessageRequest.class)); + } + + @Test + public void testWhenMultipartUploadEnabledWithCustomThresholdThenThresholdIsHonored() { + int customThreshold = 1024 * 1024; // 1MB + int messageLength = customThreshold + 1000; // Just above custom threshold + String messageBody = generateStringWithLength(messageLength); + + ExtendedAsyncClientConfiguration extendedClientConfiguration = new ExtendedAsyncClientConfiguration() + .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME) + .withMultipartUploadEnabled(true) + .withMultipartUploadThreshold(customThreshold); + SqsAsyncClient sqsExtended = spy(new AmazonSQSExtendedAsyncClient(mockSqsBackend, extendedClientConfiguration)); + + SendMessageRequest messageRequest = SendMessageRequest.builder().queueUrl(SQS_QUEUE_URL).messageBody(messageBody).build(); + sqsExtended.sendMessage(messageRequest).join(); + + verify(mockSqsBackend, times(1)).sendMessage(isA(SendMessageRequest.class)); + } + + @Test + public void testWhenMultipartUploadEnabledWithAlwaysThroughS3ThenSmallMessagesAlsoUseS3() { + String messageBody = generateStringWithLength(LESS_THAN_SQS_SIZE_LIMIT); + ExtendedAsyncClientConfiguration extendedClientConfiguration = new ExtendedAsyncClientConfiguration() + .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME) + .withMultipartUploadEnabled(true) + .withMultipartUploadThreshold(MULTIPART_UPLOAD_THRESHOLD) + .withAlwaysThroughS3(true); + SqsAsyncClient sqsExtended = spy(new AmazonSQSExtendedAsyncClient(mockSqsBackend, extendedClientConfiguration)); + + SendMessageRequest messageRequest = SendMessageRequest.builder().queueUrl(SQS_QUEUE_URL).messageBody(messageBody).build(); + sqsExtended.sendMessage(messageRequest).join(); + + verify(mockS3, times(1)).putObject(isA(PutObjectRequest.class), isA(AsyncRequestBody.class)); + } + + @Test + public void testWhenMultipartUploadEnabledThenConfigurationIsSetCorrectly() { + ExtendedAsyncClientConfiguration extendedClientConfiguration = new ExtendedAsyncClientConfiguration() + .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME) + .withMultipartUploadEnabled(true) + .withMultipartUploadThreshold(MULTIPART_UPLOAD_THRESHOLD); + + assertTrue(extendedClientConfiguration.isMultipartUploadEnabled()); + assertEquals(MULTIPART_UPLOAD_THRESHOLD, extendedClientConfiguration.getMultipartUploadThreshold()); + } + + @Test + public void testWhenMultipartUploadDisabledByDefaultThenStandardUploadIsUsed() { + String messageBody = generateStringWithLength(MORE_THAN_MULTIPART_THRESHOLD); + ExtendedAsyncClientConfiguration extendedClientConfiguration = new ExtendedAsyncClientConfiguration() + .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME); + SqsAsyncClient sqsExtended = spy(new AmazonSQSExtendedAsyncClient(mockSqsBackend, extendedClientConfiguration)); + + SendMessageRequest messageRequest = SendMessageRequest.builder().queueUrl(SQS_QUEUE_URL).messageBody(messageBody).build(); + sqsExtended.sendMessage(messageRequest).join(); + + assertFalse(extendedClientConfiguration.isMultipartUploadEnabled()); + verify(mockS3, times(1)).putObject(isA(PutObjectRequest.class), isA(AsyncRequestBody.class)); + } + + @Test + public void testWhenMultipartUploadEnabledWithMessageBatchThenLargeMessagesUseMultipart() { + int customThreshold = 500_000; // 500KB multipart threshold + + // 3 messages above SQS size limit (256KB) will be stored in S3: + // - 300K uses standard S3 upload + // - 600K and 700K use multipart upload (above 500K threshold) + int[] messageLengthForCounter = new int[] { + 100_000, + 200_000, + 300_000, + 600_000, + 700_000 + }; + + List batchEntries = new ArrayList<>(); + for (int i = 0; i < messageLengthForCounter.length; i++) { + int messageLength = messageLengthForCounter[i]; + String messageBody = generateStringWithLength(messageLength); + SendMessageBatchRequestEntry entry = SendMessageBatchRequestEntry.builder() + .id("entry_" + i) + .messageBody(messageBody) + .build(); + batchEntries.add(entry); + } + + ExtendedAsyncClientConfiguration extendedClientConfiguration = new ExtendedAsyncClientConfiguration() + .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME) + .withMultipartUploadEnabled(true) + .withMultipartUploadThreshold(customThreshold); + SqsAsyncClient sqsExtended = spy(new AmazonSQSExtendedAsyncClient(mockSqsBackend, extendedClientConfiguration)); + + SendMessageBatchRequest batchRequest = SendMessageBatchRequest.builder().queueUrl(SQS_QUEUE_URL).entries(batchEntries).build(); + sqsExtended.sendMessageBatch(batchRequest).join(); + + verify(mockS3, times(3)).putObject(isA(PutObjectRequest.class), isA(AsyncRequestBody.class)); + verify(mockSqsBackend, times(1)).sendMessageBatch(isA(SendMessageBatchRequest.class)); + } + + @Test + public void testWhenMultipartUploadEnabledWithCustomKMSThenKMSIsAppliedToMultipart() { + String messageBody = generateStringWithLength(MORE_THAN_MULTIPART_THRESHOLD); + ExtendedAsyncClientConfiguration extendedClientConfiguration = new ExtendedAsyncClientConfiguration() + .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME) + .withMultipartUploadEnabled(true) + .withMultipartUploadThreshold(MULTIPART_UPLOAD_THRESHOLD) + .withServerSideEncryption(SERVER_SIDE_ENCRYPTION_CUSTOM_STRATEGY); + SqsAsyncClient sqsExtended = spy(new AmazonSQSExtendedAsyncClient(mockSqsBackend, extendedClientConfiguration)); + + SendMessageRequest messageRequest = SendMessageRequest.builder().queueUrl(SQS_QUEUE_URL).messageBody(messageBody).build(); + sqsExtended.sendMessage(messageRequest).join(); + + assertEquals(SERVER_SIDE_ENCRYPTION_CUSTOM_STRATEGY, extendedClientConfiguration.getServerSideEncryptionStrategy()); + verify(mockSqsBackend, times(1)).sendMessage(isA(SendMessageRequest.class)); + } + + @Test + public void testWhenMultipartUploadEnabledAtExactThresholdThenMultipartIsNotUsed() { + String messageBody = generateStringWithLength(MULTIPART_UPLOAD_THRESHOLD); + ExtendedAsyncClientConfiguration extendedClientConfiguration = new ExtendedAsyncClientConfiguration() + .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME) + .withMultipartUploadEnabled(true) + .withMultipartUploadThreshold(MULTIPART_UPLOAD_THRESHOLD); + SqsAsyncClient sqsExtended = spy(new AmazonSQSExtendedAsyncClient(mockSqsBackend, extendedClientConfiguration)); + + SendMessageRequest messageRequest = SendMessageRequest.builder().queueUrl(SQS_QUEUE_URL).messageBody(messageBody).build(); + sqsExtended.sendMessage(messageRequest).join(); + + verify(mockS3, times(1)).putObject(isA(PutObjectRequest.class), isA(AsyncRequestBody.class)); + } + + @Test + public void testWhenMultipartUploadFailsThenFallsBackToStandardUpload() { + String messageBody = generateStringWithLength(MORE_THAN_MULTIPART_THRESHOLD); + ExtendedAsyncClientConfiguration extendedClientConfiguration = new ExtendedAsyncClientConfiguration() + .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME) + .withMultipartUploadEnabled(true) + .withMultipartUploadThreshold(MULTIPART_UPLOAD_THRESHOLD); + SqsAsyncClient sqsExtended = spy(new AmazonSQSExtendedAsyncClient(mockSqsBackend, extendedClientConfiguration)); + + SendMessageRequest messageRequest = SendMessageRequest.builder().queueUrl(SQS_QUEUE_URL).messageBody(messageBody).build(); + sqsExtended.sendMessage(messageRequest).join(); + + verify(mockSqsBackend, times(1)).sendMessage(isA(SendMessageRequest.class)); + } } diff --git a/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientTest.java b/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientTest.java index 96ac44f..9459fd0 100644 --- a/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientTest.java +++ b/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientTest.java @@ -71,6 +71,7 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.isA; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.isA; import static org.mockito.Mockito.mock; @@ -113,6 +114,11 @@ public class AmazonSQSExtendedClientTest { // should be > 1 and << SQS_SIZE_LIMIT private static final int ARBITRARY_SMALLER_THRESHOLD = 500; + + // Multipart upload thresholds + private static final int MULTIPART_UPLOAD_THRESHOLD = 5 * 1024 * 1024; // 5MB default + private static final int LESS_THAN_MULTIPART_THRESHOLD = MULTIPART_UPLOAD_THRESHOLD - 1; + private static final int MORE_THAN_MULTIPART_THRESHOLD = MULTIPART_UPLOAD_THRESHOLD + 1; @BeforeEach public void setupClients() { @@ -777,4 +783,195 @@ private String getLargeReceiptHandle(String s3Key, String originalReceiptHandle) private String getSampleLargeReceiptHandle(String originalReceiptHandle) { return getLargeReceiptHandle(UUID.randomUUID().toString(), originalReceiptHandle); } + + @Test + public void testWhenMultipartUploadDisabledThenStandardUploadIsUsed() { + String messageBody = generateStringWithLength(MORE_THAN_MULTIPART_THRESHOLD); + ExtendedClientConfiguration extendedClientConfiguration = new ExtendedClientConfiguration() + .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME) + .withMultipartUploadEnabled(false); + SqsClient sqsExtended = spy(new AmazonSQSExtendedClient(mockSqsBackend, extendedClientConfiguration)); + + SendMessageRequest messageRequest = SendMessageRequest.builder().queueUrl(SQS_QUEUE_URL).messageBody(messageBody).build(); + sqsExtended.sendMessage(messageRequest); + + verify(mockS3, times(1)).putObject(isA(PutObjectRequest.class), isA(RequestBody.class)); + } + + @Test + public void testWhenMultipartUploadEnabledAndMessageBelowThresholdThenStandardUploadIsUsed() { + String messageBody = generateStringWithLength(LESS_THAN_MULTIPART_THRESHOLD); + ExtendedClientConfiguration extendedClientConfiguration = new ExtendedClientConfiguration() + .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME) + .withMultipartUploadEnabled(true) + .withMultipartUploadThreshold(MULTIPART_UPLOAD_THRESHOLD); + SqsClient sqsExtended = spy(new AmazonSQSExtendedClient(mockSqsBackend, extendedClientConfiguration)); + + SendMessageRequest messageRequest = SendMessageRequest.builder().queueUrl(SQS_QUEUE_URL).messageBody(messageBody).build(); + sqsExtended.sendMessage(messageRequest); + + verify(mockS3, times(1)).putObject(isA(PutObjectRequest.class), isA(RequestBody.class)); + } + + @Test + public void testWhenMultipartUploadEnabledAndMessageAboveThresholdThenMultipartIsAttempted() { + String messageBody = generateStringWithLength(MORE_THAN_MULTIPART_THRESHOLD); + ExtendedClientConfiguration extendedClientConfiguration = new ExtendedClientConfiguration() + .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME) + .withMultipartUploadEnabled(true) + .withMultipartUploadThreshold(MULTIPART_UPLOAD_THRESHOLD); + SqsClient sqsExtended = spy(new AmazonSQSExtendedClient(mockSqsBackend, extendedClientConfiguration)); + + SendMessageRequest messageRequest = SendMessageRequest.builder().queueUrl(SQS_QUEUE_URL).messageBody(messageBody).build(); + sqsExtended.sendMessage(messageRequest); + + verify(mockSqsBackend, times(1)).sendMessage(isA(SendMessageRequest.class)); + } + + @Test + public void testWhenMultipartUploadEnabledWithCustomThresholdThenThresholdIsHonored() { + int customThreshold = 1024 * 1024; + int messageLength = customThreshold + 1000; + String messageBody = generateStringWithLength(messageLength); + + ExtendedClientConfiguration extendedClientConfiguration = new ExtendedClientConfiguration() + .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME) + .withMultipartUploadEnabled(true) + .withMultipartUploadThreshold(customThreshold); + SqsClient sqsExtended = spy(new AmazonSQSExtendedClient(mockSqsBackend, extendedClientConfiguration)); + + SendMessageRequest messageRequest = SendMessageRequest.builder().queueUrl(SQS_QUEUE_URL).messageBody(messageBody).build(); + sqsExtended.sendMessage(messageRequest); + + verify(mockSqsBackend, times(1)).sendMessage(isA(SendMessageRequest.class)); + } + + @Test + public void testWhenMultipartUploadEnabledWithAlwaysThroughS3ThenSmallMessagesAlsoUseS3() { + String messageBody = generateStringWithLength(LESS_THAN_SQS_SIZE_LIMIT); + ExtendedClientConfiguration extendedClientConfiguration = new ExtendedClientConfiguration() + .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME) + .withMultipartUploadEnabled(true) + .withMultipartUploadThreshold(MULTIPART_UPLOAD_THRESHOLD) + .withAlwaysThroughS3(true); + SqsClient sqsExtended = spy(new AmazonSQSExtendedClient(mockSqsBackend, extendedClientConfiguration)); + + SendMessageRequest messageRequest = SendMessageRequest.builder().queueUrl(SQS_QUEUE_URL).messageBody(messageBody).build(); + sqsExtended.sendMessage(messageRequest); + + verify(mockS3, times(1)).putObject(isA(PutObjectRequest.class), isA(RequestBody.class)); + } + + @Test + public void testWhenMultipartUploadEnabledWithS3KeyPrefixThenPrefixIsUsed() { + String messageBody = generateStringWithLength(MORE_THAN_MULTIPART_THRESHOLD); + ExtendedClientConfiguration extendedClientConfiguration = new ExtendedClientConfiguration() + .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME) + .withMultipartUploadEnabled(true) + .withMultipartUploadThreshold(MULTIPART_UPLOAD_THRESHOLD) + .withS3KeyPrefix(S3_KEY_PREFIX); + SqsClient sqsExtended = spy(new AmazonSQSExtendedClient(mockSqsBackend, extendedClientConfiguration)); + + SendMessageRequest messageRequest = SendMessageRequest.builder().queueUrl(SQS_QUEUE_URL).messageBody(messageBody).build(); + sqsExtended.sendMessage(messageRequest); + + verify(mockSqsBackend, times(1)).sendMessage(isA(SendMessageRequest.class)); + } + + @Test + public void testWhenMultipartUploadEnabledThenConfigurationWithCorrectThreshold() { + ExtendedClientConfiguration extendedClientConfiguration = new ExtendedClientConfiguration() + .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME) + .withMultipartUploadEnabled(true) + .withMultipartUploadThreshold(MULTIPART_UPLOAD_THRESHOLD); + + assertTrue(extendedClientConfiguration.isMultipartUploadEnabled()); + assertEquals(MULTIPART_UPLOAD_THRESHOLD, extendedClientConfiguration.getMultipartUploadThreshold()); + } + + @Test + public void testWhenMultipartUploadDisabledByDefaultThenStandardUploadIsUsed() { + String messageBody = generateStringWithLength(MORE_THAN_MULTIPART_THRESHOLD); + ExtendedClientConfiguration extendedClientConfiguration = new ExtendedClientConfiguration() + .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME); + SqsClient sqsExtended = spy(new AmazonSQSExtendedClient(mockSqsBackend, extendedClientConfiguration)); + + SendMessageRequest messageRequest = SendMessageRequest.builder().queueUrl(SQS_QUEUE_URL).messageBody(messageBody).build(); + sqsExtended.sendMessage(messageRequest); + + assertFalse(extendedClientConfiguration.isMultipartUploadEnabled()); + verify(mockS3, times(1)).putObject(isA(PutObjectRequest.class), isA(RequestBody.class)); + } + + @Test + public void testWhenMultipartUploadEnabledWithMessageBatchThenLargeMessagesUseMultipart() { + int customThreshold = 500_000; // 500KB multipart threshold + + // 3 messages above SQS size limit (256KB) will be stored in S3: + // - 300K uses standard S3 upload + // - 600K and 700K use multipart upload (above 500K threshold) + int[] messageLengthForCounter = new int[] { + 100_000, + 200_000, + 300_000, + 600_000, + 700_000 + }; + + List batchEntries = new ArrayList<>(); + for (int i = 0; i < messageLengthForCounter.length; i++) { + int messageLength = messageLengthForCounter[i]; + String messageBody = generateStringWithLength(messageLength); + SendMessageBatchRequestEntry entry = SendMessageBatchRequestEntry.builder() + .id("entry_" + i) + .messageBody(messageBody) + .build(); + batchEntries.add(entry); + } + + ExtendedClientConfiguration extendedClientConfiguration = new ExtendedClientConfiguration() + .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME) + .withMultipartUploadEnabled(true) + .withMultipartUploadThreshold(customThreshold); + SqsClient sqsExtended = spy(new AmazonSQSExtendedClient(mockSqsBackend, extendedClientConfiguration)); + + SendMessageBatchRequest batchRequest = SendMessageBatchRequest.builder().queueUrl(SQS_QUEUE_URL).entries(batchEntries).build(); + sqsExtended.sendMessageBatch(batchRequest); + + verify(mockS3, times(3)).putObject(isA(PutObjectRequest.class), isA(RequestBody.class)); + verify(mockSqsBackend, times(1)).sendMessageBatch(isA(SendMessageBatchRequest.class)); + } + + @Test + public void testWhenMultipartUploadEnabledWithCustomKMSThenKMSIsAppliedToMultipart() { + String messageBody = generateStringWithLength(MORE_THAN_MULTIPART_THRESHOLD); + ExtendedClientConfiguration extendedClientConfiguration = new ExtendedClientConfiguration() + .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME) + .withMultipartUploadEnabled(true) + .withMultipartUploadThreshold(MULTIPART_UPLOAD_THRESHOLD) + .withServerSideEncryption(SERVER_SIDE_ENCRYPTION_CUSTOM_STRATEGY); + SqsClient sqsExtended = spy(new AmazonSQSExtendedClient(mockSqsBackend, extendedClientConfiguration)); + + SendMessageRequest messageRequest = SendMessageRequest.builder().queueUrl(SQS_QUEUE_URL).messageBody(messageBody).build(); + sqsExtended.sendMessage(messageRequest); + + assertEquals(SERVER_SIDE_ENCRYPTION_CUSTOM_STRATEGY, extendedClientConfiguration.getServerSideEncryptionStrategy()); + verify(mockSqsBackend, times(1)).sendMessage(isA(SendMessageRequest.class)); + } + + @Test + public void testWhenMultipartUploadEnabledAtExactThresholdThenMultipartIsNotUsed() { + String messageBody = generateStringWithLength(MULTIPART_UPLOAD_THRESHOLD); + ExtendedClientConfiguration extendedClientConfiguration = new ExtendedClientConfiguration() + .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME) + .withMultipartUploadEnabled(true) + .withMultipartUploadThreshold(MULTIPART_UPLOAD_THRESHOLD); + SqsClient sqsExtended = spy(new AmazonSQSExtendedClient(mockSqsBackend, extendedClientConfiguration)); + + SendMessageRequest messageRequest = SendMessageRequest.builder().queueUrl(SQS_QUEUE_URL).messageBody(messageBody).build(); + sqsExtended.sendMessage(messageRequest); + + verify(mockS3, times(1)).putObject(isA(PutObjectRequest.class), isA(RequestBody.class)); + } + } diff --git a/src/test/java/com/amazon/sqs/javamessaging/ExtendedAsyncClientConfigurationTest.java b/src/test/java/com/amazon/sqs/javamessaging/ExtendedAsyncClientConfigurationTest.java index 879f098..5004b65 100644 --- a/src/test/java/com/amazon/sqs/javamessaging/ExtendedAsyncClientConfigurationTest.java +++ b/src/test/java/com/amazon/sqs/javamessaging/ExtendedAsyncClientConfigurationTest.java @@ -85,4 +85,47 @@ public void testLargePayloadSupportEnabledWithDeleteFromS3Disabled() { assertNotNull(extendedClientConfiguration.getS3AsyncClient()); assertEquals(s3BucketName, extendedClientConfiguration.getS3BucketName()); } + + @Test + public void testMultipartUploadDisabledByDefault() { + ExtendedAsyncClientConfiguration extendedClientConfiguration = new ExtendedAsyncClientConfiguration(); + + assertFalse(extendedClientConfiguration.isMultipartUploadEnabled()); + assertEquals(5 * 1024 * 1024, extendedClientConfiguration.getMultipartUploadThreshold()); // 5MB default + } + + @Test + public void testMultipartUploadEnabled() { + ExtendedAsyncClientConfiguration extendedClientConfiguration = new ExtendedAsyncClientConfiguration(); + extendedClientConfiguration.withMultipartUploadEnabled(true); + + assertTrue(extendedClientConfiguration.isMultipartUploadEnabled()); + } + + @Test + public void testMultipartUploadThresholdCustomValue() { + int customThreshold = 10 * 1024 * 1024; // 10MB + ExtendedAsyncClientConfiguration extendedClientConfiguration = new ExtendedAsyncClientConfiguration(); + extendedClientConfiguration.withMultipartUploadThreshold(customThreshold); + + assertEquals(customThreshold, extendedClientConfiguration.getMultipartUploadThreshold()); + } + + @Test + public void testMultipartUploadPartSizeCustomValue() { + int customPartSize = 10 * 1024 * 1024; // 10MB + ExtendedAsyncClientConfiguration extendedClientConfiguration = new ExtendedAsyncClientConfiguration(); + extendedClientConfiguration.withMultipartUploadPartSize(customPartSize); + + assertEquals(customPartSize, extendedClientConfiguration.getMultipartUploadPartSize()); + } + + @Test + public void testMultipartUploadPartSizeBelowMinimumRoundedUpTo5MB() { + int belowMinimum = 3 * 1024 * 1024; // 3MB (below 5MB minimum) + ExtendedAsyncClientConfiguration extendedClientConfiguration = new ExtendedAsyncClientConfiguration(); + extendedClientConfiguration.withMultipartUploadPartSize(belowMinimum); + + assertEquals(5 * 1024 * 1024, extendedClientConfiguration.getMultipartUploadPartSize()); + } } diff --git a/src/test/java/com/amazon/sqs/javamessaging/ExtendedClientConfigurationTest.java b/src/test/java/com/amazon/sqs/javamessaging/ExtendedClientConfigurationTest.java index 2dc5b6b..fafa791 100644 --- a/src/test/java/com/amazon/sqs/javamessaging/ExtendedClientConfigurationTest.java +++ b/src/test/java/com/amazon/sqs/javamessaging/ExtendedClientConfigurationTest.java @@ -224,4 +224,58 @@ public void testS3keyPrefixWithALargeString() { assertThrows(SdkClientException.class, () -> extendedClientConfiguration.withS3KeyPrefix(s3KeyPrefix)); } + + @Test + public void testMultipartUploadEnabled() { + ExtendedClientConfiguration extendedClientConfiguration = new ExtendedClientConfiguration(); + extendedClientConfiguration.withMultipartUploadEnabled(true); + + assertTrue(extendedClientConfiguration.isMultipartUploadEnabled()); + } + + @Test + public void testMultipartUploadThresholdCustomValue() { + int customThreshold = 10 * 1024 * 1024; // 10MB + ExtendedClientConfiguration extendedClientConfiguration = new ExtendedClientConfiguration(); + extendedClientConfiguration.withMultipartUploadThreshold(customThreshold); + + assertEquals(customThreshold, extendedClientConfiguration.getMultipartUploadThreshold()); + } + + @Test + public void testMultipartUploadPartSizeCustomValue() { + int customPartSize = 10 * 1024 * 1024; // 10MB + ExtendedClientConfiguration extendedClientConfiguration = new ExtendedClientConfiguration(); + extendedClientConfiguration.withMultipartUploadPartSize(customPartSize); + + assertEquals(customPartSize, extendedClientConfiguration.getMultipartUploadPartSize()); + } + + @Test + public void testMultipartUploadPartSizeBelowMinimumRoundedUpTo5MB() { + int belowMinimum = 3 * 1024 * 1024; // 3MB (below 5MB minimum) + ExtendedClientConfiguration extendedClientConfiguration = new ExtendedClientConfiguration(); + extendedClientConfiguration.withMultipartUploadPartSize(belowMinimum); + + assertEquals(5 * 1024 * 1024, extendedClientConfiguration.getMultipartUploadPartSize()); + } + + @Test + public void testMultipartConfigurationInCopyConstructor() { + S3Client s3 = mock(S3Client.class); + int customThreshold = 10 * 1024 * 1024; + int customPartSize = 8 * 1024 * 1024; + + ExtendedClientConfiguration originalConfig = new ExtendedClientConfiguration(); + originalConfig.withPayloadSupportEnabled(s3, s3BucketName) + .withMultipartUploadEnabled(true) + .withMultipartUploadThreshold(customThreshold) + .withMultipartUploadPartSize(customPartSize); + + ExtendedClientConfiguration copiedConfig = new ExtendedClientConfiguration(originalConfig); + + assertTrue(copiedConfig.isMultipartUploadEnabled()); + assertEquals(customThreshold, copiedConfig.getMultipartUploadThreshold()); + assertEquals(customPartSize, copiedConfig.getMultipartUploadPartSize()); + } } From 588473eeea4c32052a3f963c0edbf7aa71b5fb70 Mon Sep 17 00:00:00 2001 From: akbarsigit Date: Fri, 10 Oct 2025 09:17:10 +0700 Subject: [PATCH 2/8] adding stream message support --- .../AmazonSQSExtendedAsyncClient.java | 307 ++++++++++-- .../AmazonSQSExtendedClient.java | 348 ++++++++++++-- .../ExtendedAsyncClientConfiguration.java | 25 +- .../ExtendedClientConfiguration.java | 29 +- .../ReceiveStreamMessageResponse.java | 21 + .../sqs/javamessaging/StreamMessage.java | 40 ++ .../AmazonSQSExtendedAsyncClientTest.java | 448 ++++++++++++------ .../AmazonSQSExtendedClientTest.java | 419 ++++++++-------- .../ExtendedAsyncClientConfigurationTest.java | 30 +- .../ExtendedClientConfigurationTest.java | 38 +- 10 files changed, 1242 insertions(+), 463 deletions(-) create mode 100644 src/main/java/com/amazon/sqs/javamessaging/ReceiveStreamMessageResponse.java create mode 100644 src/main/java/com/amazon/sqs/javamessaging/StreamMessage.java diff --git a/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedAsyncClient.java b/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedAsyncClient.java index 7907cc3..0bb5ceb 100644 --- a/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedAsyncClient.java +++ b/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedAsyncClient.java @@ -9,6 +9,10 @@ import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClientUtil.isS3ReceiptHandle; import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClientUtil.updateMessageAttributePayloadSize; +import java.io.IOException; +import java.io.ByteArrayInputStream; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -47,10 +51,12 @@ import software.amazon.awssdk.services.sqs.model.SendMessageRequest; import software.amazon.awssdk.services.sqs.model.SendMessageResponse; import software.amazon.awssdk.utils.StringUtils; +import software.amazon.awssdk.utils.IoUtils; import software.amazon.payloadoffloading.PayloadStoreAsync; import software.amazon.payloadoffloading.S3AsyncDao; import software.amazon.payloadoffloading.S3BackedPayloadStoreAsync; -import software.amazon.payloadoffloading.S3BackedMultipartPayloadStoreAsync; +import software.amazon.payloadoffloading.S3BackedStreamPayloadStoreAsync; +import software.amazon.payloadoffloading.StreamPayloadStoreAsync; import software.amazon.payloadoffloading.Util; /** @@ -128,14 +134,11 @@ public AmazonSQSExtendedAsyncClient(SqsAsyncClient sqsClient, super(sqsClient); this.clientConfiguration = new ExtendedAsyncClientConfiguration(extendedClientConfig); S3AsyncDao s3Dao = new S3AsyncDao(clientConfiguration.getS3AsyncClient(), - clientConfiguration.getServerSideEncryptionStrategy(), - clientConfiguration.getObjectCannedACL()); - if (clientConfiguration.isMultipartUploadEnabled()) { - this.payloadStore = new S3BackedMultipartPayloadStoreAsync( - s3Dao, - clientConfiguration.getS3BucketName(), - clientConfiguration.getMultipartUploadPartSize(), - clientConfiguration.getMultipartUploadThreshold()); + clientConfiguration.getServerSideEncryptionStrategy(), + clientConfiguration.getObjectCannedACL()); + + if (clientConfiguration.isStreamUploadEnabled()) { + this.payloadStore = new S3BackedStreamPayloadStoreAsync(s3Dao, clientConfiguration.getS3BucketName()); } else { this.payloadStore = new S3BackedPayloadStoreAsync(s3Dao, clientConfiguration.getS3BucketName()); } @@ -293,6 +296,154 @@ public CompletableFuture receiveMessage(ReceiveMessageRe }); } + /** + * Receives messages from the specified queue with streaming access to large payloads. + * + * @param receiveMessageRequest The receive message request + * @return A CompletableFuture containing the response with streaming message access + */ + public CompletableFuture receiveMessageAsStream(ReceiveMessageRequest receiveMessageRequest) { + if (receiveMessageRequest == null) { + String errorMessage = "receiveMessageRequest cannot be null."; + LOG.error(errorMessage); + throw SdkClientException.create(errorMessage); + } + + ReceiveMessageRequest.Builder receiveMessageRequestBuilder = receiveMessageRequest.toBuilder(); + appendUserAgent(receiveMessageRequestBuilder); + + if (!clientConfiguration.isPayloadSupportEnabled()) { + // If payload support is disabled, fall back to regular receive and wrap in stream response + return super.receiveMessage(receiveMessageRequestBuilder.build()) + .thenApply(receiveMessageResponse -> { + List streamMessages = receiveMessageResponse.messages().stream() + .map(message -> new StreamMessage(message, null)) + .collect(Collectors.toList()); + return new ReceiveStreamMessageResponse(streamMessages); + }); + } + + // Remove before adding to avoid any duplicates + List messageAttributeNames = new ArrayList<>(receiveMessageRequest.messageAttributeNames()); + messageAttributeNames.removeAll(AmazonSQSExtendedClientUtil.RESERVED_ATTRIBUTE_NAMES); + messageAttributeNames.addAll(AmazonSQSExtendedClientUtil.RESERVED_ATTRIBUTE_NAMES); + receiveMessageRequestBuilder.messageAttributeNames(messageAttributeNames); + String queueUrl = receiveMessageRequest.queueUrl(); + receiveMessageRequest = receiveMessageRequestBuilder.build(); + + return super.receiveMessage(receiveMessageRequest) + .thenCompose(receiveMessageResponse -> { + List messages = receiveMessageResponse.messages(); + + // Check for no messages. If so, no need to process further. + if (messages.isEmpty()) { + return CompletableFuture.completedFuture(new ArrayList()); + } + + List> streamMessageFutures = new ArrayList<>(messages.size()); + for (Message message : messages) { + // For each received message check if they are stored in S3. + Optional largePayloadAttributeName = getReservedAttributeNameIfPresent( + message.messageAttributes()); + if (!largePayloadAttributeName.isPresent()) { + // Not S3 - create StreamMessage without stream + streamMessageFutures.add(CompletableFuture.completedFuture(new StreamMessage(message, null))); + } else { + // In S3 - get streaming access to payload + final String largeMessagePointer = message.body() + .replace("com.amazon.sqs.javamessaging.MessageS3Pointer", + "software.amazon.payloadoffloading.PayloadS3Pointer"); + + // Check if we have stream payload store for streaming + if (payloadStore instanceof StreamPayloadStoreAsync) { + StreamPayloadStoreAsync streamStore = (StreamPayloadStoreAsync) payloadStore; + + streamMessageFutures.add(streamStore.getOriginalPayloadStreamStream(largeMessagePointer) + .handle((stream, throwable) -> { + if (throwable != null) { + if (clientConfiguration.ignoresPayloadNotFound()) { + DeleteMessageRequest deleteMessageRequest = DeleteMessageRequest + .builder() + .queueUrl(queueUrl) + .receiptHandle(message.receiptHandle()) + .build(); + + deleteMessage(deleteMessageRequest).join(); + LOG.warn("Message deleted from SQS since payload with pointer could not be found in S3."); + return null; + } else { + throw new CompletionException(throwable); + } + } + + Message.Builder messageBuilder = message.toBuilder(); + // Remove the additional attribute before returning the message + Map messageAttributes = new HashMap<>( + message.messageAttributes()); + messageAttributes.keySet().removeAll(AmazonSQSExtendedClientUtil.RESERVED_ATTRIBUTE_NAMES); + messageBuilder.messageAttributes(messageAttributes); + + // Embed s3 object pointer in the receipt handle. + String modifiedReceiptHandle = embedS3PointerInReceiptHandle( + message.receiptHandle(), + largeMessagePointer); + messageBuilder.receiptHandle(modifiedReceiptHandle); + + return new StreamMessage(messageBuilder.build(), stream); + })); + } else { + // Fall back to regular payload retrieval if no streaming support + streamMessageFutures.add(payloadStore.getOriginalPayload(largeMessagePointer) + .handle((originalPayload, throwable) -> { + if (throwable != null) { + if (clientConfiguration.ignoresPayloadNotFound()) { + DeleteMessageRequest deleteMessageRequest = DeleteMessageRequest + .builder() + .queueUrl(queueUrl) + .receiptHandle(message.receiptHandle()) + .build(); + + deleteMessage(deleteMessageRequest).join(); + LOG.warn("Message deleted from SQS since payload with pointer could not be found in S3."); + return null; + } else { + throw new CompletionException(throwable); + } + } + + // Set original payload and create StreamMessage without stream + Message.Builder messageBuilder = message.toBuilder(); + messageBuilder.body(originalPayload); + + // Remove the additional attribute before returning the message + Map messageAttributes = new HashMap<>( + message.messageAttributes()); + messageAttributes.keySet().removeAll(AmazonSQSExtendedClientUtil.RESERVED_ATTRIBUTE_NAMES); + messageBuilder.messageAttributes(messageAttributes); + + // Embed s3 object pointer in the receipt handle. + String modifiedReceiptHandle = embedS3PointerInReceiptHandle( + message.receiptHandle(), + largeMessagePointer); + messageBuilder.receiptHandle(modifiedReceiptHandle); + + return new StreamMessage(messageBuilder.build(), null); + })); + } + } + } + + // Convert list of stream message futures to a future list of stream messages. + return CompletableFuture.allOf( + streamMessageFutures.toArray(new CompletableFuture[streamMessageFutures.size()])) + .thenApply(v -> streamMessageFutures.stream() + .map(CompletableFuture::join) + .filter(Objects::nonNull) + .collect(Collectors.toList())); + }) + .thenApply(streamMessages -> new ReceiveStreamMessageResponse(streamMessages)); + } + /** * {@inheritDoc} */ @@ -452,6 +603,69 @@ public CompletableFuture deleteMessageBatch( return super.deleteMessageBatch(deleteMessageBatchRequestBuilder.build()); } + /** + *

+ * Delivers a message to the specified queue using an InputStream for the message body. + * This method allows sending large messages without loading them entirely into memory. + *

+ * + * @param sendMessageRequest The send message request with message body as InputStream + * @param messageBodyStream InputStream containing the message body content + * @param contentLength The total length of the content in the stream + * @return CompletableFuture of the SendMessage operation returned by the service. + */ + public CompletableFuture sendStreamMessage(SendMessageRequest sendMessageRequest, + InputStream messageBodyStream, + long contentLength) { + if (sendMessageRequest == null) { + String errorMessage = "sendMessageRequest cannot be null."; + LOG.error(errorMessage); + throw SdkClientException.create(errorMessage); + } + + SendMessageRequest.Builder sendMessageRequestBuilder = sendMessageRequest.toBuilder(); + sendMessageRequest = appendUserAgent(sendMessageRequestBuilder).build(); + + if (!clientConfiguration.isPayloadSupportEnabled()) { + // Convert stream to string for non-extended client + try { + String messageBody = IoUtils.toUtf8String(messageBodyStream); + SendMessageRequest finalRequest = sendMessageRequest.toBuilder().messageBody(messageBody).build(); + return super.sendMessage(finalRequest); + } catch (IOException e) { + CompletableFuture failedFuture = new CompletableFuture<>(); + failedFuture.completeExceptionally(new RuntimeException("Failed to read from InputStream", e)); + return failedFuture; + } + } + + if (messageBodyStream == null) { + String errorMessage = "messageBodyStream cannot be null."; + LOG.error(errorMessage); + throw SdkClientException.create(errorMessage); + } + + //Check message attributes for ExtendedClient related constraints + checkMessageAttributes(clientConfiguration.getPayloadSizeThreshold(), sendMessageRequest.messageAttributes()); + + if (clientConfiguration.isAlwaysThroughS3() + || contentLength >= clientConfiguration.getPayloadSizeThreshold()) { + return storeStreamMessageInS3(sendMessageRequest, messageBodyStream, contentLength) + .thenCompose(modifiedRequest -> super.sendMessage(modifiedRequest)); + } else { + // Convert stream to string for small messages + try { + String messageBody = IoUtils.toUtf8String(messageBodyStream); + SendMessageRequest finalRequest = sendMessageRequest.toBuilder().messageBody(messageBody).build(); + return super.sendMessage(finalRequest); + } catch (IOException e) { + CompletableFuture failedFuture = new CompletableFuture<>(); + failedFuture.completeExceptionally(new RuntimeException("Failed to read from InputStream", e)); + return failedFuture; + } + } + } + /** * {@inheritDoc} */ @@ -533,36 +747,61 @@ private CompletableFuture storeMessageInS3(SendMessageReques private CompletableFuture storeOriginalPayload(String messageContentStr) { String s3KeyPrefix = clientConfiguration.getS3KeyPrefix(); - String key = StringUtils.isBlank(s3KeyPrefix) ? UUID.randomUUID().toString() : s3KeyPrefix + UUID.randomUUID(); - - if (clientConfiguration.isMultipartUploadEnabled()) { - CompletableFuture multipartResult = tryMultipartUploadAsync(messageContentStr, key); - return multipartResult.thenCompose(result -> result != null ? - CompletableFuture.completedFuture(result) : - payloadStore.storeOriginalPayload(messageContentStr, key)); + if (StringUtils.isBlank(s3KeyPrefix)) { + return payloadStore.storeOriginalPayload(messageContentStr); } - - return payloadStore.storeOriginalPayload(messageContentStr, key); + return payloadStore.storeOriginalPayload(messageContentStr, s3KeyPrefix + UUID.randomUUID()); } - private CompletableFuture tryMultipartUploadAsync(String payload, String candidateKey) { - if (!(payloadStore instanceof software.amazon.payloadoffloading.MultipartPayloadStoreAsync)) { - return CompletableFuture.completedFuture(null); - } - long sizeBytes = Util.getStringSizeInBytes(payload); - if (sizeBytes < clientConfiguration.getMultipartUploadThreshold()) { - return CompletableFuture.completedFuture(null); - } - try { - return ((software.amazon.payloadoffloading.MultipartPayloadStoreAsync) payloadStore) - .storeOriginalPayloadMultipart(payload, candidateKey) + + private CompletableFuture storeStreamMessageInS3(SendMessageRequest sendMessageRequest, InputStream messageBodyStream, long contentLength) { + SendMessageRequest.Builder sendMessageRequestBuilder = sendMessageRequest.toBuilder(); + + sendMessageRequestBuilder.messageAttributes( + updateMessageAttributePayloadSize(sendMessageRequest.messageAttributes(), contentLength, + clientConfiguration.usesLegacyReservedAttributeName())); + + // Store the message content in S3. + return storeOriginalPayload(messageBodyStream) + .thenApply(largeMessagePointer -> { + sendMessageRequestBuilder.messageBody(largeMessagePointer); + return sendMessageRequestBuilder.build(); + }); + } + + private CompletableFuture storeOriginalPayload(InputStream messageContentStream) { + String s3KeyPrefix = clientConfiguration.getS3KeyPrefix(); + String key = StringUtils.isBlank(s3KeyPrefix) ? UUID.randomUUID().toString() : s3KeyPrefix + UUID.randomUUID(); + + if (payloadStore instanceof StreamPayloadStoreAsync) { + return ((StreamPayloadStoreAsync) payloadStore) + .storeOriginalPayloadStream(messageContentStream, key) .exceptionally(ex -> { - LOG.warn("Multipart upload attempt failed; falling back to standard single-part upload."); - return null; + LOG.warn("Stream upload attempt failed; falling back to standard single-part upload."); + try { + messageContentStream.reset(); + } catch (Exception resetEx) { + LOG.warn("Failed to reset stream after multipart upload failure, stream may be exhausted", resetEx); + throw new RuntimeException("Stream upload failed and stream cannot be reset", ex); + } + // Fall back to reading the stream and using string method + try { + String content = IoUtils.toUtf8String(messageContentStream); + return payloadStore.storeOriginalPayload(content, key).join(); + } catch (IOException ioEx) { + throw new RuntimeException("Failed to read from InputStream", ioEx); + } }); - } catch (RuntimeException e) { - LOG.warn("Multipart upload attempt failed; falling back to standard single-part upload."); - return CompletableFuture.completedFuture(null); + } + + // Fall back to reading the stream and using string method + try { + String content = IoUtils.toUtf8String(messageContentStream); + return payloadStore.storeOriginalPayload(content, key); + } catch (IOException e) { + CompletableFuture failedFuture = new CompletableFuture<>(); + failedFuture.completeExceptionally(new RuntimeException("Failed to read from InputStream", e)); + return failedFuture; } } diff --git a/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClient.java b/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClient.java index 42f3554..7fecea3 100644 --- a/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClient.java +++ b/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClient.java @@ -24,6 +24,7 @@ import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClientUtil.isS3ReceiptHandle; import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClientUtil.updateMessageAttributePayloadSize; +import java.net.URI; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -33,11 +34,19 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; + +import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; +import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider; +import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; import software.amazon.awssdk.awscore.AwsRequest; import software.amazon.awssdk.awscore.exception.AwsServiceException; +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.exception.SdkException; import software.amazon.awssdk.core.util.VersionInfo; +import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.s3.model.NoSuchKeyException; import software.amazon.awssdk.services.sqs.SqsClient; import software.amazon.awssdk.services.sqs.model.BatchEntryIdsNotDistinctException; @@ -75,13 +84,17 @@ import software.amazon.awssdk.services.sqs.model.SqsException; import software.amazon.awssdk.services.sqs.model.TooManyEntriesInBatchRequestException; import software.amazon.awssdk.utils.StringUtils; +import software.amazon.awssdk.utils.IoUtils; import software.amazon.payloadoffloading.PayloadStore; -import software.amazon.payloadoffloading.MultipartPayloadStore; +import software.amazon.payloadoffloading.StreamPayloadStore; import software.amazon.payloadoffloading.S3BackedPayloadStore; -import software.amazon.payloadoffloading.S3BackedMultipartPayloadStore; +import software.amazon.payloadoffloading.S3BackedStreamPayloadStore; import software.amazon.payloadoffloading.S3Dao; import software.amazon.payloadoffloading.Util; +import java.io.InputStream; +import java.io.ByteArrayInputStream; +import java.io.IOException; /** * Amazon SQS Extended Client extends the functionality of Amazon SQS client. @@ -144,16 +157,26 @@ public AmazonSQSExtendedClient(SqsClient sqsClient) { public AmazonSQSExtendedClient(SqsClient sqsClient, ExtendedClientConfiguration extendedClientConfig) { super(sqsClient); this.clientConfiguration = new ExtendedClientConfiguration(extendedClientConfig); - S3Dao s3Dao = new S3Dao(clientConfiguration.getS3Client(), - clientConfiguration.getServerSideEncryptionStrategy(), - clientConfiguration.getObjectCannedACL()); - if (clientConfiguration.isMultipartUploadEnabled()) { - this.payloadStore = new S3BackedMultipartPayloadStore( - s3Dao, - clientConfiguration.getS3BucketName(), - clientConfiguration.getMultipartUploadPartSize(), - clientConfiguration.getMultipartUploadThreshold()); + if (clientConfiguration.isStreamUploadEnabled()) { + S3AsyncClient s3AsyncClient = S3AsyncClient.builder() + .multipartEnabled(true) + .multipartConfiguration( + multipartConfig -> multipartConfig + .minimumPartSizeInBytes(clientConfiguration.getStreamUploadPartSize()) + .thresholdInBytes(clientConfiguration.getStreamUploadThreshold()) + ) + // .credentialsProvider(DefaultCredentialsProvider.create()) + .endpointOverride(URI.create("http://localhost:4566")) + .forcePathStyle(true) + .credentialsProvider(StaticCredentialsProvider.create( + AwsBasicCredentials.create("dummy", "dummy"))) + .region(Region.of(clientConfiguration.getS3Region())) + .build(); + + S3Dao s3Dao = new S3Dao(clientConfiguration.getS3Client(), s3AsyncClient, clientConfiguration.getServerSideEncryptionStrategy(), clientConfiguration.getObjectCannedACL()); + this.payloadStore = new S3BackedStreamPayloadStore(s3Dao, clientConfiguration.getS3BucketName()); } else { + S3Dao s3Dao = new S3Dao(clientConfiguration.getS3Client(), clientConfiguration.getS3AsyncClient(), clientConfiguration.getServerSideEncryptionStrategy(), clientConfiguration.getObjectCannedACL()); this.payloadStore = new S3BackedPayloadStore(s3Dao, clientConfiguration.getS3BucketName()); } } @@ -386,6 +409,88 @@ public ReceiveMessageResponse receiveMessage(ReceiveMessageRequest receiveMessag return receiveMessageResponseBuilder.build(); } + public ReceiveStreamMessageResponse receiveMessageAsStream(ReceiveMessageRequest receiveMessageRequest) { + //TODO: Clone request since it's modified in this method and will cause issues if the client reuses request object. + if (receiveMessageRequest == null) { + String errorMessage = "receiveMessageRequest cannot be null."; + LOG.error(errorMessage); + throw SdkClientException.create(errorMessage); + } + + ReceiveMessageRequest.Builder receiveMessageRequestBuilder = receiveMessageRequest.toBuilder(); + appendUserAgent(receiveMessageRequestBuilder); + + if (!clientConfiguration.isPayloadSupportEnabled()) { + // If payload support is disabled, fall back to regular receive and wrap in StreamMessage with null streams + ReceiveMessageResponse response = super.receiveMessage(receiveMessageRequestBuilder.build()); + List streamMessages = response.messages().stream() + .map(message -> new StreamMessage(message, null)) + .collect(java.util.stream.Collectors.toList()); + return new ReceiveStreamMessageResponse(streamMessages); + } + + //Remove before adding to avoid any duplicates + List messageAttributeNames = new ArrayList<>(receiveMessageRequest.messageAttributeNames()); + messageAttributeNames.removeAll(AmazonSQSExtendedClientUtil.RESERVED_ATTRIBUTE_NAMES); + messageAttributeNames.addAll(AmazonSQSExtendedClientUtil.RESERVED_ATTRIBUTE_NAMES); + receiveMessageRequestBuilder.messageAttributeNames(messageAttributeNames); + receiveMessageRequest = receiveMessageRequestBuilder.build(); + + ReceiveMessageResponse receiveMessageResponse = super.receiveMessage(receiveMessageRequest); + List streamMessages = new ArrayList<>(receiveMessageResponse.messages().size()); + + for (Message message : receiveMessageResponse.messages()) { + Message.Builder messageBuilder = message.toBuilder(); + ResponseInputStream payloadStream = null; + + // for each received message check if they are stored in S3. + Optional largePayloadAttributeName = getReservedAttributeNameIfPresent(message.messageAttributes()); + if (largePayloadAttributeName.isPresent()) { + String largeMessagePointer = message.body(); + largeMessagePointer = largeMessagePointer.replace("com.amazon.sqs.javamessaging.MessageS3Pointer", "software.amazon.payloadoffloading.PayloadS3Pointer"); + + try { + if (payloadStore instanceof StreamPayloadStore) { + payloadStream = ((StreamPayloadStore) payloadStore).getOriginalPayloadStreamStream(largeMessagePointer); + } else { + // Fallback: load into memory and create a stream from it + System.out.println("Warning: payload store is not a StreamPayloadStore, loading entire payload into memory"); + String payload = payloadStore.getOriginalPayload(largeMessagePointer); + ByteArrayInputStream byteStream = new ByteArrayInputStream(payload.getBytes(java.nio.charset.StandardCharsets.UTF_8)); + payloadStream = new ResponseInputStream<>(GetObjectResponse.builder().build(), byteStream); + } + } catch (SdkException e) { + if (e.getCause() instanceof NoSuchKeyException && clientConfiguration.ignoresPayloadNotFound()) { + DeleteMessageRequest deleteMessageRequest = DeleteMessageRequest + .builder() + .queueUrl(receiveMessageRequest.queueUrl()) + .receiptHandle(message.receiptHandle()) + .build(); + deleteMessage(deleteMessageRequest); + LOG.warn("Message deleted from SQS since payload with pointer could not be found in S3."); + continue; + } else throw e; + } + + // remove the additional attribute before returning the message to user + Map messageAttributes = new HashMap<>(message.messageAttributes()); + messageAttributes.keySet().removeAll(AmazonSQSExtendedClientUtil.RESERVED_ATTRIBUTE_NAMES); + messageBuilder.messageAttributes(messageAttributes); + + // Embed s3 object pointer in the receipt handle. + String modifiedReceiptHandle = embedS3PointerInReceiptHandle( + message.receiptHandle(), + largeMessagePointer); + + messageBuilder.receiptHandle(modifiedReceiptHandle); + } + + streamMessages.add(new StreamMessage(messageBuilder.build(), payloadStream)); + } + + return new ReceiveStreamMessageResponse(streamMessages); + } + /** *

* Deletes the specified message from the specified queue. To select the message to delete, use the @@ -662,6 +767,150 @@ public SendMessageBatchResponse sendMessageBatch(SendMessageBatchRequest sendMes return super.sendMessageBatch(sendMessageBatchRequest); } + /** + *

+ * Delivers a message to the specified queue using an InputStream for the message body. + * This method allows sending large messages without loading them entirely into memory. + *

+ * + * @param sendMessageRequest The send message request with message body as InputStream + * @param messageBodyStream InputStream containing the message body content + * @param contentLength The total length of the content in the stream + * @return Result of the SendMessage operation returned by the service. + */ + public SendMessageResponse sendStreamMessage(SendMessageRequest sendMessageRequest, InputStream messageBodyStream, long contentLength) { + if (sendMessageRequest == null) { + String errorMessage = "sendMessageRequest cannot be null."; + LOG.error(errorMessage); + throw SdkClientException.create(errorMessage); + } + + SendMessageRequest.Builder sendMessageRequestBuilder = sendMessageRequest.toBuilder(); + sendMessageRequest = appendUserAgent(sendMessageRequestBuilder).build(); + + if (!clientConfiguration.isPayloadSupportEnabled()) { + // Convert stream to string for non-extended client + try { + String messageBody = software.amazon.awssdk.utils.IoUtils.toUtf8String(messageBodyStream); + sendMessageRequest = sendMessageRequest.toBuilder().messageBody(messageBody).build(); + return super.sendMessage(sendMessageRequest); + } catch (IOException e) { + throw new RuntimeException("Failed to read from InputStream", e); + } + } + + if (messageBodyStream == null) { + String errorMessage = "messageBodyStream cannot be null."; + LOG.error(errorMessage); + throw SdkClientException.create(errorMessage); + } + + //Check message attributes for ExtendedClient related constraints + checkMessageAttributes(clientConfiguration.getPayloadSizeThreshold(), sendMessageRequest.messageAttributes()); + + if (clientConfiguration.isAlwaysThroughS3() + || contentLength >= clientConfiguration.getPayloadSizeThreshold()) { + sendMessageRequest = storeStreamMessageInS3(sendMessageRequest, messageBodyStream, contentLength); + } else { + // Convert stream to string for small messages + try { + String messageBody = software.amazon.awssdk.utils.IoUtils.toUtf8String(messageBodyStream); + sendMessageRequest = sendMessageRequest.toBuilder().messageBody(messageBody).build(); + } catch (IOException e) { + throw new RuntimeException("Failed to read from InputStream", e); + } + } + return super.sendMessage(sendMessageRequest); + } + + /** + *

+ * Delivers up to ten messages to the specified queue using InputStreams for message bodies. + * This method allows sending large messages without loading them entirely into memory. + *

+ * + * @param sendMessageBatchRequest The send message batch request + * @param messageBodyStreams List of InputStreams containing message body content for each entry + * @param contentLengths List of content lengths corresponding to each stream + * @return Result of the SendMessageBatch operation returned by the service. + */ + public SendMessageBatchResponse sendStreamMessageBatch(SendMessageBatchRequest sendMessageBatchRequest, + java.util.List messageBodyStreams, + java.util.List contentLengths) { + + if (sendMessageBatchRequest == null) { + String errorMessage = "sendMessageBatchRequest cannot be null."; + LOG.error(errorMessage); + throw SdkClientException.create(errorMessage); + } + + if (messageBodyStreams == null || contentLengths == null) { + String errorMessage = "messageBodyStreams and contentLengths cannot be null."; + LOG.error(errorMessage); + throw SdkClientException.create(errorMessage); + } + + if (messageBodyStreams.size() != sendMessageBatchRequest.entries().size() || + contentLengths.size() != sendMessageBatchRequest.entries().size()) { + String errorMessage = "messageBodyStreams and contentLengths must have the same size as batch entries."; + LOG.error(errorMessage); + throw SdkClientException.create(errorMessage); + } + + SendMessageBatchRequest.Builder sendMessageBatchRequestBuilder = sendMessageBatchRequest.toBuilder(); + appendUserAgent(sendMessageBatchRequestBuilder); + sendMessageBatchRequest = sendMessageBatchRequestBuilder.build(); + + if (!clientConfiguration.isPayloadSupportEnabled()) { + // Convert streams to strings for non-extended client + java.util.List entries = new java.util.ArrayList<>(); + for (int i = 0; i < sendMessageBatchRequest.entries().size(); i++) { + SendMessageBatchRequestEntry entry = sendMessageBatchRequest.entries().get(i); + try { + String messageBody = software.amazon.awssdk.utils.IoUtils.toUtf8String(messageBodyStreams.get(i)); + entries.add(entry.toBuilder().messageBody(messageBody).build()); + } catch (IOException e) { + throw new RuntimeException("Failed to read from InputStream", e); + } + } + sendMessageBatchRequest = sendMessageBatchRequest.toBuilder().entries(entries).build(); + return super.sendMessageBatch(sendMessageBatchRequest); + } + + List batchEntries = new ArrayList<>(sendMessageBatchRequest.entries().size()); + + boolean hasS3Entries = false; + for (int i = 0; i < sendMessageBatchRequest.entries().size(); i++) { + SendMessageBatchRequestEntry entry = sendMessageBatchRequest.entries().get(i); + InputStream stream = messageBodyStreams.get(i); + long contentLength = contentLengths.get(i); + + //Check message attributes for ExtendedClient related constraints + checkMessageAttributes(clientConfiguration.getPayloadSizeThreshold(), entry.messageAttributes()); + + if (clientConfiguration.isAlwaysThroughS3() + || contentLength >= clientConfiguration.getPayloadSizeThreshold()) { + entry = storeStreamMessageInS3(entry, stream, contentLength); + hasS3Entries = true; + } else { + // Convert stream to string for small messages + try { + String messageBody = software.amazon.awssdk.utils.IoUtils.toUtf8String(stream); + entry = entry.toBuilder().messageBody(messageBody).build(); + } catch (IOException e) { + throw new RuntimeException("Failed to read from InputStream", e); + } + } + batchEntries.add(entry); + } + + if (hasS3Entries) { + sendMessageBatchRequest = sendMessageBatchRequest.toBuilder().entries(batchEntries).build(); + } + + return super.sendMessageBatch(sendMessageBatchRequest); + } + /** *

* Deletes up to ten messages from the specified queue. This is a batch version of @@ -904,33 +1153,76 @@ private SendMessageRequest storeMessageInS3(SendMessageRequest sendMessageReques } private String storeOriginalPayload(String messageContentStr) { + String s3KeyPrefix = clientConfiguration.getS3KeyPrefix(); + if (StringUtils.isBlank(s3KeyPrefix)) { + return payloadStore.storeOriginalPayload(messageContentStr); + } + return payloadStore.storeOriginalPayload(messageContentStr, s3KeyPrefix + UUID.randomUUID()); + } + + private SendMessageRequest storeStreamMessageInS3(SendMessageRequest sendMessageRequest, InputStream messageBodyStream, long contentLength) { + SendMessageRequest.Builder sendMessageRequestBuilder = sendMessageRequest.toBuilder(); + + sendMessageRequestBuilder.messageAttributes( + updateMessageAttributePayloadSize(sendMessageRequest.messageAttributes(), contentLength, + clientConfiguration.usesLegacyReservedAttributeName())); + + // Store the message content in S3. + String largeMessagePointer = storeOriginalPayload(messageBodyStream); + sendMessageRequestBuilder.messageBody(largeMessagePointer); + + return sendMessageRequestBuilder.build(); + } + + private SendMessageBatchRequestEntry storeStreamMessageInS3(SendMessageBatchRequestEntry batchEntry, InputStream messageBodyStream, long contentLength) { + SendMessageBatchRequestEntry.Builder batchEntryBuilder = batchEntry.toBuilder(); + + batchEntryBuilder.messageAttributes( + updateMessageAttributePayloadSize(batchEntry.messageAttributes(), contentLength, + clientConfiguration.usesLegacyReservedAttributeName())); + + // Store the message content in S3. + String largeMessagePointer = storeOriginalPayload(messageBodyStream); + batchEntryBuilder.messageBody(largeMessagePointer); + + return batchEntryBuilder.build(); + } + + private String storeOriginalPayload(InputStream messageContentStream) { String s3KeyPrefix = clientConfiguration.getS3KeyPrefix(); String key = StringUtils.isBlank(s3KeyPrefix) ? UUID.randomUUID().toString() : s3KeyPrefix + UUID.randomUUID(); - if (clientConfiguration.isMultipartUploadEnabled()) { - String multipartResult = tryMultipartUpload(messageContentStr, key); - if (multipartResult != null) { - return multipartResult; + if (payloadStore instanceof StreamPayloadStore) { + try { + return ((StreamPayloadStore) payloadStore).storeOriginalPayloadStream(messageContentStream, key); + } catch (RuntimeException e) { + LOG.warn("Stream upload attempt failed; falling back to standard single-part upload."); + try { + messageContentStream.reset(); + } catch (Exception resetEx) { + LOG.warn("Failed to reset stream after multipart upload failure, stream may be exhausted", resetEx); + throw e; + } + // Fall back to reading the stream and using string method + try { + String content = IoUtils.toUtf8String(messageContentStream); + return payloadStore.storeOriginalPayload(content, key); + } catch (IOException ioEx) { + throw new RuntimeException("Failed to read from InputStream", ioEx); + } } } - return payloadStore.storeOriginalPayload(messageContentStr, key); - } - - private String tryMultipartUpload(String payload, String candidateKey) { - long sizeBytes = Util.getStringSizeInBytes(payload); - if (sizeBytes < clientConfiguration.getMultipartUploadThreshold()) { - return null; - } - + // Fall back to reading the stream and using string method try { - return ((MultipartPayloadStore) payloadStore).storeOriginalPayloadMultipart(payload, candidateKey); - } catch (RuntimeException e) { - LOG.warn("Multipart upload attempt failed; falling back to standard single-part upload."); - return null; + String content = IoUtils.toUtf8String(messageContentStream); + return payloadStore.storeOriginalPayload(content, key); + } catch (IOException e) { + throw new RuntimeException("Failed to read from InputStream", e); } } + @SuppressWarnings("unchecked") private static T appendUserAgent(final T builder) { return AmazonSQSExtendedClientUtil.appendUserAgent(builder, USER_AGENT_NAME, USER_AGENT_VERSION); } diff --git a/src/main/java/com/amazon/sqs/javamessaging/ExtendedAsyncClientConfiguration.java b/src/main/java/com/amazon/sqs/javamessaging/ExtendedAsyncClientConfiguration.java index 09d192f..baf7d6e 100644 --- a/src/main/java/com/amazon/sqs/javamessaging/ExtendedAsyncClientConfiguration.java +++ b/src/main/java/com/amazon/sqs/javamessaging/ExtendedAsyncClientConfiguration.java @@ -214,27 +214,32 @@ public ExtendedAsyncClientConfiguration withServerSideEncryption(ServerSideEncry } /** - * Enables or disables multipart upload support for large payload storage operations. - * @param enabled true to enable multipart uploads when threshold exceeded. + * Enables or disables stream upload support for large payload storage operations. + * @param enabled true to enable stream uploads when threshold exceeded. * @return updated configuration */ - public ExtendedAsyncClientConfiguration withMultipartUploadEnabled(boolean enabled) { - setMultipartUploadEnabled(enabled); + public ExtendedAsyncClientConfiguration withStreamUploadEnabled(boolean enabled) { + setStreamUploadEnabled(enabled); return this; } /** - * Sets the multipart upload threshold (in bytes). Only used when multipart upload is enabled. - * @param threshold threshold in bytes (>0). Values <=0 reset to default (5MB) + * Sets the threshold for stream upload in bytes. + * @param threshold the threshold in bytes * @return updated configuration */ - public ExtendedAsyncClientConfiguration withMultipartUploadThreshold(int threshold) { - setMultipartUploadThreshold(threshold); + public ExtendedAsyncClientConfiguration withStreamUploadThreshold(int threshold) { + setStreamUploadThreshold(threshold); return this; } - public ExtendedAsyncClientConfiguration withMultipartUploadPartSize(int partSize) { - setMultipartUploadPartSize(partSize); + public ExtendedAsyncClientConfiguration withStreamUploadPartSize(int partSize) { + setStreamUploadPartSize(partSize); + return this; + } + + public ExtendedAsyncClientConfiguration withS3Region(String s3Region) { + setS3Region(s3Region); return this; } } \ No newline at end of file diff --git a/src/main/java/com/amazon/sqs/javamessaging/ExtendedClientConfiguration.java b/src/main/java/com/amazon/sqs/javamessaging/ExtendedClientConfiguration.java index 49ab26b..a8ddcfc 100644 --- a/src/main/java/com/amazon/sqs/javamessaging/ExtendedClientConfiguration.java +++ b/src/main/java/com/amazon/sqs/javamessaging/ExtendedClientConfiguration.java @@ -233,32 +233,27 @@ public ExtendedClientConfiguration withServerSideEncryption(ServerSideEncryption } /** - * Enables or disables multipart upload support for large payload storage operations. - * @param enabled true to enable multipart uploads when threshold exceeded. + * Enables or disables stream upload support for large payload storage operations. + * @param enabled true to enable stream uploads when threshold exceeded. * @return updated configuration */ - public ExtendedClientConfiguration withMultipartUploadEnabled(boolean enabled) { - setMultipartUploadEnabled(enabled); + public ExtendedClientConfiguration withStreamUploadEnabled(boolean enabled) { + setStreamUploadEnabled(enabled); return this; } - /** - * Sets the multipart upload threshold (in bytes). Only used when multipart upload is enabled. - * @param threshold threshold in bytes (>0). Values <=0 reset to default (5MB) - * @return updated configuration - */ - public ExtendedClientConfiguration withMultipartUploadThreshold(int threshold) { - setMultipartUploadThreshold(threshold); + public ExtendedClientConfiguration withStreamUploadThreshold(int threshold) { + setStreamUploadThreshold(threshold); return this; } + public ExtendedClientConfiguration withStreamUploadPartSize(int partSize) { + setStreamUploadPartSize(partSize); + return this; + } - /** - * Sets the multipart upload part size (in bytes). Only used when multipart upload is enabled. - * @param partSize part size in bytes (>0). Values <=0 reset to default (5MB) - */ - public ExtendedClientConfiguration withMultipartUploadPartSize(int partSize) { - setMultipartUploadPartSize(partSize); + public ExtendedClientConfiguration withS3Region(String s3Region) { + setS3Region(s3Region); return this; } diff --git a/src/main/java/com/amazon/sqs/javamessaging/ReceiveStreamMessageResponse.java b/src/main/java/com/amazon/sqs/javamessaging/ReceiveStreamMessageResponse.java new file mode 100644 index 0000000..c93d458 --- /dev/null +++ b/src/main/java/com/amazon/sqs/javamessaging/ReceiveStreamMessageResponse.java @@ -0,0 +1,21 @@ +package com.amazon.sqs.javamessaging; + +import java.util.List; + +/** + * Response containing messages with streaming access to large payloads. + */ +public class ReceiveStreamMessageResponse { + private final List streamMessages; + + public ReceiveStreamMessageResponse(List streamMessages) { + this.streamMessages = streamMessages; + } + + /** + * @return list of messages with streaming payload access + */ + public List streamMessages() { + return streamMessages; + } +} \ No newline at end of file diff --git a/src/main/java/com/amazon/sqs/javamessaging/StreamMessage.java b/src/main/java/com/amazon/sqs/javamessaging/StreamMessage.java new file mode 100644 index 0000000..9f08986 --- /dev/null +++ b/src/main/java/com/amazon/sqs/javamessaging/StreamMessage.java @@ -0,0 +1,40 @@ +package com.amazon.sqs.javamessaging; + +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.services.sqs.model.Message; + +/** + * Represents a message received from SQS that can provide streaming access to large payloads stored in S3. + * Instead of loading the entire payload into memory, this provides access to the payload as a stream. + */ +public class StreamMessage { + private final Message originalMessage; + private final ResponseInputStream payloadStream; + + public StreamMessage(Message originalMessage, ResponseInputStream payloadStream) { + this.originalMessage = originalMessage; + this.payloadStream = payloadStream; + } + + /** + * @return the original SQS message metadata (messageId, receiptHandle, attributes, etc.) + */ + public Message getMessage() { + return originalMessage; + } + + /** + * @return stream for accessing the message payload, or null if payload is not stored in S3 + */ + public ResponseInputStream getPayloadStream() { + return payloadStream; + } + + /** + * @return true if this message has a streaming payload from S3 + */ + public boolean hasStreamPayload() { + return payloadStream != null; + } +} \ No newline at end of file diff --git a/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedAsyncClientTest.java b/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedAsyncClientTest.java index 36bd9c9..ff941c8 100644 --- a/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedAsyncClientTest.java +++ b/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedAsyncClientTest.java @@ -6,9 +6,12 @@ import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNotEquals; import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.isA; import static org.mockito.Mockito.doThrow; @@ -20,6 +23,8 @@ import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; +import java.io.IOException; +import java.io.InputStream; import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Arrays; @@ -38,6 +43,7 @@ import software.amazon.awssdk.core.ResponseBytes; import software.amazon.awssdk.core.async.AsyncRequestBody; import software.amazon.awssdk.core.async.AsyncResponseTransformer; +import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.exception.SdkException; import software.amazon.awssdk.services.s3.S3AsyncClient; import software.amazon.awssdk.services.s3.model.DeleteObjectRequest; @@ -63,10 +69,14 @@ import software.amazon.awssdk.services.sqs.model.SendMessageBatchResponse; import software.amazon.awssdk.services.sqs.model.SendMessageRequest; import software.amazon.awssdk.services.sqs.model.SendMessageResponse; -import software.amazon.awssdk.utils.ImmutableMap; +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.payloadoffloading.StreamPayloadStoreAsync; +import software.amazon.payloadoffloading.PayloadStoreAsync; import software.amazon.payloadoffloading.PayloadS3Pointer; import software.amazon.payloadoffloading.ServerSideEncryptionFactory; import software.amazon.payloadoffloading.ServerSideEncryptionStrategy; +import software.amazon.awssdk.utils.ImmutableMap; public class AmazonSQSExtendedAsyncClientTest { @@ -89,10 +99,10 @@ public class AmazonSQSExtendedAsyncClientTest { // should be > 1 and << SQS_SIZE_LIMIT private static final int ARBITRARY_SMALLER_THRESHOLD = 500; - // Multipart upload thresholds for testing - private static final int MULTIPART_UPLOAD_THRESHOLD = 5 * 1024 * 1024; // 5MB default - private static final int LESS_THAN_MULTIPART_THRESHOLD = MULTIPART_UPLOAD_THRESHOLD - 1; - private static final int MORE_THAN_MULTIPART_THRESHOLD = MULTIPART_UPLOAD_THRESHOLD + 1; + // Stream upload thresholds + private static final int STREAM_UPLOAD_THRESHOLD = 5 * 1024 * 1024; // 5MB default + private static final int LESS_THAN_STREAM_THRESHOLD = STREAM_UPLOAD_THRESHOLD - 1; + private static final int MORE_THAN_STREAM_THRESHOLD = STREAM_UPLOAD_THRESHOLD + 1; @BeforeEach public void setupClients() { @@ -736,191 +746,341 @@ private String generateStringWithLength(int messageLength) { } @Test - public void testWhenMultipartUploadDisabledThenStandardUploadIsUsed() { - String messageBody = generateStringWithLength(MORE_THAN_MULTIPART_THRESHOLD); - ExtendedAsyncClientConfiguration extendedClientConfiguration = new ExtendedAsyncClientConfiguration() - .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME) - .withMultipartUploadEnabled(false); - SqsAsyncClient sqsExtended = spy(new AmazonSQSExtendedAsyncClient(mockSqsBackend, extendedClientConfiguration)); + public void testReceiveMessageAsStream_PayloadSupportDisabled_ReturnsMessagesWithoutStreams() { + ExtendedAsyncClientConfiguration config = new ExtendedAsyncClientConfiguration() + .withPayloadSupportDisabled(); + AmazonSQSExtendedAsyncClient clientWithDisabledPayload = new AmazonSQSExtendedAsyncClient(mockSqsBackend, config); - SendMessageRequest messageRequest = SendMessageRequest.builder().queueUrl(SQS_QUEUE_URL).messageBody(messageBody).build(); - sqsExtended.sendMessage(messageRequest).join(); + ReceiveMessageRequest request = ReceiveMessageRequest.builder() + .queueUrl(SQS_QUEUE_URL) + .build(); - verify(mockS3, times(1)).putObject(isA(PutObjectRequest.class), isA(AsyncRequestBody.class)); - } + Message message = Message.builder() + .messageId("msg1") + .body("small message") + .receiptHandle("receipt1") + .build(); - @Test - public void testWhenMultipartUploadEnabledAndMessageBelowThresholdThenStandardUploadIsUsed() { - String messageBody = generateStringWithLength(LESS_THAN_MULTIPART_THRESHOLD); - ExtendedAsyncClientConfiguration extendedClientConfiguration = new ExtendedAsyncClientConfiguration() - .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME) - .withMultipartUploadEnabled(true) - .withMultipartUploadThreshold(MULTIPART_UPLOAD_THRESHOLD); - SqsAsyncClient sqsExtended = spy(new AmazonSQSExtendedAsyncClient(mockSqsBackend, extendedClientConfiguration)); + ReceiveMessageResponse sqsResponse = ReceiveMessageResponse.builder() + .messages(message) + .build(); - SendMessageRequest messageRequest = SendMessageRequest.builder().queueUrl(SQS_QUEUE_URL).messageBody(messageBody).build(); - sqsExtended.sendMessage(messageRequest).join(); + when(mockSqsBackend.receiveMessage(any(ReceiveMessageRequest.class))) + .thenReturn(CompletableFuture.completedFuture(sqsResponse)); - verify(mockS3, times(1)).putObject(isA(PutObjectRequest.class), isA(AsyncRequestBody.class)); + CompletableFuture future = clientWithDisabledPayload.receiveMessageAsStream(request); + ReceiveStreamMessageResponse response = future.join(); + + assertEquals(1, response.streamMessages().size()); + StreamMessage streamMessage = response.streamMessages().get(0); + assertEquals("msg1", streamMessage.getMessage().messageId()); + assertEquals("small message", streamMessage.getMessage().body()); + assertFalse(streamMessage.hasStreamPayload()); } @Test - public void testWhenMultipartUploadEnabledAndMessageAboveThresholdThenMultipartIsAttempted() { - String messageBody = generateStringWithLength(MORE_THAN_MULTIPART_THRESHOLD); - ExtendedAsyncClientConfiguration extendedClientConfiguration = new ExtendedAsyncClientConfiguration() - .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME) - .withMultipartUploadEnabled(true) - .withMultipartUploadThreshold(MULTIPART_UPLOAD_THRESHOLD); - SqsAsyncClient sqsExtended = spy(new AmazonSQSExtendedAsyncClient(mockSqsBackend, extendedClientConfiguration)); + public void testReceiveMessageAsStream_LargeMessage_WithStreamStore_ReturnsMessageWithStream() throws IOException { + ExtendedAsyncClientConfiguration config = new ExtendedAsyncClientConfiguration() + .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME) + .withPayloadSizeThreshold(262144) // 256KB + .withStreamUploadEnabled(true) + .withStreamUploadThreshold(1024 * 1024); // 1MB - SendMessageRequest messageRequest = SendMessageRequest.builder().queueUrl(SQS_QUEUE_URL).messageBody(messageBody).build(); - sqsExtended.sendMessage(messageRequest).join(); + AmazonSQSExtendedAsyncClient clientWithStream = new AmazonSQSExtendedAsyncClient(mockSqsBackend, config); - verify(mockSqsBackend, times(1)).sendMessage(isA(SendMessageRequest.class)); - } + ReceiveMessageRequest request = ReceiveMessageRequest.builder() + .queueUrl(SQS_QUEUE_URL) + .build(); - @Test - public void testWhenMultipartUploadEnabledWithCustomThresholdThenThresholdIsHonored() { - int customThreshold = 1024 * 1024; // 1MB - int messageLength = customThreshold + 1000; // Just above custom threshold - String messageBody = generateStringWithLength(messageLength); + String s3Pointer = new PayloadS3Pointer(S3_BUCKET_NAME, "test-key").toJson(); + Message message = Message.builder() + .messageId("msg1") + .body(s3Pointer) + .receiptHandle("receipt1") + .messageAttributes(ImmutableMap.of( + "ExtendedPayloadSize", MessageAttributeValue.builder().stringValue("300000").dataType("Number").build() + )) + .build(); - ExtendedAsyncClientConfiguration extendedClientConfiguration = new ExtendedAsyncClientConfiguration() - .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME) - .withMultipartUploadEnabled(true) - .withMultipartUploadThreshold(customThreshold); - SqsAsyncClient sqsExtended = spy(new AmazonSQSExtendedAsyncClient(mockSqsBackend, extendedClientConfiguration)); + ReceiveMessageResponse sqsResponse = ReceiveMessageResponse.builder() + .messages(message) + .build(); - SendMessageRequest messageRequest = SendMessageRequest.builder().queueUrl(SQS_QUEUE_URL).messageBody(messageBody).build(); - sqsExtended.sendMessage(messageRequest).join(); + @SuppressWarnings("unchecked") + ResponseInputStream mockStream = mock(ResponseInputStream.class); + when(mockStream.read(any(byte[].class))).thenReturn(-1); // End of stream - verify(mockSqsBackend, times(1)).sendMessage(isA(SendMessageRequest.class)); - } + when(mockSqsBackend.receiveMessage(any(ReceiveMessageRequest.class))) + .thenReturn(CompletableFuture.completedFuture(sqsResponse)); - @Test - public void testWhenMultipartUploadEnabledWithAlwaysThroughS3ThenSmallMessagesAlsoUseS3() { - String messageBody = generateStringWithLength(LESS_THAN_SQS_SIZE_LIMIT); - ExtendedAsyncClientConfiguration extendedClientConfiguration = new ExtendedAsyncClientConfiguration() - .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME) - .withMultipartUploadEnabled(true) - .withMultipartUploadThreshold(MULTIPART_UPLOAD_THRESHOLD) - .withAlwaysThroughS3(true); - SqsAsyncClient sqsExtended = spy(new AmazonSQSExtendedAsyncClient(mockSqsBackend, extendedClientConfiguration)); + @SuppressWarnings("unchecked") + CompletableFuture> futureStream = CompletableFuture.completedFuture(mockStream); + when(mockS3.getObject(any(GetObjectRequest.class), any(AsyncResponseTransformer.class))) + .thenReturn(futureStream); - SendMessageRequest messageRequest = SendMessageRequest.builder().queueUrl(SQS_QUEUE_URL).messageBody(messageBody).build(); - sqsExtended.sendMessage(messageRequest).join(); + CompletableFuture future = clientWithStream.receiveMessageAsStream(request); + ReceiveStreamMessageResponse response = future.join(); - verify(mockS3, times(1)).putObject(isA(PutObjectRequest.class), isA(AsyncRequestBody.class)); + assertEquals(1, response.streamMessages().size()); + StreamMessage streamMessage = response.streamMessages().get(0); + assertEquals("msg1", streamMessage.getMessage().messageId()); + assertTrue(streamMessage.hasStreamPayload()); + assertSame(mockStream, streamMessage.getPayloadStream()); + + // Verify receipt handle was modified + assertTrue(streamMessage.getMessage().receiptHandle().contains("test-key")); } @Test - public void testWhenMultipartUploadEnabledThenConfigurationIsSetCorrectly() { - ExtendedAsyncClientConfiguration extendedClientConfiguration = new ExtendedAsyncClientConfiguration() - .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME) - .withMultipartUploadEnabled(true) - .withMultipartUploadThreshold(MULTIPART_UPLOAD_THRESHOLD); + public void testReceiveMessageAsStream_LargeMessage_WithoutStreamStore_FallsBackToRegularRetrieval() { + ExtendedAsyncClientConfiguration config = new ExtendedAsyncClientConfiguration() + .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME) + .withStreamUploadEnabled(false); - assertTrue(extendedClientConfiguration.isMultipartUploadEnabled()); - assertEquals(MULTIPART_UPLOAD_THRESHOLD, extendedClientConfiguration.getMultipartUploadThreshold()); - } + AmazonSQSExtendedAsyncClient clientWithRegularStore = new AmazonSQSExtendedAsyncClient(mockSqsBackend, config); - @Test - public void testWhenMultipartUploadDisabledByDefaultThenStandardUploadIsUsed() { - String messageBody = generateStringWithLength(MORE_THAN_MULTIPART_THRESHOLD); - ExtendedAsyncClientConfiguration extendedClientConfiguration = new ExtendedAsyncClientConfiguration() - .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME); - SqsAsyncClient sqsExtended = spy(new AmazonSQSExtendedAsyncClient(mockSqsBackend, extendedClientConfiguration)); + ReceiveMessageRequest request = ReceiveMessageRequest.builder() + .queueUrl(SQS_QUEUE_URL) + .build(); - SendMessageRequest messageRequest = SendMessageRequest.builder().queueUrl(SQS_QUEUE_URL).messageBody(messageBody).build(); - sqsExtended.sendMessage(messageRequest).join(); + String s3Pointer = new PayloadS3Pointer(S3_BUCKET_NAME, "test-key").toJson(); + Message message = Message.builder() + .messageId("msg1") + .body(s3Pointer) + .receiptHandle("receipt1") + .messageAttributes(ImmutableMap.of( + "ExtendedPayloadSize", MessageAttributeValue.builder().stringValue("300000").dataType("Number").build() + )) + .build(); - assertFalse(extendedClientConfiguration.isMultipartUploadEnabled()); - verify(mockS3, times(1)).putObject(isA(PutObjectRequest.class), isA(AsyncRequestBody.class)); + ReceiveMessageResponse sqsResponse = ReceiveMessageResponse.builder() + .messages(message) + .build(); + + when(mockSqsBackend.receiveMessage(any(ReceiveMessageRequest.class))) + .thenReturn(CompletableFuture.completedFuture(sqsResponse)); + + ResponseBytes s3Object = ResponseBytes.fromByteArray( + GetObjectResponse.builder().build(), + "large payload content".getBytes(StandardCharsets.UTF_8)); + when(mockS3.getObject(any(GetObjectRequest.class), any(AsyncResponseTransformer.class))) + .thenReturn(CompletableFuture.completedFuture(s3Object)); + + CompletableFuture future = clientWithRegularStore.receiveMessageAsStream(request); + ReceiveStreamMessageResponse response = future.join(); + + assertEquals(1, response.streamMessages().size()); + StreamMessage streamMessage = response.streamMessages().get(0); + assertEquals("msg1", streamMessage.getMessage().messageId()); + assertEquals("large payload content", streamMessage.getMessage().body()); + assertFalse(streamMessage.hasStreamPayload()); } @Test - public void testWhenMultipartUploadEnabledWithMessageBatchThenLargeMessagesUseMultipart() { - int customThreshold = 500_000; // 500KB multipart threshold + public void testReceiveMessageAsStream_StreamRetrievalFails_IgnoreNotFoundEnabled_DeletesMessage() { + ExtendedAsyncClientConfiguration config = new ExtendedAsyncClientConfiguration() + .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME) + .withIgnorePayloadNotFound(true) + .withStreamUploadEnabled(true); - // 3 messages above SQS size limit (256KB) will be stored in S3: - // - 300K uses standard S3 upload - // - 600K and 700K use multipart upload (above 500K threshold) - int[] messageLengthForCounter = new int[] { - 100_000, - 200_000, - 300_000, - 600_000, - 700_000 - }; + AmazonSQSExtendedAsyncClient clientWithIgnore = new AmazonSQSExtendedAsyncClient(mockSqsBackend, config); - List batchEntries = new ArrayList<>(); - for (int i = 0; i < messageLengthForCounter.length; i++) { - int messageLength = messageLengthForCounter[i]; - String messageBody = generateStringWithLength(messageLength); - SendMessageBatchRequestEntry entry = SendMessageBatchRequestEntry.builder() - .id("entry_" + i) - .messageBody(messageBody) - .build(); - batchEntries.add(entry); - } + ReceiveMessageRequest request = ReceiveMessageRequest.builder() + .queueUrl(SQS_QUEUE_URL) + .build(); - ExtendedAsyncClientConfiguration extendedClientConfiguration = new ExtendedAsyncClientConfiguration() - .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME) - .withMultipartUploadEnabled(true) - .withMultipartUploadThreshold(customThreshold); - SqsAsyncClient sqsExtended = spy(new AmazonSQSExtendedAsyncClient(mockSqsBackend, extendedClientConfiguration)); + String s3Pointer = new PayloadS3Pointer(S3_BUCKET_NAME, "test-key").toJson(); + Message message = Message.builder() + .messageId("msg1") + .body(s3Pointer) + .receiptHandle("receipt1") + .messageAttributes(ImmutableMap.of( + "ExtendedPayloadSize", MessageAttributeValue.builder().stringValue("300000").dataType("Number").build() + )) + .build(); - SendMessageBatchRequest batchRequest = SendMessageBatchRequest.builder().queueUrl(SQS_QUEUE_URL).entries(batchEntries).build(); - sqsExtended.sendMessageBatch(batchRequest).join(); + ReceiveMessageResponse sqsResponse = ReceiveMessageResponse.builder() + .messages(message) + .build(); - verify(mockS3, times(3)).putObject(isA(PutObjectRequest.class), isA(AsyncRequestBody.class)); - verify(mockSqsBackend, times(1)).sendMessageBatch(isA(SendMessageBatchRequest.class)); + when(mockSqsBackend.receiveMessage(any(ReceiveMessageRequest.class))) + .thenReturn(CompletableFuture.completedFuture(sqsResponse)); + + CompletableFuture> failedFuture = new CompletableFuture<>(); + failedFuture.completeExceptionally(NoSuchKeyException.builder().build()); + when(mockS3.getObject(any(GetObjectRequest.class), any(AsyncResponseTransformer.class))) + .thenReturn(failedFuture); + + when(mockSqsBackend.deleteMessage(any(DeleteMessageRequest.class))) + .thenReturn(CompletableFuture.completedFuture(DeleteMessageResponse.builder().build())); + + CompletableFuture future = clientWithIgnore.receiveMessageAsStream(request); + ReceiveStreamMessageResponse response = future.join(); + + assertTrue(response.streamMessages().isEmpty()); + + ArgumentCaptor deleteCaptor = ArgumentCaptor.forClass(DeleteMessageRequest.class); + verify(mockSqsBackend).deleteMessage(deleteCaptor.capture()); + assertEquals("receipt1", deleteCaptor.getValue().receiptHandle()); } @Test - public void testWhenMultipartUploadEnabledWithCustomKMSThenKMSIsAppliedToMultipart() { - String messageBody = generateStringWithLength(MORE_THAN_MULTIPART_THRESHOLD); - ExtendedAsyncClientConfiguration extendedClientConfiguration = new ExtendedAsyncClientConfiguration() - .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME) - .withMultipartUploadEnabled(true) - .withMultipartUploadThreshold(MULTIPART_UPLOAD_THRESHOLD) - .withServerSideEncryption(SERVER_SIDE_ENCRYPTION_CUSTOM_STRATEGY); - SqsAsyncClient sqsExtended = spy(new AmazonSQSExtendedAsyncClient(mockSqsBackend, extendedClientConfiguration)); + public void testReceiveMessageAsStream_StreamRetrievalFails_IgnoreNotFoundDisabled_ThrowsException() { + ExtendedAsyncClientConfiguration config = new ExtendedAsyncClientConfiguration() + .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME) + .withIgnorePayloadNotFound(false) + .withStreamUploadEnabled(true); - SendMessageRequest messageRequest = SendMessageRequest.builder().queueUrl(SQS_QUEUE_URL).messageBody(messageBody).build(); - sqsExtended.sendMessage(messageRequest).join(); + AmazonSQSExtendedAsyncClient clientWithIgnore = new AmazonSQSExtendedAsyncClient(mockSqsBackend, config); + ReceiveMessageRequest request = ReceiveMessageRequest.builder() + .queueUrl(SQS_QUEUE_URL) + .build(); + + String s3Pointer = new PayloadS3Pointer(S3_BUCKET_NAME, "test-key").toJson(); + Message message = Message.builder() + .messageId("msg1") + .body(s3Pointer) + .receiptHandle("receipt1") + .messageAttributes(ImmutableMap.of( + "ExtendedPayloadSize", MessageAttributeValue.builder().stringValue("300000").dataType("Number").build() + )) + .build(); + + ReceiveMessageResponse sqsResponse = ReceiveMessageResponse.builder() + .messages(message) + .build(); - assertEquals(SERVER_SIDE_ENCRYPTION_CUSTOM_STRATEGY, extendedClientConfiguration.getServerSideEncryptionStrategy()); - verify(mockSqsBackend, times(1)).sendMessage(isA(SendMessageRequest.class)); + when(mockSqsBackend.receiveMessage(any(ReceiveMessageRequest.class))) + .thenReturn(CompletableFuture.completedFuture(sqsResponse)); + + CompletableFuture> failedFuture = new CompletableFuture<>(); + failedFuture.completeExceptionally(NoSuchKeyException.builder().build()); + when(mockS3.getObject(any(GetObjectRequest.class), any(AsyncResponseTransformer.class))) + .thenReturn(failedFuture); + + assertThrows(CompletionException.class, () -> { + clientWithIgnore.receiveMessageAsStream(request).join(); + }); } @Test - public void testWhenMultipartUploadEnabledAtExactThresholdThenMultipartIsNotUsed() { - String messageBody = generateStringWithLength(MULTIPART_UPLOAD_THRESHOLD); - ExtendedAsyncClientConfiguration extendedClientConfiguration = new ExtendedAsyncClientConfiguration() - .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME) - .withMultipartUploadEnabled(true) - .withMultipartUploadThreshold(MULTIPART_UPLOAD_THRESHOLD); - SqsAsyncClient sqsExtended = spy(new AmazonSQSExtendedAsyncClient(mockSqsBackend, extendedClientConfiguration)); + public void testReceiveMessageAsStream_MultipleMessages_MixedTypes() throws IOException { + ExtendedAsyncClientConfiguration config = new ExtendedAsyncClientConfiguration() + .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME) + .withPayloadSizeThreshold(262144) // 256KB + .withStreamUploadEnabled(true) + .withStreamUploadThreshold(1024 * 1024); // 1MB - SendMessageRequest messageRequest = SendMessageRequest.builder().queueUrl(SQS_QUEUE_URL).messageBody(messageBody).build(); - sqsExtended.sendMessage(messageRequest).join(); + AmazonSQSExtendedAsyncClient clientWithStream = new AmazonSQSExtendedAsyncClient(mockSqsBackend, config); - verify(mockS3, times(1)).putObject(isA(PutObjectRequest.class), isA(AsyncRequestBody.class)); + ReceiveMessageRequest request = ReceiveMessageRequest.builder() + .queueUrl(SQS_QUEUE_URL) + .build(); + + // Small message (no S3) + Message smallMessage = Message.builder() + .messageId("msg1") + .body("small message") + .receiptHandle("receipt1") + .build(); + + // Large message (with S3 pointer) + String s3Pointer = new PayloadS3Pointer(S3_BUCKET_NAME, "test-key").toJson(); + Message largeMessage = Message.builder() + .messageId("msg2") + .body(s3Pointer) + .receiptHandle("receipt2") + .messageAttributes(ImmutableMap.of( + "ExtendedPayloadSize", MessageAttributeValue.builder().stringValue("300000").dataType("Number").build() + )) + .build(); + + ReceiveMessageResponse sqsResponse = ReceiveMessageResponse.builder() + .messages(Arrays.asList(smallMessage, largeMessage)) + .build(); + + ResponseInputStream mockStream = mock(ResponseInputStream.class); + when(mockStream.read(any(byte[].class))).thenReturn(-1); + + when(mockSqsBackend.receiveMessage(any(ReceiveMessageRequest.class))) + .thenReturn(CompletableFuture.completedFuture(sqsResponse)); + + CompletableFuture> futureStream = CompletableFuture.completedFuture(mockStream); + when(mockS3.getObject(any(GetObjectRequest.class), any(AsyncResponseTransformer.class))) + .thenReturn(futureStream); + + CompletableFuture future = clientWithStream.receiveMessageAsStream(request); + ReceiveStreamMessageResponse response = future.join(); + + assertEquals(2, response.streamMessages().size()); + + StreamMessage msg1 = response.streamMessages().get(0); + assertEquals("msg1", msg1.getMessage().messageId()); + assertFalse(msg1.hasStreamPayload()); + + StreamMessage msg2 = response.streamMessages().get(1); + assertEquals("msg2", msg2.getMessage().messageId()); + assertTrue(msg2.hasStreamPayload()); + assertSame(mockStream, msg2.getPayloadStream()); } @Test - public void testWhenMultipartUploadFailsThenFallsBackToStandardUpload() { - String messageBody = generateStringWithLength(MORE_THAN_MULTIPART_THRESHOLD); - ExtendedAsyncClientConfiguration extendedClientConfiguration = new ExtendedAsyncClientConfiguration() - .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME) - .withMultipartUploadEnabled(true) - .withMultipartUploadThreshold(MULTIPART_UPLOAD_THRESHOLD); - SqsAsyncClient sqsExtended = spy(new AmazonSQSExtendedAsyncClient(mockSqsBackend, extendedClientConfiguration)); + public void testSendStreamMessage_LargeFileUpload_StoresInS3AndSendsPointer() { + int fileSizeBytes = 500_000; + String fileContent = generateStringWithLength(fileSizeBytes); + java.io.InputStream fileStream = new java.io.ByteArrayInputStream(fileContent.getBytes(StandardCharsets.UTF_8)); + + SendMessageRequest request = SendMessageRequest.builder() + .queueUrl(SQS_QUEUE_URL) + .messageAttributes(ImmutableMap.of( + "fileName", MessageAttributeValue.builder().stringValue("largefile.json").dataType("String").build(), + "contentType", MessageAttributeValue.builder().stringValue("application/json").dataType("String").build() + )) + .build(); - SendMessageRequest messageRequest = SendMessageRequest.builder().queueUrl(SQS_QUEUE_URL).messageBody(messageBody).build(); - sqsExtended.sendMessage(messageRequest).join(); + ((AmazonSQSExtendedAsyncClient) extendedSqsWithDefaultConfig) + .sendStreamMessage(request, fileStream, fileSizeBytes).join(); + + ArgumentCaptor s3Captor = ArgumentCaptor.forClass(PutObjectRequest.class); + verify(mockS3, times(1)).putObject(s3Captor.capture(), any(AsyncRequestBody.class)); + assertEquals(S3_BUCKET_NAME, s3Captor.getValue().bucket()); + + ArgumentCaptor sqsCaptor = ArgumentCaptor.forClass(SendMessageRequest.class); + verify(mockSqsBackend, times(1)).sendMessage(sqsCaptor.capture()); + + assertTrue(sqsCaptor.getValue().messageBody().contains(S3_BUCKET_NAME)); + assertTrue(sqsCaptor.getValue().messageAttributes().containsKey("fileName")); + assertTrue(sqsCaptor.getValue().messageAttributes().containsKey("contentType")); + assertTrue(sqsCaptor.getValue().messageAttributes() + .containsKey(AmazonSQSExtendedClientUtil.LEGACY_RESERVED_ATTRIBUTE_NAME)); + assertEquals(String.valueOf(fileSizeBytes), + sqsCaptor.getValue().messageAttributes() + .get(AmazonSQSExtendedClientUtil.LEGACY_RESERVED_ATTRIBUTE_NAME).stringValue()); + } - verify(mockSqsBackend, times(1)).sendMessage(isA(SendMessageRequest.class)); + @Test + public void testSendStreamMessage_SmallPayload_SendsDirectlyWithoutS3() { + String smallMessage = "Small notification message"; + java.io.InputStream stream = new java.io.ByteArrayInputStream(smallMessage.getBytes(StandardCharsets.UTF_8)); + + SendMessageRequest request = SendMessageRequest.builder() + .queueUrl(SQS_QUEUE_URL) + .messageAttributes(ImmutableMap.of( + "messageType", MessageAttributeValue.builder().stringValue("notification").dataType("String").build() + )) + .build(); + + ((AmazonSQSExtendedAsyncClient) extendedSqsWithDefaultConfig) + .sendStreamMessage(request, stream, smallMessage.length()).join(); + + verify(mockS3, never()).putObject(any(PutObjectRequest.class), any(AsyncRequestBody.class)); + + ArgumentCaptor sqsCaptor = ArgumentCaptor.forClass(SendMessageRequest.class); + verify(mockSqsBackend, times(1)).sendMessage(sqsCaptor.capture()); + assertEquals(smallMessage, sqsCaptor.getValue().messageBody()); + + assertFalse(sqsCaptor.getValue().messageAttributes() + .containsKey(AmazonSQSExtendedClientUtil.LEGACY_RESERVED_ATTRIBUTE_NAME)); } } + diff --git a/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientTest.java b/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientTest.java index 9459fd0..4fb4a3a 100644 --- a/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientTest.java +++ b/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientTest.java @@ -15,13 +15,45 @@ package com.amazon.sqs.javamessaging; +import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClient.USER_AGENT_NAME; +import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClient.USER_AGENT_VERSION; import static com.amazon.sqs.javamessaging.StringTestUtil.generateStringWithLength; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.argThat; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockStatic; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.io.InputStream; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.stream.Collectors; +import java.util.stream.IntStream; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.mockito.ArgumentCaptor; import org.mockito.MockedStatic; + import software.amazon.awssdk.awscore.AwsRequestOverrideConfiguration; import software.amazon.awssdk.core.ApiName; import software.amazon.awssdk.core.ResponseInputStream; @@ -52,37 +84,6 @@ import software.amazon.payloadoffloading.ServerSideEncryptionFactory; import software.amazon.payloadoffloading.ServerSideEncryptionStrategy; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import java.util.UUID; -import java.util.stream.Collectors; -import java.util.stream.IntStream; - -import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClient.USER_AGENT_NAME; -import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClient.USER_AGENT_VERSION; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertNotEquals; -import static org.junit.jupiter.api.Assertions.assertNull; -import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.junit.jupiter.api.Assertions.fail; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.argThat; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.ArgumentMatchers.isA; -import static org.mockito.Mockito.doThrow; -import static org.mockito.Mockito.isA; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.mockStatic; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.verifyNoInteractions; -import static org.mockito.Mockito.when; - /** * Tests the AmazonSQSExtendedClient class. */ @@ -115,10 +116,13 @@ public class AmazonSQSExtendedClientTest { // should be > 1 and << SQS_SIZE_LIMIT private static final int ARBITRARY_SMALLER_THRESHOLD = 500; - // Multipart upload thresholds - private static final int MULTIPART_UPLOAD_THRESHOLD = 5 * 1024 * 1024; // 5MB default - private static final int LESS_THAN_MULTIPART_THRESHOLD = MULTIPART_UPLOAD_THRESHOLD - 1; - private static final int MORE_THAN_MULTIPART_THRESHOLD = MULTIPART_UPLOAD_THRESHOLD + 1; + // Stream upload thresholds + private static final int STREAM_UPLOAD_THRESHOLD = 5 * 1024 * 1024; // 5MB default + private static final int LESS_THAN_STREAM_THRESHOLD = STREAM_UPLOAD_THRESHOLD - 1; + private static final int MORE_THAN_STREAM_THRESHOLD = STREAM_UPLOAD_THRESHOLD + 1; + + // Stream part size + private static final int STREAM_UPLOAD_PART_SIZE = 5 * 1024 * 1024; // 5MB @BeforeEach public void setupClients() { @@ -785,193 +789,216 @@ private String getSampleLargeReceiptHandle(String originalReceiptHandle) { } @Test - public void testWhenMultipartUploadDisabledThenStandardUploadIsUsed() { - String messageBody = generateStringWithLength(MORE_THAN_MULTIPART_THRESHOLD); - ExtendedClientConfiguration extendedClientConfiguration = new ExtendedClientConfiguration() - .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME) - .withMultipartUploadEnabled(false); - SqsClient sqsExtended = spy(new AmazonSQSExtendedClient(mockSqsBackend, extendedClientConfiguration)); - - SendMessageRequest messageRequest = SendMessageRequest.builder().queueUrl(SQS_QUEUE_URL).messageBody(messageBody).build(); - sqsExtended.sendMessage(messageRequest); - - verify(mockS3, times(1)).putObject(isA(PutObjectRequest.class), isA(RequestBody.class)); - } - - @Test - public void testWhenMultipartUploadEnabledAndMessageBelowThresholdThenStandardUploadIsUsed() { - String messageBody = generateStringWithLength(LESS_THAN_MULTIPART_THRESHOLD); - ExtendedClientConfiguration extendedClientConfiguration = new ExtendedClientConfiguration() - .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME) - .withMultipartUploadEnabled(true) - .withMultipartUploadThreshold(MULTIPART_UPLOAD_THRESHOLD); - SqsClient sqsExtended = spy(new AmazonSQSExtendedClient(mockSqsBackend, extendedClientConfiguration)); - - SendMessageRequest messageRequest = SendMessageRequest.builder().queueUrl(SQS_QUEUE_URL).messageBody(messageBody).build(); - sqsExtended.sendMessage(messageRequest); - - verify(mockS3, times(1)).putObject(isA(PutObjectRequest.class), isA(RequestBody.class)); - } - - @Test - public void testWhenMultipartUploadEnabledAndMessageAboveThresholdThenMultipartIsAttempted() { - String messageBody = generateStringWithLength(MORE_THAN_MULTIPART_THRESHOLD); - ExtendedClientConfiguration extendedClientConfiguration = new ExtendedClientConfiguration() - .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME) - .withMultipartUploadEnabled(true) - .withMultipartUploadThreshold(MULTIPART_UPLOAD_THRESHOLD); - SqsClient sqsExtended = spy(new AmazonSQSExtendedClient(mockSqsBackend, extendedClientConfiguration)); - - SendMessageRequest messageRequest = SendMessageRequest.builder().queueUrl(SQS_QUEUE_URL).messageBody(messageBody).build(); - sqsExtended.sendMessage(messageRequest); - - verify(mockSqsBackend, times(1)).sendMessage(isA(SendMessageRequest.class)); - } - - @Test - public void testWhenMultipartUploadEnabledWithCustomThresholdThenThresholdIsHonored() { - int customThreshold = 1024 * 1024; - int messageLength = customThreshold + 1000; - String messageBody = generateStringWithLength(messageLength); + public void testReceiveMessageAsStream_LargeMessage_WithStreamStore_ReturnsMessageWithStream() throws IOException { + String largeMessageBody = generateStringWithLength(MORE_THAN_STREAM_THRESHOLD); + + ResponseInputStream mockStream = mock(ResponseInputStream.class); + when(mockStream.read(any(byte[].class))).thenReturn(-1); + when(mockS3.getObject(isA(GetObjectRequest.class))).thenReturn(mockStream); + + String s3Key = "stream-key-" + UUID.randomUUID(); + String pointer = new PayloadS3Pointer(S3_BUCKET_NAME, s3Key).toJson(); + Message sqsMessage = Message.builder() + .messageAttributes(ImmutableMap.of(SQSExtendedClientConstants.RESERVED_ATTRIBUTE_NAME, + MessageAttributeValue.builder().dataType("Number").stringValue(String.valueOf(largeMessageBody.length())).build())) + .body(pointer) + .receiptHandle("test-receipt-handle") + .build(); + + when(mockSqsBackend.receiveMessage(isA(ReceiveMessageRequest.class))) + .thenReturn(ReceiveMessageResponse.builder().messages(sqsMessage).build()); ExtendedClientConfiguration extendedClientConfiguration = new ExtendedClientConfiguration() - .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME) - .withMultipartUploadEnabled(true) - .withMultipartUploadThreshold(customThreshold); - SqsClient sqsExtended = spy(new AmazonSQSExtendedClient(mockSqsBackend, extendedClientConfiguration)); - - SendMessageRequest messageRequest = SendMessageRequest.builder().queueUrl(SQS_QUEUE_URL).messageBody(messageBody).build(); - sqsExtended.sendMessage(messageRequest); - - verify(mockSqsBackend, times(1)).sendMessage(isA(SendMessageRequest.class)); - } - - @Test - public void testWhenMultipartUploadEnabledWithAlwaysThroughS3ThenSmallMessagesAlsoUseS3() { - String messageBody = generateStringWithLength(LESS_THAN_SQS_SIZE_LIMIT); - ExtendedClientConfiguration extendedClientConfiguration = new ExtendedClientConfiguration() - .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME) - .withMultipartUploadEnabled(true) - .withMultipartUploadThreshold(MULTIPART_UPLOAD_THRESHOLD) - .withAlwaysThroughS3(true); - SqsClient sqsExtended = spy(new AmazonSQSExtendedClient(mockSqsBackend, extendedClientConfiguration)); - - SendMessageRequest messageRequest = SendMessageRequest.builder().queueUrl(SQS_QUEUE_URL).messageBody(messageBody).build(); - sqsExtended.sendMessage(messageRequest); - - verify(mockS3, times(1)).putObject(isA(PutObjectRequest.class), isA(RequestBody.class)); + .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME) + .withStreamUploadEnabled(true) + .withStreamUploadPartSize(STREAM_UPLOAD_PART_SIZE) + .withS3Region("ap-south-1") + .withStreamUploadThreshold(STREAM_UPLOAD_THRESHOLD); + AmazonSQSExtendedClient sqsExtended = new AmazonSQSExtendedClient(mockSqsBackend, extendedClientConfiguration); + + ReceiveMessageRequest request = ReceiveMessageRequest.builder().queueUrl(SQS_QUEUE_URL).build(); + ReceiveStreamMessageResponse response = sqsExtended.receiveMessageAsStream(request); + sqsExtended.close(); + + assertEquals(1, response.streamMessages().size()); + StreamMessage streamMessage = response.streamMessages().get(0); + + assertTrue(streamMessage.hasStreamPayload()); + + ResponseInputStream payloadStream = streamMessage.getPayloadStream(); + assertNotNull(payloadStream); + + verify(mockS3, times(1)).getObject(isA(GetObjectRequest.class)); + assertEquals("-..s3BucketName..-test-bucket-name-..s3BucketName..--..s3Key..-stream-key-test-s3-key-uuid-..s3Key..-test-receipt-handle", streamMessage.getMessage().receiptHandle()); } @Test - public void testWhenMultipartUploadEnabledWithS3KeyPrefixThenPrefixIsUsed() { - String messageBody = generateStringWithLength(MORE_THAN_MULTIPART_THRESHOLD); - ExtendedClientConfiguration extendedClientConfiguration = new ExtendedClientConfiguration() - .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME) - .withMultipartUploadEnabled(true) - .withMultipartUploadThreshold(MULTIPART_UPLOAD_THRESHOLD) - .withS3KeyPrefix(S3_KEY_PREFIX); - SqsClient sqsExtended = spy(new AmazonSQSExtendedClient(mockSqsBackend, extendedClientConfiguration)); - - SendMessageRequest messageRequest = SendMessageRequest.builder().queueUrl(SQS_QUEUE_URL).messageBody(messageBody).build(); - sqsExtended.sendMessage(messageRequest); - - verify(mockSqsBackend, times(1)).sendMessage(isA(SendMessageRequest.class)); - } + public void testSendStreamMessage_LargeFileUpload_StoresInS3AndSendsPointer() { + // Use case: Uploading a large file (e.g., 6MB) via stream to avoid loading in memory + // Also tests StreamPayloadStore multipart upload path when stream upload is enabled + int fileSizeBytes = MORE_THAN_STREAM_THRESHOLD; // 6MB - above stream threshold + String fileContent = generateStringWithLength(fileSizeBytes); + java.io.InputStream fileStream = new java.io.ByteArrayInputStream(fileContent.getBytes(java.nio.charset.StandardCharsets.UTF_8)); + + // Create client with stream upload enabled to test StreamPayloadStore path + ExtendedClientConfiguration streamConfig = new ExtendedClientConfiguration() + .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME) + .withStreamUploadEnabled(true) + .withStreamUploadThreshold(STREAM_UPLOAD_THRESHOLD) + .withS3Region("us-east-1"); + + SqsClient streamClient = spy(new AmazonSQSExtendedClient(mockSqsBackend, streamConfig)); + + SendMessageRequest request = SendMessageRequest.builder() + .queueUrl(SQS_QUEUE_URL) + .messageAttributes(ImmutableMap.of( + "fileName", MessageAttributeValue.builder().stringValue("largefile.json").dataType("String").build(), + "contentType", MessageAttributeValue.builder().stringValue("application/json").dataType("String").build() + )) + .build(); - @Test - public void testWhenMultipartUploadEnabledThenConfigurationWithCorrectThreshold() { - ExtendedClientConfiguration extendedClientConfiguration = new ExtendedClientConfiguration() - .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME) - .withMultipartUploadEnabled(true) - .withMultipartUploadThreshold(MULTIPART_UPLOAD_THRESHOLD); + ((AmazonSQSExtendedClient) streamClient).sendStreamMessage(request, fileStream, fileSizeBytes); - assertTrue(extendedClientConfiguration.isMultipartUploadEnabled()); - assertEquals(MULTIPART_UPLOAD_THRESHOLD, extendedClientConfiguration.getMultipartUploadThreshold()); + ArgumentCaptor sqsCaptor = ArgumentCaptor.forClass(SendMessageRequest.class); + verify(mockSqsBackend, times(1)).sendMessage(sqsCaptor.capture()); + + assertTrue(sqsCaptor.getValue().messageBody().contains(S3_BUCKET_NAME)); + assertTrue(sqsCaptor.getValue().messageAttributes().containsKey("fileName")); + assertTrue(sqsCaptor.getValue().messageAttributes().containsKey("contentType")); + assertTrue(sqsCaptor.getValue().messageAttributes() + .containsKey(AmazonSQSExtendedClientUtil.LEGACY_RESERVED_ATTRIBUTE_NAME)); + assertEquals(String.valueOf(fileSizeBytes), + sqsCaptor.getValue().messageAttributes() + .get(AmazonSQSExtendedClientUtil.LEGACY_RESERVED_ATTRIBUTE_NAME).stringValue()); } @Test - public void testWhenMultipartUploadDisabledByDefaultThenStandardUploadIsUsed() { - String messageBody = generateStringWithLength(MORE_THAN_MULTIPART_THRESHOLD); - ExtendedClientConfiguration extendedClientConfiguration = new ExtendedClientConfiguration() - .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME); - SqsClient sqsExtended = spy(new AmazonSQSExtendedClient(mockSqsBackend, extendedClientConfiguration)); + public void testSendStreamMessage_SmallPayload_SendsDirectlyWithoutS3() { + String smallMessage = "Small notification message"; + java.io.InputStream stream = new java.io.ByteArrayInputStream(smallMessage.getBytes(java.nio.charset.StandardCharsets.UTF_8)); + + SendMessageRequest request = SendMessageRequest.builder() + .queueUrl(SQS_QUEUE_URL) + .messageAttributes(ImmutableMap.of( + "messageType", MessageAttributeValue.builder().stringValue("notification").dataType("String").build() + )) + .build(); - SendMessageRequest messageRequest = SendMessageRequest.builder().queueUrl(SQS_QUEUE_URL).messageBody(messageBody).build(); - sqsExtended.sendMessage(messageRequest); + ((AmazonSQSExtendedClient) extendedSqsWithDefaultConfig).sendStreamMessage(request, stream, smallMessage.length()); - assertFalse(extendedClientConfiguration.isMultipartUploadEnabled()); - verify(mockS3, times(1)).putObject(isA(PutObjectRequest.class), isA(RequestBody.class)); + verify(mockS3, never()).putObject(any(PutObjectRequest.class), any(RequestBody.class)); + ArgumentCaptor sqsCaptor = ArgumentCaptor.forClass(SendMessageRequest.class); + verify(mockSqsBackend, times(1)).sendMessage(sqsCaptor.capture()); + assertEquals(smallMessage, sqsCaptor.getValue().messageBody()); + assertFalse(sqsCaptor.getValue().messageAttributes() + .containsKey(AmazonSQSExtendedClientUtil.LEGACY_RESERVED_ATTRIBUTE_NAME)); } @Test - public void testWhenMultipartUploadEnabledWithMessageBatchThenLargeMessagesUseMultipart() { - int customThreshold = 500_000; // 500KB multipart threshold - - // 3 messages above SQS size limit (256KB) will be stored in S3: - // - 300K uses standard S3 upload - // - 600K and 700K use multipart upload (above 500K threshold) - int[] messageLengthForCounter = new int[] { - 100_000, - 200_000, - 300_000, - 600_000, - 700_000 - }; - - List batchEntries = new ArrayList<>(); - for (int i = 0; i < messageLengthForCounter.length; i++) { - int messageLength = messageLengthForCounter[i]; - String messageBody = generateStringWithLength(messageLength); - SendMessageBatchRequestEntry entry = SendMessageBatchRequestEntry.builder() - .id("entry_" + i) - .messageBody(messageBody) - .build(); - batchEntries.add(entry); - } - - ExtendedClientConfiguration extendedClientConfiguration = new ExtendedClientConfiguration() - .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME) - .withMultipartUploadEnabled(true) - .withMultipartUploadThreshold(customThreshold); - SqsClient sqsExtended = spy(new AmazonSQSExtendedClient(mockSqsBackend, extendedClientConfiguration)); - - SendMessageBatchRequest batchRequest = SendMessageBatchRequest.builder().queueUrl(SQS_QUEUE_URL).entries(batchEntries).build(); - sqsExtended.sendMessageBatch(batchRequest); + public void testSendStreamMessage_WithS3KeyPrefix() { + int dataSize = MORE_THAN_SQS_SIZE_LIMIT; + String largeData = generateStringWithLength(dataSize); + java.io.InputStream stream = new java.io.ByteArrayInputStream(largeData.getBytes(java.nio.charset.StandardCharsets.UTF_8)); + + SendMessageRequest request = SendMessageRequest.builder() + .queueUrl(SQS_QUEUE_URL) + .build(); - verify(mockS3, times(3)).putObject(isA(PutObjectRequest.class), isA(RequestBody.class)); - verify(mockSqsBackend, times(1)).sendMessageBatch(isA(SendMessageBatchRequest.class)); + ((AmazonSQSExtendedClient) extendedSqsWithS3KeyPrefix).sendStreamMessage(request, stream, dataSize); + ArgumentCaptor s3Captor = ArgumentCaptor.forClass(PutObjectRequest.class); + verify(mockS3, times(1)).putObject(s3Captor.capture(), any(RequestBody.class)); + assertTrue(s3Captor.getValue().key().startsWith(S3_KEY_PREFIX), + "S3 key should start with prefix: " + S3_KEY_PREFIX); } @Test - public void testWhenMultipartUploadEnabledWithCustomKMSThenKMSIsAppliedToMultipart() { - String messageBody = generateStringWithLength(MORE_THAN_MULTIPART_THRESHOLD); - ExtendedClientConfiguration extendedClientConfiguration = new ExtendedClientConfiguration() - .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME) - .withMultipartUploadEnabled(true) - .withMultipartUploadThreshold(MULTIPART_UPLOAD_THRESHOLD) - .withServerSideEncryption(SERVER_SIDE_ENCRYPTION_CUSTOM_STRATEGY); - SqsClient sqsExtended = spy(new AmazonSQSExtendedClient(mockSqsBackend, extendedClientConfiguration)); + public void testSendStreamMessageBatch_MixedSizes_OnlyLargeMessagesUseS3() { + // Use case: Batch with mixed sizes - small goes direct, large goes to S3 + // Also tests StreamPayloadStore multipart upload path for very large messages + + // Create client with stream upload enabled to test StreamPayloadStore path + ExtendedClientConfiguration streamConfig = new ExtendedClientConfiguration() + .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME) + .withStreamUploadEnabled(true) + .withStreamUploadThreshold(STREAM_UPLOAD_THRESHOLD) // 5MB + .withS3Region("us-east-1"); + + SqsClient streamClient = spy(new AmazonSQSExtendedClient(mockSqsBackend, streamConfig)); + + List entries = new ArrayList<>(); + List streams = new ArrayList<>(); + List contentLengths = new ArrayList<>(); + + // Small message (100KB - below SQS limit) + String smallMsg = generateStringWithLength(100_000); + entries.add(SendMessageBatchRequestEntry.builder() + .id("msg1") + .messageAttributes(ImmutableMap.of("size", MessageAttributeValue.builder().stringValue("small").dataType("String").build())) + .build()); + streams.add(new java.io.ByteArrayInputStream(smallMsg.getBytes(java.nio.charset.StandardCharsets.UTF_8))); + contentLengths.add((long) smallMsg.length()); + + // Large message (300KB - above SQS limit but below stream threshold) + String largeMsg = generateStringWithLength(300_000); + entries.add(SendMessageBatchRequestEntry.builder() + .id("msg2") + .messageAttributes(ImmutableMap.of("size", MessageAttributeValue.builder().stringValue("large").dataType("String").build())) + .build()); + streams.add(new java.io.ByteArrayInputStream(largeMsg.getBytes(java.nio.charset.StandardCharsets.UTF_8))); + contentLengths.add((long) largeMsg.length()); + + // Very large message (6MB - above stream threshold, uses multipart) + String veryLargeMsg = generateStringWithLength(MORE_THAN_STREAM_THRESHOLD); + entries.add(SendMessageBatchRequestEntry.builder() + .id("msg3") + .build()); + streams.add(new java.io.ByteArrayInputStream(veryLargeMsg.getBytes(java.nio.charset.StandardCharsets.UTF_8))); + contentLengths.add((long) veryLargeMsg.length()); + + SendMessageBatchRequest batchRequest = SendMessageBatchRequest.builder() + .queueUrl(SQS_QUEUE_URL) + .entries(entries) + .build(); - SendMessageRequest messageRequest = SendMessageRequest.builder().queueUrl(SQS_QUEUE_URL).messageBody(messageBody).build(); - sqsExtended.sendMessage(messageRequest); + ((AmazonSQSExtendedClient) streamClient).sendStreamMessageBatch(batchRequest, streams, contentLengths); - assertEquals(SERVER_SIDE_ENCRYPTION_CUSTOM_STRATEGY, extendedClientConfiguration.getServerSideEncryptionStrategy()); - verify(mockSqsBackend, times(1)).sendMessage(isA(SendMessageRequest.class)); + // Verify 2 large messages stored in S3 (300KB and 6MB) + ArgumentCaptor sqsCaptor = ArgumentCaptor.forClass(SendMessageBatchRequest.class); + verify(mockSqsBackend, times(1)).sendMessageBatch(sqsCaptor.capture()); + SendMessageBatchRequestEntry firstEntry = sqsCaptor.getValue().entries().get(0); + assertEquals(smallMsg, firstEntry.messageBody()); + SendMessageBatchRequestEntry secondEntry = sqsCaptor.getValue().entries().get(1); + assertTrue(secondEntry.messageBody().contains(S3_BUCKET_NAME)); + assertTrue(secondEntry.messageAttributes().containsKey(AmazonSQSExtendedClientUtil.LEGACY_RESERVED_ATTRIBUTE_NAME)); + assertTrue(secondEntry.messageAttributes().containsKey("size")); + assertEquals("large", secondEntry.messageAttributes().get("size").stringValue()); + SendMessageBatchRequestEntry thirdEntry = sqsCaptor.getValue().entries().get(2); + assertTrue(thirdEntry.messageBody().contains(S3_BUCKET_NAME)); + assertTrue(thirdEntry.messageAttributes().containsKey(AmazonSQSExtendedClientUtil.LEGACY_RESERVED_ATTRIBUTE_NAME)); } @Test - public void testWhenMultipartUploadEnabledAtExactThresholdThenMultipartIsNotUsed() { - String messageBody = generateStringWithLength(MULTIPART_UPLOAD_THRESHOLD); - ExtendedClientConfiguration extendedClientConfiguration = new ExtendedClientConfiguration() - .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME) - .withMultipartUploadEnabled(true) - .withMultipartUploadThreshold(MULTIPART_UPLOAD_THRESHOLD); - SqsClient sqsExtended = spy(new AmazonSQSExtendedClient(mockSqsBackend, extendedClientConfiguration)); + public void testSendStreamMessage_WithEncryption_AppliesKMSToS3Upload() { + int dataSize = MORE_THAN_SQS_SIZE_LIMIT; + String sensitiveData = generateStringWithLength(dataSize); + java.io.InputStream stream = new java.io.ByteArrayInputStream(sensitiveData.getBytes(java.nio.charset.StandardCharsets.UTF_8)); + + SendMessageRequest request = SendMessageRequest.builder() + .queueUrl(SQS_QUEUE_URL) + .messageAttributes(ImmutableMap.of( + "dataType", MessageAttributeValue.builder().stringValue("sensitive").dataType("String").build() + )) + .build(); - SendMessageRequest messageRequest = SendMessageRequest.builder().queueUrl(SQS_QUEUE_URL).messageBody(messageBody).build(); - sqsExtended.sendMessage(messageRequest); + ((AmazonSQSExtendedClient) extendedSqsWithCustomKMS).sendStreamMessage(request, stream, dataSize); - verify(mockS3, times(1)).putObject(isA(PutObjectRequest.class), isA(RequestBody.class)); + // Verify S3 storage with KMS + ArgumentCaptor s3Captor = ArgumentCaptor.forClass(PutObjectRequest.class); + verify(mockS3, times(1)).putObject(s3Captor.capture(), any(RequestBody.class)); + assertEquals(S3_SERVER_SIDE_ENCRYPTION_KMS_KEY_ID, s3Captor.getValue().ssekmsKeyId()); + + // Verify message sent to SQS with pointer + verify(mockSqsBackend, times(1)).sendMessage(any(SendMessageRequest.class)); } } diff --git a/src/test/java/com/amazon/sqs/javamessaging/ExtendedAsyncClientConfigurationTest.java b/src/test/java/com/amazon/sqs/javamessaging/ExtendedAsyncClientConfigurationTest.java index 5004b65..ae8fa81 100644 --- a/src/test/java/com/amazon/sqs/javamessaging/ExtendedAsyncClientConfigurationTest.java +++ b/src/test/java/com/amazon/sqs/javamessaging/ExtendedAsyncClientConfigurationTest.java @@ -87,45 +87,45 @@ public void testLargePayloadSupportEnabledWithDeleteFromS3Disabled() { } @Test - public void testMultipartUploadDisabledByDefault() { + public void testStreamUploadDisabledByDefault() { ExtendedAsyncClientConfiguration extendedClientConfiguration = new ExtendedAsyncClientConfiguration(); - assertFalse(extendedClientConfiguration.isMultipartUploadEnabled()); - assertEquals(5 * 1024 * 1024, extendedClientConfiguration.getMultipartUploadThreshold()); // 5MB default + assertFalse(extendedClientConfiguration.isStreamUploadEnabled()); + assertEquals(5 * 1024 * 1024, extendedClientConfiguration.getStreamUploadThreshold()); // 5MB default } @Test - public void testMultipartUploadEnabled() { + public void testStreamUploadEnabled() { ExtendedAsyncClientConfiguration extendedClientConfiguration = new ExtendedAsyncClientConfiguration(); - extendedClientConfiguration.withMultipartUploadEnabled(true); + extendedClientConfiguration.withStreamUploadEnabled(true); - assertTrue(extendedClientConfiguration.isMultipartUploadEnabled()); + assertTrue(extendedClientConfiguration.isStreamUploadEnabled()); } @Test - public void testMultipartUploadThresholdCustomValue() { + public void testStreamUploadThresholdCustomValue() { int customThreshold = 10 * 1024 * 1024; // 10MB ExtendedAsyncClientConfiguration extendedClientConfiguration = new ExtendedAsyncClientConfiguration(); - extendedClientConfiguration.withMultipartUploadThreshold(customThreshold); + extendedClientConfiguration.withStreamUploadThreshold(customThreshold); - assertEquals(customThreshold, extendedClientConfiguration.getMultipartUploadThreshold()); + assertEquals(customThreshold, extendedClientConfiguration.getStreamUploadThreshold()); } @Test - public void testMultipartUploadPartSizeCustomValue() { + public void testStreamUploadPartSizeCustomValue() { int customPartSize = 10 * 1024 * 1024; // 10MB ExtendedAsyncClientConfiguration extendedClientConfiguration = new ExtendedAsyncClientConfiguration(); - extendedClientConfiguration.withMultipartUploadPartSize(customPartSize); + extendedClientConfiguration.withStreamUploadPartSize(customPartSize); - assertEquals(customPartSize, extendedClientConfiguration.getMultipartUploadPartSize()); + assertEquals(customPartSize, extendedClientConfiguration.getStreamUploadPartSize()); } @Test - public void testMultipartUploadPartSizeBelowMinimumRoundedUpTo5MB() { + public void testStreamUploadPartSizeBelowMinimumRoundedUpTo5MB() { int belowMinimum = 3 * 1024 * 1024; // 3MB (below 5MB minimum) ExtendedAsyncClientConfiguration extendedClientConfiguration = new ExtendedAsyncClientConfiguration(); - extendedClientConfiguration.withMultipartUploadPartSize(belowMinimum); + extendedClientConfiguration.withStreamUploadPartSize(belowMinimum); - assertEquals(5 * 1024 * 1024, extendedClientConfiguration.getMultipartUploadPartSize()); + assertEquals(5 * 1024 * 1024, extendedClientConfiguration.getStreamUploadPartSize()); } } diff --git a/src/test/java/com/amazon/sqs/javamessaging/ExtendedClientConfigurationTest.java b/src/test/java/com/amazon/sqs/javamessaging/ExtendedClientConfigurationTest.java index fafa791..e7b3757 100644 --- a/src/test/java/com/amazon/sqs/javamessaging/ExtendedClientConfigurationTest.java +++ b/src/test/java/com/amazon/sqs/javamessaging/ExtendedClientConfigurationTest.java @@ -226,56 +226,56 @@ public void testS3keyPrefixWithALargeString() { } @Test - public void testMultipartUploadEnabled() { + public void testStreamUploadEnabled() { ExtendedClientConfiguration extendedClientConfiguration = new ExtendedClientConfiguration(); - extendedClientConfiguration.withMultipartUploadEnabled(true); + extendedClientConfiguration.withStreamUploadEnabled(true); - assertTrue(extendedClientConfiguration.isMultipartUploadEnabled()); + assertTrue(extendedClientConfiguration.isStreamUploadEnabled()); } @Test - public void testMultipartUploadThresholdCustomValue() { + public void testStreamUploadThresholdCustomValue() { int customThreshold = 10 * 1024 * 1024; // 10MB ExtendedClientConfiguration extendedClientConfiguration = new ExtendedClientConfiguration(); - extendedClientConfiguration.withMultipartUploadThreshold(customThreshold); + extendedClientConfiguration.withStreamUploadThreshold(customThreshold); - assertEquals(customThreshold, extendedClientConfiguration.getMultipartUploadThreshold()); + assertEquals(customThreshold, extendedClientConfiguration.getStreamUploadThreshold()); } @Test - public void testMultipartUploadPartSizeCustomValue() { + public void testStreamUploadPartSizeCustomValue() { int customPartSize = 10 * 1024 * 1024; // 10MB ExtendedClientConfiguration extendedClientConfiguration = new ExtendedClientConfiguration(); - extendedClientConfiguration.withMultipartUploadPartSize(customPartSize); + extendedClientConfiguration.withStreamUploadPartSize(customPartSize); - assertEquals(customPartSize, extendedClientConfiguration.getMultipartUploadPartSize()); + assertEquals(customPartSize, extendedClientConfiguration.getStreamUploadPartSize()); } @Test - public void testMultipartUploadPartSizeBelowMinimumRoundedUpTo5MB() { + public void testStreamUploadPartSizeBelowMinimumRoundedUpTo5MB() { int belowMinimum = 3 * 1024 * 1024; // 3MB (below 5MB minimum) ExtendedClientConfiguration extendedClientConfiguration = new ExtendedClientConfiguration(); - extendedClientConfiguration.withMultipartUploadPartSize(belowMinimum); + extendedClientConfiguration.withStreamUploadPartSize(belowMinimum); - assertEquals(5 * 1024 * 1024, extendedClientConfiguration.getMultipartUploadPartSize()); + assertEquals(5 * 1024 * 1024, extendedClientConfiguration.getStreamUploadPartSize()); } @Test - public void testMultipartConfigurationInCopyConstructor() { + public void testStreamConfigurationInCopyConstructor() { S3Client s3 = mock(S3Client.class); int customThreshold = 10 * 1024 * 1024; int customPartSize = 8 * 1024 * 1024; ExtendedClientConfiguration originalConfig = new ExtendedClientConfiguration(); originalConfig.withPayloadSupportEnabled(s3, s3BucketName) - .withMultipartUploadEnabled(true) - .withMultipartUploadThreshold(customThreshold) - .withMultipartUploadPartSize(customPartSize); + .withStreamUploadEnabled(true) + .withStreamUploadThreshold(customThreshold) + .withStreamUploadPartSize(customPartSize); ExtendedClientConfiguration copiedConfig = new ExtendedClientConfiguration(originalConfig); - assertTrue(copiedConfig.isMultipartUploadEnabled()); - assertEquals(customThreshold, copiedConfig.getMultipartUploadThreshold()); - assertEquals(customPartSize, copiedConfig.getMultipartUploadPartSize()); + assertTrue(copiedConfig.isStreamUploadEnabled()); + assertEquals(customThreshold, copiedConfig.getStreamUploadThreshold()); + assertEquals(customPartSize, copiedConfig.getStreamUploadPartSize()); } } From 439345e5bfb4f86c9baf9883e5e752156f6e5eef Mon Sep 17 00:00:00 2001 From: akbarsigit Date: Fri, 10 Oct 2025 09:21:00 +0700 Subject: [PATCH 3/8] add stream support --- .../StreamingIntegrationTest.java | 717 ++++++++++++++++++ 1 file changed, 717 insertions(+) create mode 100644 src/test/java/com/amazon/sqs/javamessaging/StreamingIntegrationTest.java diff --git a/src/test/java/com/amazon/sqs/javamessaging/StreamingIntegrationTest.java b/src/test/java/com/amazon/sqs/javamessaging/StreamingIntegrationTest.java new file mode 100644 index 0000000..2905b00 --- /dev/null +++ b/src/test/java/com/amazon/sqs/javamessaging/StreamingIntegrationTest.java @@ -0,0 +1,717 @@ +package com.amazon.sqs.javamessaging; + +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; +import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.S3Configuration; +import software.amazon.awssdk.services.sqs.SqsClient; +import software.amazon.awssdk.services.sqs.model.CreateQueueRequest; +import software.amazon.awssdk.services.sqs.model.DeleteQueueRequest; +import software.amazon.awssdk.services.sqs.model.GetQueueUrlRequest; +import software.amazon.awssdk.services.sqs.model.ReceiveMessageRequest; +import software.amazon.awssdk.services.sqs.model.SendMessageRequest; +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; + +import java.io.IOException; +import java.net.URI; +import java.nio.charset.StandardCharsets; + +import static org.junit.jupiter.api.Assertions.*; + +/** + * Integration test for streaming functionality using LocalStack. + * This test requires LocalStack to be running. + */ +@EnabledIfEnvironmentVariable(named = "LOCALSTACK_ENDPOINT", matches = ".*") +public class StreamingIntegrationTest { + + private static final String LOCALSTACK_ENDPOINT = System.getenv("LOCALSTACK_ENDPOINT") != null ? + System.getenv("LOCALSTACK_ENDPOINT") : "http://localhost:4566"; + private static final String BUCKET_NAME = "offload-bucket"; + private static final String QUEUE_NAME = "streaming-test-queue"; + + private static SqsClient sqsClient; + private static S3Client s3Client; + private static AmazonSQSExtendedClient extendedClient; + private static String queueUrl; + + @BeforeAll + public static void setup() { + // Create clients pointing to LocalStack + sqsClient = SqsClient.builder() + .endpointOverride(URI.create(LOCALSTACK_ENDPOINT)) + .credentialsProvider(StaticCredentialsProvider.create( + AwsBasicCredentials.create("dummy", "dummy"))) + .region(Region.of("ap-southeast-1")) + .build(); + + s3Client = S3Client.builder() + .endpointOverride(URI.create(LOCALSTACK_ENDPOINT)) + .credentialsProvider(StaticCredentialsProvider.create( + AwsBasicCredentials.create("dummy", "dummy"))) + .region(Region.of("ap-southeast-1")) + .serviceConfiguration(S3Configuration.builder() + .pathStyleAccessEnabled(true) + .build()) + .build(); + + // Create S3AsyncClient with same configuration for TransferManager + // software.amazon.awssdk.services.s3.S3AsyncClient s3AsyncClient = + // software.amazon.awssdk.services.s3.S3AsyncClient.builder() + // .endpointOverride(URI.create(LOCALSTACK_ENDPOINT)) + // .credentialsProvider(StaticCredentialsProvider.create( + // AwsBasicCredentials.create("dummy", "dummy"))) + // .region(Region.of("ap-southeast-1")) + // .serviceConfiguration(S3Configuration.builder() + // .pathStyleAccessEnabled(true) + // .build()) + // .multipartEnabled(true) + // .multipartConfiguration(conf -> conf + // .minimumPartSizeInBytes(10 * 1024 * 1024L) // 10MB + // .thresholdInBytes(16 * 1024 * 1024L)) // 16MB + // .build(); + + // Create extended client with stream support + ExtendedClientConfiguration config = new ExtendedClientConfiguration() + .withPayloadSupportEnabled(s3Client, BUCKET_NAME) + .withStreamUploadEnabled(true) + .withStreamUploadThreshold(5 * 1024 * 1024) // 5MB + .withStreamUploadPartSize(10 * 1024 * 1024) // 10MB + .withPayloadSizeThreshold(256 * 1024) // 256KB + .withS3Region("ap-southeast-1"); + + System.out.println("S3 Region configured: " + config.getS3Region()); + + extendedClient = new AmazonSQSExtendedClient(sqsClient, config); + + // Create queue + CreateQueueRequest createQueueRequest = CreateQueueRequest.builder() + .queueName(QUEUE_NAME) + .build(); + sqsClient.createQueue(createQueueRequest); + + // Get queue URL + GetQueueUrlRequest getQueueUrlRequest = GetQueueUrlRequest.builder() + .queueName(QUEUE_NAME) + .build(); + queueUrl = sqsClient.getQueueUrl(getQueueUrlRequest).queueUrl(); + + System.out.println("Integration test setup complete. Using LocalStack at: " + LOCALSTACK_ENDPOINT); + } + + @AfterAll + public static void cleanup() { + if (queueUrl != null) { + try { + DeleteQueueRequest deleteQueueRequest = DeleteQueueRequest.builder() + .queueUrl(queueUrl) + .build(); + sqsClient.deleteQueue(deleteQueueRequest); + } catch (Exception e) { + System.err.println("Failed to delete queue: " + e.getMessage()); + } + } + + if (sqsClient != null) { + sqsClient.close(); + } + if (s3Client != null) { + s3Client.close(); + } + if (extendedClient != null) { + extendedClient.close(); + } + } + + // @Test + // public void testSendAndReceiveLargeMessageWithStreaming() throws IOException { + // // Create a large message (6MB) that will trigger stream upload + // String largeMessage = generateLargeMessage(6 * 1024 * 1024); // 6MB + + // // Send the large message + // SendMessageRequest sendRequest = SendMessageRequest.builder() + // .queueUrl(queueUrl) + // .messageBody(largeMessage) + // .build(); + + // extendedClient.sendMessage(sendRequest); + // System.out.println("Sent large message (" + largeMessage.length() + " chars, ~" + + // (largeMessage.length() * 2 / 1024 / 1024) + "MB)"); + + // // Receive the message using streaming + // ReceiveMessageRequest receiveRequest = ReceiveMessageRequest.builder() + // .queueUrl(queueUrl) + // .maxNumberOfMessages(1) + // .build(); + + // ReceiveStreamMessageResponse streamResponse = extendedClient.receiveMessageAsStream(receiveRequest); + + // assertEquals(1, streamResponse.streamMessages().size(), + // "Should receive exactly one message"); + + // StreamMessage streamMessage = streamResponse.streamMessages().get(0); + + // assertTrue(streamMessage.hasStreamPayload(), + // "Large message should have streaming payload"); + + // ResponseInputStream payloadStream = streamMessage.getPayloadStream(); + // assertNotNull(payloadStream, "Payload stream should not be null"); + + // String receivedContent = readStreamContent(payloadStream); + + // assertEquals(largeMessage, receivedContent, + // "Received content should match sent message"); + + // System.out.println("Successfully received and streamed large message content"); + // } + + // @Test + // public void testSendAndReceiveSmallMessageWithoutStreaming() { + // String smallMessage = "This is a small message that stays in SQS"; + + // SendMessageRequest sendRequest = SendMessageRequest.builder() + // .queueUrl(queueUrl) + // .messageBody(smallMessage) + // .build(); + + // extendedClient.sendMessage(sendRequest); + // System.out.println("Sent small message (" + smallMessage.length() + " chars)"); + + // // Receive the message using streaming + // ReceiveMessageRequest receiveRequest = ReceiveMessageRequest.builder() + // .queueUrl(queueUrl) + // .maxNumberOfMessages(1) + // .build(); + + // ReceiveStreamMessageResponse streamResponse = extendedClient.receiveMessageAsStream(receiveRequest); + + // assertEquals(1, streamResponse.streamMessages().size(), + // "Should receive exactly one message"); + + // StreamMessage streamMessage = streamResponse.streamMessages().get(0); + + // assertFalse(streamMessage.hasStreamPayload(), + // "Small message should not have streaming payload"); + + // assertEquals(smallMessage, streamMessage.getMessage().body(), + // "Small message content should be in message body"); + + // System.out.println("Successfully received small message without streaming"); + // } + + // @Test + // public void testMixedMessageTypes() throws IOException { + // String smallMessage = "Small message"; + // String largeMessage = generateLargeMessage(4 * 1024 * 1024); // 4MB + + // extendedClient.sendMessage(SendMessageRequest.builder() + // .queueUrl(queueUrl) + // .messageBody(smallMessage) + // .build()); + + // extendedClient.sendMessage(SendMessageRequest.builder() + // .queueUrl(queueUrl) + // .messageBody(largeMessage) + // .build()); + + // ReceiveMessageRequest receiveRequest = ReceiveMessageRequest.builder() + // .queueUrl(queueUrl) + // .maxNumberOfMessages(10) + // .build(); + + // ReceiveStreamMessageResponse streamResponse = extendedClient.receiveMessageAsStream(receiveRequest); + + // assertEquals(2, streamResponse.streamMessages().size(), + // "Should receive exactly two messages"); + + // StreamMessage smallStreamMessage = null; + // StreamMessage largeStreamMessage = null; + + // for (StreamMessage msg : streamResponse.streamMessages()) { + // if (msg.hasStreamPayload()) { + // largeStreamMessage = msg; + // } else { + // smallStreamMessage = msg; + // } + // } + + // assertNotNull(smallStreamMessage, "Should have small message"); + // assertNotNull(largeStreamMessage, "Should have large message"); + + // assertEquals(smallMessage, smallStreamMessage.getMessage().body()); + + // ResponseInputStream payloadStream = largeStreamMessage.getPayloadStream(); + // String receivedLargeContent = readStreamContent(payloadStream); + // assertEquals(largeMessage, receivedLargeContent); + + // System.out.println("Successfully handled mixed small and large messages"); + // } + + // @Test + // public void testStreamingUploadWithMultipartConfiguration() throws IOException { + // // Create a 20MB message that will definitely trigger multipart upload + // // With 10MB part size and 16MB threshold configured + // String veryLargeMessage = generateLargeMessage(20 * 1024 * 1024); // 20MB + // byte[] messageBytes = veryLargeMessage.getBytes(StandardCharsets.UTF_8); + + // System.out.println("Sending very large message (" + + // (messageBytes.length / 1024 / 1024) + "MB) - should use multipart upload"); + + // // Send the message using sendStreamMessage - this should trigger streaming upload with multipart + // SendMessageRequest sendRequest = SendMessageRequest.builder() + // .queueUrl(queueUrl) + // .build(); + + // long startTime = System.currentTimeMillis(); + // java.io.InputStream messageStream = new java.io.ByteArrayInputStream(messageBytes); + // extendedClient.sendStreamMessage(sendRequest, messageStream, messageBytes.length); + // long uploadTime = System.currentTimeMillis() - startTime; + + // System.out.println("Upload completed in " + uploadTime + "ms using TransferManager with multipart"); + + // // Receive using streaming to avoid loading entire message into memory + // ReceiveMessageRequest receiveRequest = ReceiveMessageRequest.builder() + // .queueUrl(queueUrl) + // .maxNumberOfMessages(1) + // .waitTimeSeconds(10) + // .build(); + + // startTime = System.currentTimeMillis(); + // ReceiveStreamMessageResponse streamResponse = extendedClient.receiveMessageAsStream(receiveRequest); + // long receiveTime = System.currentTimeMillis() - startTime; + + // System.out.println("Receive completed in " + receiveTime + "ms"); + + // assertEquals(1, streamResponse.streamMessages().size(), + // "Should receive exactly one message"); + + // StreamMessage streamMessage = streamResponse.streamMessages().get(0); + + // assertTrue(streamMessage.hasStreamPayload(), + // "Very large message should have streaming payload"); + + // ResponseInputStream payloadStream = streamMessage.getPayloadStream(); + // assertNotNull(payloadStream, "Payload stream should not be null"); + + // // Read the content in chunks to demonstrate true streaming + // startTime = System.currentTimeMillis(); + // String receivedContent = readStreamContent(payloadStream); + // long readTime = System.currentTimeMillis() - startTime; + + // System.out.println("Stream read completed in " + readTime + "ms"); + + // assertEquals(veryLargeMessage.length(), receivedContent.length(), + // "Received content length should match sent message length"); + + // // Verify start and end markers to ensure content integrity + // assertTrue(receivedContent.startsWith("START:"), + // "Content should start with START marker"); + // assertTrue(receivedContent.endsWith(":END"), + // "Content should end with END marker"); + + // System.out.println("Successfully sent and received 20MB message using multipart streaming"); + // } + + // @Test + // public void testStreamingWithCustomPartSizeAndThreshold() throws IOException { + // // This test verifies that the configured part size (10MB) and threshold (16MB) + // // are being used correctly for multipart uploads + + // // Create a message just above the threshold (17MB) + // String largeMessage = generateLargeMessage(17 * 1024 * 1024); // 17MB + // byte[] messageBytes = largeMessage.getBytes(StandardCharsets.UTF_8); + + // System.out.println("Testing with " + (messageBytes.length / 1024 / 1024) + + // "MB message - above 16MB threshold, should trigger multipart"); + + // SendMessageRequest sendRequest = SendMessageRequest.builder() + // .queueUrl(queueUrl) + // .build(); + + // java.io.InputStream messageStream = new java.io.ByteArrayInputStream(messageBytes); + // extendedClient.sendStreamMessage(sendRequest, messageStream, messageBytes.length); + // System.out.println("Successfully sent " + (messageBytes.length / 1024 / 1024) + + // "MB message with multipart upload"); + + // // Receive and verify + // ReceiveMessageRequest receiveRequest = ReceiveMessageRequest.builder() + // .queueUrl(queueUrl) + // .maxNumberOfMessages(1) + // .build(); + + // ReceiveStreamMessageResponse streamResponse = extendedClient.receiveMessageAsStream(receiveRequest); + + // assertEquals(1, streamResponse.streamMessages().size()); + // StreamMessage streamMessage = streamResponse.streamMessages().get(0); + + // assertTrue(streamMessage.hasStreamPayload(), + // "Message above threshold should have streaming payload"); + + // String receivedContent = readStreamContent(streamMessage.getPayloadStream()); + + // assertEquals(largeMessage.length(), receivedContent.length(), + // "Received content should match sent message length"); + + // System.out.println("Successfully verified custom part size and threshold configuration"); + // } + + @Test + public void testMemoryUsageComparisonSendTraditionalVsStreaming() throws IOException, InterruptedException { + // Test 1: Compare memory usage for SENDING large messages + // In real-world: sender and receiver are on different machines with separate memory + + final int messageSizeBytes = 50 * 1024 * 1024; // 50MB + + System.out.println("\n=== SEND Memory Usage Comparison ==="); + System.out.println("Message size: " + (messageSizeBytes / 1024 / 1024) + "MB"); + System.out.println("Comparing traditional sendMessage() vs sendStreamMessage()"); + + Runtime runtime = Runtime.getRuntime(); + + // TRADITIONAL SEND TEST FIRST + System.out.println("\n--- Traditional sendMessage() ---"); + + System.gc(); + Thread.sleep(200); + long traditionalStartMemory = runtime.totalMemory() - runtime.freeMemory(); + long traditionalPeakMemory = traditionalStartMemory; + System.out.println("Baseline memory: " + (traditionalStartMemory / 1024 / 1024) + "MB"); + + String largeMessage = generateLargeMessage(messageSizeBytes); + long afterLoadMemory = runtime.totalMemory() - runtime.freeMemory(); + traditionalPeakMemory = Math.max(traditionalPeakMemory, afterLoadMemory); + System.out.println("After loading file into String: " + (afterLoadMemory / 1024 / 1024) + "MB (+" + + ((afterLoadMemory - traditionalStartMemory) / 1024 / 1024) + "MB)"); + + SendMessageRequest sendRequest1 = SendMessageRequest.builder() + .queueUrl(queueUrl) + .messageBody(largeMessage) + .build(); + + long beforeSend = System.currentTimeMillis(); + extendedClient.sendMessage(sendRequest1); + long sendTime = System.currentTimeMillis() - beforeSend; + + long afterSendMemory = runtime.totalMemory() - runtime.freeMemory(); + traditionalPeakMemory = Math.max(traditionalPeakMemory, afterSendMemory); + + System.out.println("Send time: " + sendTime + "ms"); + System.out.println("Memory after send: " + (afterSendMemory / 1024 / 1024) + "MB (+" + + ((afterSendMemory - traditionalStartMemory) / 1024 / 1024) + "MB)"); + System.out.println("Peak memory: " + (traditionalPeakMemory / 1024 / 1024) + "MB (+" + + ((traditionalPeakMemory - traditionalStartMemory) / 1024 / 1024) + "MB)"); + + assertTrue(largeMessage.startsWith("START:"), "Traditional send message should start with START:"); + assertTrue(largeMessage.endsWith(":END"), "Traditional send message should end with :END"); + + long traditionalSendMemoryUsed = traditionalPeakMemory - traditionalStartMemory; + + largeMessage = null; + System.gc(); + Thread.sleep(200); + + // STREAMING SEND TEST SECOND (fresh memory state) + System.out.println("\n--- Streaming sendStreamMessage() ---"); + + System.gc(); + Thread.sleep(200); + long streamingStartMemory = runtime.totalMemory() - runtime.freeMemory(); + long streamingPeakMemory = streamingStartMemory; + System.out.println("Baseline memory: " + (streamingStartMemory / 1024 / 1024) + "MB"); + + byte[] messageBytes = generateLargeMessage(messageSizeBytes).getBytes(StandardCharsets.UTF_8); + long afterGenerateMemory = runtime.totalMemory() - runtime.freeMemory(); + streamingPeakMemory = Math.max(streamingPeakMemory, afterGenerateMemory); + System.out.println("After generating bytes for stream: " + (afterGenerateMemory / 1024 / 1024) + "MB (+" + + ((afterGenerateMemory - streamingStartMemory) / 1024 / 1024) + "MB)"); + + SendMessageRequest sendRequest2 = SendMessageRequest.builder() + .queueUrl(queueUrl) + .build(); + + long beforeStreamSend = System.currentTimeMillis(); + java.io.InputStream messageStream = new java.io.ByteArrayInputStream(messageBytes); + extendedClient.sendStreamMessage(sendRequest2, messageStream, messageBytes.length); + long streamSendTime = System.currentTimeMillis() - beforeStreamSend; + + long afterStreamSendMemory = runtime.totalMemory() - runtime.freeMemory(); + streamingPeakMemory = Math.max(streamingPeakMemory, afterStreamSendMemory); + + System.out.println("Send time: " + streamSendTime + "ms"); + System.out.println("Memory after send: " + (afterStreamSendMemory / 1024 / 1024) + "MB (+" + + ((afterStreamSendMemory - streamingStartMemory) / 1024 / 1024) + "MB)"); + System.out.println("Peak memory: " + (streamingPeakMemory / 1024 / 1024) + "MB (+" + + ((streamingPeakMemory - streamingStartMemory) / 1024 / 1024) + "MB)"); + + String messageBytesStr = new String(messageBytes, StandardCharsets.UTF_8); + assertTrue(messageBytesStr.startsWith("START:"), "Streaming send message should start with START:"); + assertTrue(messageBytesStr.endsWith(":END"), "Streaming send message should end with :END"); + messageBytesStr = null; + + long streamingSendMemoryUsed = streamingPeakMemory - streamingStartMemory; + + messageBytes = null; + System.gc(); + Thread.sleep(200); + + System.out.println("\n=== SEND Comparison ==="); + System.out.println("Traditional send peak: " + (traditionalSendMemoryUsed / 1024 / 1024) + "MB"); + System.out.println("Streaming send peak: " + (streamingSendMemoryUsed / 1024 / 1024) + "MB"); + long sendMemorySaved = traditionalSendMemoryUsed - streamingSendMemoryUsed; + double sendPercentSaved = traditionalSendMemoryUsed > 0 ? + (sendMemorySaved * 100.0) / traditionalSendMemoryUsed : 0; + System.out.println("Memory saved: " + (sendMemorySaved / 1024 / 1024) + "MB (" + + String.format("%.1f", sendPercentSaved) + "% reduction)"); + } + + @Test + public void testRealWorldStreamingVsTraditionalReceive() throws IOException, InterruptedException { + // REAL-WORLD SCENARIO: Compare memory usage when actually PROCESSING the content + // This simulates what happens in real applications where you need to consume the data + + final int messageSizeBytes = 50 * 1024 * 1024; // 50MB + + System.out.println("\n=== REAL-WORLD RECEIVE Comparison ==="); + System.out.println("Message size: " + (messageSizeBytes / 1024 / 1024) + "MB"); + System.out.println("Simulating real app: processing content (counting chars, validating data)"); + + Runtime runtime = Runtime.getRuntime(); + + // TRADITIONAL APPROACH: Load entire content, then process it + String largeMessage = generateLargeMessage(messageSizeBytes); + SendMessageRequest sendRequest1 = SendMessageRequest.builder() + .queueUrl(queueUrl) + .messageBody(largeMessage) + .build(); + extendedClient.sendMessage(sendRequest1); + largeMessage = null; // Release reference + System.gc(); + Thread.sleep(200); + + System.out.println("\n--- Traditional: Load entire content, then process ---"); + + System.gc(); + Thread.sleep(200); + long traditionalStartMemory = runtime.totalMemory() - runtime.freeMemory(); + long traditionalPeakMemory = traditionalStartMemory; + System.out.println("Baseline memory: " + (traditionalStartMemory / 1024 / 1024) + "MB"); + + ReceiveMessageRequest receiveRequest1 = ReceiveMessageRequest.builder() + .queueUrl(queueUrl) + .maxNumberOfMessages(1) + .build(); + + long beforeReceive = System.currentTimeMillis(); + software.amazon.awssdk.services.sqs.model.ReceiveMessageResponse response = + extendedClient.receiveMessage(receiveRequest1); + long receiveTime = System.currentTimeMillis() - beforeReceive; + + long afterReceiveMemory = runtime.totalMemory() - runtime.freeMemory(); + traditionalPeakMemory = Math.max(traditionalPeakMemory, afterReceiveMemory); + + // REAL PROCESSING: Access the body (loads entire content) + String receivedBody = response.messages().get(0).body(); + long afterBodyAccessMemory = runtime.totalMemory() - runtime.freeMemory(); + traditionalPeakMemory = Math.max(traditionalPeakMemory, afterBodyAccessMemory); + + // Simulate real processing: count characters, validate content + long beforeProcessing = System.currentTimeMillis(); + int charCount = receivedBody.length(); + boolean hasValidContent = receivedBody.contains("START:") && receivedBody.contains(":END"); + int dataLines = 0; + for (char c : receivedBody.toCharArray()) { + if (c == '\n') dataLines++; + } + long processingTime = System.currentTimeMillis() - beforeProcessing; + + long afterProcessingMemory = runtime.totalMemory() - runtime.freeMemory(); + traditionalPeakMemory = Math.max(traditionalPeakMemory, afterProcessingMemory); + + System.out.println("Receive time: " + receiveTime + "ms"); + System.out.println("Processing time: " + processingTime + "ms"); + System.out.println("Content length: " + charCount + " chars"); + System.out.println("Data lines: " + dataLines); + System.out.println("Valid content: " + hasValidContent); + System.out.println("Memory after receive: " + (afterReceiveMemory / 1024 / 1024) + "MB (+" + + ((afterReceiveMemory - traditionalStartMemory) / 1024 / 1024) + "MB)"); + System.out.println("Memory after body access: " + (afterBodyAccessMemory / 1024 / 1024) + "MB (+" + + ((afterBodyAccessMemory - traditionalStartMemory) / 1024 / 1024) + "MB)"); + System.out.println("Memory after processing: " + (afterProcessingMemory / 1024 / 1024) + "MB (+" + + ((afterProcessingMemory - traditionalStartMemory) / 1024 / 1024) + "MB)"); + System.out.println("Peak memory: " + (traditionalPeakMemory / 1024 / 1024) + "MB (+" + + ((traditionalPeakMemory - traditionalStartMemory) / 1024 / 1024) + "MB)"); + + long traditionalMemoryUsed = traditionalPeakMemory - traditionalStartMemory; + + receivedBody = null; + response = null; + System.gc(); + Thread.sleep(200); + + // STREAMING APPROACH: Process content as it streams (realistic scenario) + byte[] messageBytes = generateLargeMessage(messageSizeBytes).getBytes(StandardCharsets.UTF_8); + SendMessageRequest sendRequest2 = SendMessageRequest.builder() + .queueUrl(queueUrl) + .build(); + java.io.InputStream sendStream = new java.io.ByteArrayInputStream(messageBytes); + extendedClient.sendStreamMessage(sendRequest2, sendStream, messageBytes.length); + messageBytes = null; + System.gc(); + Thread.sleep(200); + + System.out.println("\n--- Streaming: Process content as it streams ---"); + + System.gc(); + Thread.sleep(200); + long streamingStartMemory = runtime.totalMemory() - runtime.freeMemory(); + long streamingPeakMemory = streamingStartMemory; + System.out.println("Baseline memory: " + (streamingStartMemory / 1024 / 1024) + "MB"); + + ReceiveMessageRequest receiveRequest2 = ReceiveMessageRequest.builder() + .queueUrl(queueUrl) + .maxNumberOfMessages(1) + .build(); + + long beforeStreamReceive = System.currentTimeMillis(); + ReceiveStreamMessageResponse streamResponse = + extendedClient.receiveMessageAsStream(receiveRequest2); + long streamReceiveTime = System.currentTimeMillis() - beforeStreamReceive; + + long afterStreamReceiveMemory = runtime.totalMemory() - runtime.freeMemory(); + streamingPeakMemory = Math.max(streamingPeakMemory, afterStreamReceiveMemory); + + System.out.println("Receive time: " + streamReceiveTime + "ms"); + System.out.println("Memory after receive: " + (afterStreamReceiveMemory / 1024 / 1024) + "MB (+" + + ((afterStreamReceiveMemory - streamingStartMemory) / 1024 / 1024) + "MB)"); + + // REAL STREAMING PROCESSING: Process content as it streams + StreamMessage streamMessage = streamResponse.streamMessages().get(0); + ResponseInputStream payloadStream = streamMessage.getPayloadStream(); + + long beforeStreamProcessing = System.currentTimeMillis(); + long totalBytesRead = 0; + int streamCharCount = 0; + boolean streamHasValidContent = false; + int streamDataLines = 0; + boolean foundStart = false; + boolean foundEnd = false; + + try (ResponseInputStream s = payloadStream) { + byte[] buffer = new byte[8192]; // 8KB buffer + int bytesRead; + + // Patterns to search for anywhere in the stream + String startPattern = "START:"; + String endPattern = ":END"; + + // Sliding window to handle patterns that span chunk boundaries + byte[] previousChunkTail = new byte[0]; + int maxPatternLength = Math.max(startPattern.length(), endPattern.length()); + + while ((bytesRead = s.read(buffer)) != -1) { + totalBytesRead += bytesRead; + streamCharCount += bytesRead; + + // Count newlines by scanning bytes directly (memory efficient) + for (int i = 0; i < bytesRead; i++) { + if (buffer[i] == '\n') streamDataLines++; + } + + // Pattern matching: search in current chunk + overlap from previous chunk + if (!foundStart || !foundEnd) { + // Combine previous chunk tail with current chunk for pattern search + // This handles patterns that span chunk boundaries + byte[] searchBuffer; + + if (previousChunkTail.length > 0) { + searchBuffer = new byte[previousChunkTail.length + bytesRead]; + System.arraycopy(previousChunkTail, 0, searchBuffer, 0, previousChunkTail.length); + System.arraycopy(buffer, 0, searchBuffer, previousChunkTail.length, bytesRead); + } else { + searchBuffer = buffer; + } + + int searchLength = (searchBuffer == buffer) ? bytesRead : searchBuffer.length; + String chunk = new String(searchBuffer, 0, searchLength, StandardCharsets.UTF_8); + + if (!foundStart && chunk.contains(startPattern)) { + foundStart = true; + } + if (!foundEnd && chunk.contains(endPattern)) { + foundEnd = true; + } + + chunk = null; // Release immediately + + // Save tail of current chunk for next iteration (to handle boundary patterns) + int tailLength = Math.min(bytesRead, maxPatternLength); + previousChunkTail = new byte[tailLength]; + System.arraycopy(buffer, bytesRead - tailLength, previousChunkTail, 0, tailLength); + } + + long currentMemory = runtime.totalMemory() - runtime.freeMemory(); + streamingPeakMemory = Math.max(streamingPeakMemory, currentMemory); + } + } + + streamHasValidContent = foundStart && foundEnd; + long streamProcessingTime = System.currentTimeMillis() - beforeStreamProcessing; + + long afterStreamProcessingMemory = runtime.totalMemory() - runtime.freeMemory(); + streamingPeakMemory = Math.max(streamingPeakMemory, afterStreamProcessingMemory); + + System.out.println("Processing time: " + streamProcessingTime + "ms"); + System.out.println("Total bytes: " + (totalBytesRead / 1024 / 1024) + "MB"); + System.out.println("Content length: " + streamCharCount + " chars"); + System.out.println("Data lines: " + streamDataLines); + System.out.println("Valid content: " + streamHasValidContent); + System.out.println("Memory after processing: " + (afterStreamProcessingMemory / 1024 / 1024) + "MB (+" + + ((afterStreamProcessingMemory - streamingStartMemory) / 1024 / 1024) + "MB)"); + System.out.println("Peak memory: " + (streamingPeakMemory / 1024 / 1024) + "MB (+" + + ((streamingPeakMemory - streamingStartMemory) / 1024 / 1024) + "MB)"); + + long streamingMemoryUsed = streamingPeakMemory - streamingStartMemory; + + System.out.println("\n=== REAL-WORLD Comparison ==="); + System.out.println("Traditional peak: " + (traditionalMemoryUsed / 1024 / 1024) + "MB"); + System.out.println("Streaming peak: " + (streamingMemoryUsed / 1024 / 1024) + "MB"); + long memorySaved = traditionalMemoryUsed - streamingMemoryUsed; + double percentSaved = traditionalMemoryUsed > 0 ? + (memorySaved * 100.0) / traditionalMemoryUsed : 0; + System.out.println("Memory saved: " + (memorySaved / 1024 / 1024) + "MB (" + + String.format("%.1f", percentSaved) + "% reduction)"); + + assertTrue(totalBytesRead >= messageSizeBytes - 100 && totalBytesRead <= messageSizeBytes + 100); + assertEquals(charCount, streamCharCount, "Both approaches should count same characters"); + assertEquals(dataLines, streamDataLines, "Both approaches should count same lines"); + assertEquals(hasValidContent, streamHasValidContent, "Both approaches should validate content same way"); + } + + + private String generateLargeMessage(int sizeInBytes) { + int numChars = sizeInBytes; + StringBuilder sb = new StringBuilder(numChars); + sb.append("START:"); + for (int i = 0; i < numChars - 12; i++) { + sb.append((char) ('A' + (i % 26))); + } + sb.append(":END"); + return sb.toString(); + } + + private String readStreamContent(ResponseInputStream stream) throws IOException { + if (stream == null) { + return ""; + } + try (ResponseInputStream s = stream) { + byte[] bytes = s.readAllBytes(); + return new String(bytes, StandardCharsets.UTF_8); + } + } +} \ No newline at end of file From 44a4537ccdb74ccce0f8e65d49350078f0431236 Mon Sep 17 00:00:00 2001 From: akbarsigit Date: Fri, 10 Oct 2025 09:29:46 +0700 Subject: [PATCH 4/8] cleanup --- .../AmazonSQSExtendedClientTest.java | 80 +++++++++---------- 1 file changed, 36 insertions(+), 44 deletions(-) diff --git a/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientTest.java b/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientTest.java index 4fb4a3a..e51d6cb 100644 --- a/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientTest.java +++ b/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientTest.java @@ -15,45 +15,13 @@ package com.amazon.sqs.javamessaging; -import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClient.USER_AGENT_NAME; -import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClient.USER_AGENT_VERSION; import static com.amazon.sqs.javamessaging.StringTestUtil.generateStringWithLength; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertNotEquals; -import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.api.Assertions.assertNull; -import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.junit.jupiter.api.Assertions.fail; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.argThat; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.ArgumentMatchers.isA; -import static org.mockito.Mockito.doThrow; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.mockStatic; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.verifyNoInteractions; -import static org.mockito.Mockito.when; - -import java.io.IOException; -import java.io.InputStream; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import java.util.UUID; -import java.util.stream.Collectors; -import java.util.stream.IntStream; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.mockito.ArgumentCaptor; import org.mockito.MockedStatic; - import software.amazon.awssdk.awscore.AwsRequestOverrideConfiguration; import software.amazon.awssdk.core.ApiName; import software.amazon.awssdk.core.ResponseInputStream; @@ -84,6 +52,41 @@ import software.amazon.payloadoffloading.ServerSideEncryptionFactory; import software.amazon.payloadoffloading.ServerSideEncryptionStrategy; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClient.USER_AGENT_NAME; +import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClient.USER_AGENT_VERSION; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.argThat; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.isA; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockStatic; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.io.InputStream; + + /** * Tests the AmazonSQSExtendedClient class. */ @@ -834,13 +837,10 @@ public void testReceiveMessageAsStream_LargeMessage_WithStreamStore_ReturnsMessa @Test public void testSendStreamMessage_LargeFileUpload_StoresInS3AndSendsPointer() { - // Use case: Uploading a large file (e.g., 6MB) via stream to avoid loading in memory - // Also tests StreamPayloadStore multipart upload path when stream upload is enabled - int fileSizeBytes = MORE_THAN_STREAM_THRESHOLD; // 6MB - above stream threshold + int fileSizeBytes = MORE_THAN_STREAM_THRESHOLD; String fileContent = generateStringWithLength(fileSizeBytes); java.io.InputStream fileStream = new java.io.ByteArrayInputStream(fileContent.getBytes(java.nio.charset.StandardCharsets.UTF_8)); - // Create client with stream upload enabled to test StreamPayloadStore path ExtendedClientConfiguration streamConfig = new ExtendedClientConfiguration() .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME) .withStreamUploadEnabled(true) @@ -913,10 +913,6 @@ public void testSendStreamMessage_WithS3KeyPrefix() { @Test public void testSendStreamMessageBatch_MixedSizes_OnlyLargeMessagesUseS3() { - // Use case: Batch with mixed sizes - small goes direct, large goes to S3 - // Also tests StreamPayloadStore multipart upload path for very large messages - - // Create client with stream upload enabled to test StreamPayloadStore path ExtendedClientConfiguration streamConfig = new ExtendedClientConfiguration() .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME) .withStreamUploadEnabled(true) @@ -962,7 +958,6 @@ public void testSendStreamMessageBatch_MixedSizes_OnlyLargeMessagesUseS3() { ((AmazonSQSExtendedClient) streamClient).sendStreamMessageBatch(batchRequest, streams, contentLengths); - // Verify 2 large messages stored in S3 (300KB and 6MB) ArgumentCaptor sqsCaptor = ArgumentCaptor.forClass(SendMessageBatchRequest.class); verify(mockSqsBackend, times(1)).sendMessageBatch(sqsCaptor.capture()); SendMessageBatchRequestEntry firstEntry = sqsCaptor.getValue().entries().get(0); @@ -992,12 +987,9 @@ public void testSendStreamMessage_WithEncryption_AppliesKMSToS3Upload() { ((AmazonSQSExtendedClient) extendedSqsWithCustomKMS).sendStreamMessage(request, stream, dataSize); - // Verify S3 storage with KMS ArgumentCaptor s3Captor = ArgumentCaptor.forClass(PutObjectRequest.class); verify(mockS3, times(1)).putObject(s3Captor.capture(), any(RequestBody.class)); assertEquals(S3_SERVER_SIDE_ENCRYPTION_KMS_KEY_ID, s3Captor.getValue().ssekmsKeyId()); - - // Verify message sent to SQS with pointer verify(mockSqsBackend, times(1)).sendMessage(any(SendMessageRequest.class)); } From 3cf72a582602869ff23b1b7b3c12915a653bcbed Mon Sep 17 00:00:00 2001 From: akbarsigit Date: Fri, 10 Oct 2025 10:06:59 +0700 Subject: [PATCH 5/8] cleanup --- .../AmazonSQSExtendedClient.java | 6 +- .../StreamingIntegrationTest.java | 717 ------------------ 2 files changed, 1 insertion(+), 722 deletions(-) delete mode 100644 src/test/java/com/amazon/sqs/javamessaging/StreamingIntegrationTest.java diff --git a/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClient.java b/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClient.java index 7fecea3..aab4778 100644 --- a/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClient.java +++ b/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClient.java @@ -165,11 +165,7 @@ public AmazonSQSExtendedClient(SqsClient sqsClient, ExtendedClientConfiguration .minimumPartSizeInBytes(clientConfiguration.getStreamUploadPartSize()) .thresholdInBytes(clientConfiguration.getStreamUploadThreshold()) ) - // .credentialsProvider(DefaultCredentialsProvider.create()) - .endpointOverride(URI.create("http://localhost:4566")) - .forcePathStyle(true) - .credentialsProvider(StaticCredentialsProvider.create( - AwsBasicCredentials.create("dummy", "dummy"))) + .credentialsProvider(DefaultCredentialsProvider.create()) .region(Region.of(clientConfiguration.getS3Region())) .build(); diff --git a/src/test/java/com/amazon/sqs/javamessaging/StreamingIntegrationTest.java b/src/test/java/com/amazon/sqs/javamessaging/StreamingIntegrationTest.java deleted file mode 100644 index 2905b00..0000000 --- a/src/test/java/com/amazon/sqs/javamessaging/StreamingIntegrationTest.java +++ /dev/null @@ -1,717 +0,0 @@ -package com.amazon.sqs.javamessaging; - -import org.junit.jupiter.api.AfterAll; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; -import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; -import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; -import software.amazon.awssdk.regions.Region; -import software.amazon.awssdk.services.s3.S3Client; -import software.amazon.awssdk.services.s3.S3Configuration; -import software.amazon.awssdk.services.sqs.SqsClient; -import software.amazon.awssdk.services.sqs.model.CreateQueueRequest; -import software.amazon.awssdk.services.sqs.model.DeleteQueueRequest; -import software.amazon.awssdk.services.sqs.model.GetQueueUrlRequest; -import software.amazon.awssdk.services.sqs.model.ReceiveMessageRequest; -import software.amazon.awssdk.services.sqs.model.SendMessageRequest; -import software.amazon.awssdk.core.ResponseInputStream; -import software.amazon.awssdk.services.s3.model.GetObjectResponse; - -import java.io.IOException; -import java.net.URI; -import java.nio.charset.StandardCharsets; - -import static org.junit.jupiter.api.Assertions.*; - -/** - * Integration test for streaming functionality using LocalStack. - * This test requires LocalStack to be running. - */ -@EnabledIfEnvironmentVariable(named = "LOCALSTACK_ENDPOINT", matches = ".*") -public class StreamingIntegrationTest { - - private static final String LOCALSTACK_ENDPOINT = System.getenv("LOCALSTACK_ENDPOINT") != null ? - System.getenv("LOCALSTACK_ENDPOINT") : "http://localhost:4566"; - private static final String BUCKET_NAME = "offload-bucket"; - private static final String QUEUE_NAME = "streaming-test-queue"; - - private static SqsClient sqsClient; - private static S3Client s3Client; - private static AmazonSQSExtendedClient extendedClient; - private static String queueUrl; - - @BeforeAll - public static void setup() { - // Create clients pointing to LocalStack - sqsClient = SqsClient.builder() - .endpointOverride(URI.create(LOCALSTACK_ENDPOINT)) - .credentialsProvider(StaticCredentialsProvider.create( - AwsBasicCredentials.create("dummy", "dummy"))) - .region(Region.of("ap-southeast-1")) - .build(); - - s3Client = S3Client.builder() - .endpointOverride(URI.create(LOCALSTACK_ENDPOINT)) - .credentialsProvider(StaticCredentialsProvider.create( - AwsBasicCredentials.create("dummy", "dummy"))) - .region(Region.of("ap-southeast-1")) - .serviceConfiguration(S3Configuration.builder() - .pathStyleAccessEnabled(true) - .build()) - .build(); - - // Create S3AsyncClient with same configuration for TransferManager - // software.amazon.awssdk.services.s3.S3AsyncClient s3AsyncClient = - // software.amazon.awssdk.services.s3.S3AsyncClient.builder() - // .endpointOverride(URI.create(LOCALSTACK_ENDPOINT)) - // .credentialsProvider(StaticCredentialsProvider.create( - // AwsBasicCredentials.create("dummy", "dummy"))) - // .region(Region.of("ap-southeast-1")) - // .serviceConfiguration(S3Configuration.builder() - // .pathStyleAccessEnabled(true) - // .build()) - // .multipartEnabled(true) - // .multipartConfiguration(conf -> conf - // .minimumPartSizeInBytes(10 * 1024 * 1024L) // 10MB - // .thresholdInBytes(16 * 1024 * 1024L)) // 16MB - // .build(); - - // Create extended client with stream support - ExtendedClientConfiguration config = new ExtendedClientConfiguration() - .withPayloadSupportEnabled(s3Client, BUCKET_NAME) - .withStreamUploadEnabled(true) - .withStreamUploadThreshold(5 * 1024 * 1024) // 5MB - .withStreamUploadPartSize(10 * 1024 * 1024) // 10MB - .withPayloadSizeThreshold(256 * 1024) // 256KB - .withS3Region("ap-southeast-1"); - - System.out.println("S3 Region configured: " + config.getS3Region()); - - extendedClient = new AmazonSQSExtendedClient(sqsClient, config); - - // Create queue - CreateQueueRequest createQueueRequest = CreateQueueRequest.builder() - .queueName(QUEUE_NAME) - .build(); - sqsClient.createQueue(createQueueRequest); - - // Get queue URL - GetQueueUrlRequest getQueueUrlRequest = GetQueueUrlRequest.builder() - .queueName(QUEUE_NAME) - .build(); - queueUrl = sqsClient.getQueueUrl(getQueueUrlRequest).queueUrl(); - - System.out.println("Integration test setup complete. Using LocalStack at: " + LOCALSTACK_ENDPOINT); - } - - @AfterAll - public static void cleanup() { - if (queueUrl != null) { - try { - DeleteQueueRequest deleteQueueRequest = DeleteQueueRequest.builder() - .queueUrl(queueUrl) - .build(); - sqsClient.deleteQueue(deleteQueueRequest); - } catch (Exception e) { - System.err.println("Failed to delete queue: " + e.getMessage()); - } - } - - if (sqsClient != null) { - sqsClient.close(); - } - if (s3Client != null) { - s3Client.close(); - } - if (extendedClient != null) { - extendedClient.close(); - } - } - - // @Test - // public void testSendAndReceiveLargeMessageWithStreaming() throws IOException { - // // Create a large message (6MB) that will trigger stream upload - // String largeMessage = generateLargeMessage(6 * 1024 * 1024); // 6MB - - // // Send the large message - // SendMessageRequest sendRequest = SendMessageRequest.builder() - // .queueUrl(queueUrl) - // .messageBody(largeMessage) - // .build(); - - // extendedClient.sendMessage(sendRequest); - // System.out.println("Sent large message (" + largeMessage.length() + " chars, ~" + - // (largeMessage.length() * 2 / 1024 / 1024) + "MB)"); - - // // Receive the message using streaming - // ReceiveMessageRequest receiveRequest = ReceiveMessageRequest.builder() - // .queueUrl(queueUrl) - // .maxNumberOfMessages(1) - // .build(); - - // ReceiveStreamMessageResponse streamResponse = extendedClient.receiveMessageAsStream(receiveRequest); - - // assertEquals(1, streamResponse.streamMessages().size(), - // "Should receive exactly one message"); - - // StreamMessage streamMessage = streamResponse.streamMessages().get(0); - - // assertTrue(streamMessage.hasStreamPayload(), - // "Large message should have streaming payload"); - - // ResponseInputStream payloadStream = streamMessage.getPayloadStream(); - // assertNotNull(payloadStream, "Payload stream should not be null"); - - // String receivedContent = readStreamContent(payloadStream); - - // assertEquals(largeMessage, receivedContent, - // "Received content should match sent message"); - - // System.out.println("Successfully received and streamed large message content"); - // } - - // @Test - // public void testSendAndReceiveSmallMessageWithoutStreaming() { - // String smallMessage = "This is a small message that stays in SQS"; - - // SendMessageRequest sendRequest = SendMessageRequest.builder() - // .queueUrl(queueUrl) - // .messageBody(smallMessage) - // .build(); - - // extendedClient.sendMessage(sendRequest); - // System.out.println("Sent small message (" + smallMessage.length() + " chars)"); - - // // Receive the message using streaming - // ReceiveMessageRequest receiveRequest = ReceiveMessageRequest.builder() - // .queueUrl(queueUrl) - // .maxNumberOfMessages(1) - // .build(); - - // ReceiveStreamMessageResponse streamResponse = extendedClient.receiveMessageAsStream(receiveRequest); - - // assertEquals(1, streamResponse.streamMessages().size(), - // "Should receive exactly one message"); - - // StreamMessage streamMessage = streamResponse.streamMessages().get(0); - - // assertFalse(streamMessage.hasStreamPayload(), - // "Small message should not have streaming payload"); - - // assertEquals(smallMessage, streamMessage.getMessage().body(), - // "Small message content should be in message body"); - - // System.out.println("Successfully received small message without streaming"); - // } - - // @Test - // public void testMixedMessageTypes() throws IOException { - // String smallMessage = "Small message"; - // String largeMessage = generateLargeMessage(4 * 1024 * 1024); // 4MB - - // extendedClient.sendMessage(SendMessageRequest.builder() - // .queueUrl(queueUrl) - // .messageBody(smallMessage) - // .build()); - - // extendedClient.sendMessage(SendMessageRequest.builder() - // .queueUrl(queueUrl) - // .messageBody(largeMessage) - // .build()); - - // ReceiveMessageRequest receiveRequest = ReceiveMessageRequest.builder() - // .queueUrl(queueUrl) - // .maxNumberOfMessages(10) - // .build(); - - // ReceiveStreamMessageResponse streamResponse = extendedClient.receiveMessageAsStream(receiveRequest); - - // assertEquals(2, streamResponse.streamMessages().size(), - // "Should receive exactly two messages"); - - // StreamMessage smallStreamMessage = null; - // StreamMessage largeStreamMessage = null; - - // for (StreamMessage msg : streamResponse.streamMessages()) { - // if (msg.hasStreamPayload()) { - // largeStreamMessage = msg; - // } else { - // smallStreamMessage = msg; - // } - // } - - // assertNotNull(smallStreamMessage, "Should have small message"); - // assertNotNull(largeStreamMessage, "Should have large message"); - - // assertEquals(smallMessage, smallStreamMessage.getMessage().body()); - - // ResponseInputStream payloadStream = largeStreamMessage.getPayloadStream(); - // String receivedLargeContent = readStreamContent(payloadStream); - // assertEquals(largeMessage, receivedLargeContent); - - // System.out.println("Successfully handled mixed small and large messages"); - // } - - // @Test - // public void testStreamingUploadWithMultipartConfiguration() throws IOException { - // // Create a 20MB message that will definitely trigger multipart upload - // // With 10MB part size and 16MB threshold configured - // String veryLargeMessage = generateLargeMessage(20 * 1024 * 1024); // 20MB - // byte[] messageBytes = veryLargeMessage.getBytes(StandardCharsets.UTF_8); - - // System.out.println("Sending very large message (" + - // (messageBytes.length / 1024 / 1024) + "MB) - should use multipart upload"); - - // // Send the message using sendStreamMessage - this should trigger streaming upload with multipart - // SendMessageRequest sendRequest = SendMessageRequest.builder() - // .queueUrl(queueUrl) - // .build(); - - // long startTime = System.currentTimeMillis(); - // java.io.InputStream messageStream = new java.io.ByteArrayInputStream(messageBytes); - // extendedClient.sendStreamMessage(sendRequest, messageStream, messageBytes.length); - // long uploadTime = System.currentTimeMillis() - startTime; - - // System.out.println("Upload completed in " + uploadTime + "ms using TransferManager with multipart"); - - // // Receive using streaming to avoid loading entire message into memory - // ReceiveMessageRequest receiveRequest = ReceiveMessageRequest.builder() - // .queueUrl(queueUrl) - // .maxNumberOfMessages(1) - // .waitTimeSeconds(10) - // .build(); - - // startTime = System.currentTimeMillis(); - // ReceiveStreamMessageResponse streamResponse = extendedClient.receiveMessageAsStream(receiveRequest); - // long receiveTime = System.currentTimeMillis() - startTime; - - // System.out.println("Receive completed in " + receiveTime + "ms"); - - // assertEquals(1, streamResponse.streamMessages().size(), - // "Should receive exactly one message"); - - // StreamMessage streamMessage = streamResponse.streamMessages().get(0); - - // assertTrue(streamMessage.hasStreamPayload(), - // "Very large message should have streaming payload"); - - // ResponseInputStream payloadStream = streamMessage.getPayloadStream(); - // assertNotNull(payloadStream, "Payload stream should not be null"); - - // // Read the content in chunks to demonstrate true streaming - // startTime = System.currentTimeMillis(); - // String receivedContent = readStreamContent(payloadStream); - // long readTime = System.currentTimeMillis() - startTime; - - // System.out.println("Stream read completed in " + readTime + "ms"); - - // assertEquals(veryLargeMessage.length(), receivedContent.length(), - // "Received content length should match sent message length"); - - // // Verify start and end markers to ensure content integrity - // assertTrue(receivedContent.startsWith("START:"), - // "Content should start with START marker"); - // assertTrue(receivedContent.endsWith(":END"), - // "Content should end with END marker"); - - // System.out.println("Successfully sent and received 20MB message using multipart streaming"); - // } - - // @Test - // public void testStreamingWithCustomPartSizeAndThreshold() throws IOException { - // // This test verifies that the configured part size (10MB) and threshold (16MB) - // // are being used correctly for multipart uploads - - // // Create a message just above the threshold (17MB) - // String largeMessage = generateLargeMessage(17 * 1024 * 1024); // 17MB - // byte[] messageBytes = largeMessage.getBytes(StandardCharsets.UTF_8); - - // System.out.println("Testing with " + (messageBytes.length / 1024 / 1024) + - // "MB message - above 16MB threshold, should trigger multipart"); - - // SendMessageRequest sendRequest = SendMessageRequest.builder() - // .queueUrl(queueUrl) - // .build(); - - // java.io.InputStream messageStream = new java.io.ByteArrayInputStream(messageBytes); - // extendedClient.sendStreamMessage(sendRequest, messageStream, messageBytes.length); - // System.out.println("Successfully sent " + (messageBytes.length / 1024 / 1024) + - // "MB message with multipart upload"); - - // // Receive and verify - // ReceiveMessageRequest receiveRequest = ReceiveMessageRequest.builder() - // .queueUrl(queueUrl) - // .maxNumberOfMessages(1) - // .build(); - - // ReceiveStreamMessageResponse streamResponse = extendedClient.receiveMessageAsStream(receiveRequest); - - // assertEquals(1, streamResponse.streamMessages().size()); - // StreamMessage streamMessage = streamResponse.streamMessages().get(0); - - // assertTrue(streamMessage.hasStreamPayload(), - // "Message above threshold should have streaming payload"); - - // String receivedContent = readStreamContent(streamMessage.getPayloadStream()); - - // assertEquals(largeMessage.length(), receivedContent.length(), - // "Received content should match sent message length"); - - // System.out.println("Successfully verified custom part size and threshold configuration"); - // } - - @Test - public void testMemoryUsageComparisonSendTraditionalVsStreaming() throws IOException, InterruptedException { - // Test 1: Compare memory usage for SENDING large messages - // In real-world: sender and receiver are on different machines with separate memory - - final int messageSizeBytes = 50 * 1024 * 1024; // 50MB - - System.out.println("\n=== SEND Memory Usage Comparison ==="); - System.out.println("Message size: " + (messageSizeBytes / 1024 / 1024) + "MB"); - System.out.println("Comparing traditional sendMessage() vs sendStreamMessage()"); - - Runtime runtime = Runtime.getRuntime(); - - // TRADITIONAL SEND TEST FIRST - System.out.println("\n--- Traditional sendMessage() ---"); - - System.gc(); - Thread.sleep(200); - long traditionalStartMemory = runtime.totalMemory() - runtime.freeMemory(); - long traditionalPeakMemory = traditionalStartMemory; - System.out.println("Baseline memory: " + (traditionalStartMemory / 1024 / 1024) + "MB"); - - String largeMessage = generateLargeMessage(messageSizeBytes); - long afterLoadMemory = runtime.totalMemory() - runtime.freeMemory(); - traditionalPeakMemory = Math.max(traditionalPeakMemory, afterLoadMemory); - System.out.println("After loading file into String: " + (afterLoadMemory / 1024 / 1024) + "MB (+" + - ((afterLoadMemory - traditionalStartMemory) / 1024 / 1024) + "MB)"); - - SendMessageRequest sendRequest1 = SendMessageRequest.builder() - .queueUrl(queueUrl) - .messageBody(largeMessage) - .build(); - - long beforeSend = System.currentTimeMillis(); - extendedClient.sendMessage(sendRequest1); - long sendTime = System.currentTimeMillis() - beforeSend; - - long afterSendMemory = runtime.totalMemory() - runtime.freeMemory(); - traditionalPeakMemory = Math.max(traditionalPeakMemory, afterSendMemory); - - System.out.println("Send time: " + sendTime + "ms"); - System.out.println("Memory after send: " + (afterSendMemory / 1024 / 1024) + "MB (+" + - ((afterSendMemory - traditionalStartMemory) / 1024 / 1024) + "MB)"); - System.out.println("Peak memory: " + (traditionalPeakMemory / 1024 / 1024) + "MB (+" + - ((traditionalPeakMemory - traditionalStartMemory) / 1024 / 1024) + "MB)"); - - assertTrue(largeMessage.startsWith("START:"), "Traditional send message should start with START:"); - assertTrue(largeMessage.endsWith(":END"), "Traditional send message should end with :END"); - - long traditionalSendMemoryUsed = traditionalPeakMemory - traditionalStartMemory; - - largeMessage = null; - System.gc(); - Thread.sleep(200); - - // STREAMING SEND TEST SECOND (fresh memory state) - System.out.println("\n--- Streaming sendStreamMessage() ---"); - - System.gc(); - Thread.sleep(200); - long streamingStartMemory = runtime.totalMemory() - runtime.freeMemory(); - long streamingPeakMemory = streamingStartMemory; - System.out.println("Baseline memory: " + (streamingStartMemory / 1024 / 1024) + "MB"); - - byte[] messageBytes = generateLargeMessage(messageSizeBytes).getBytes(StandardCharsets.UTF_8); - long afterGenerateMemory = runtime.totalMemory() - runtime.freeMemory(); - streamingPeakMemory = Math.max(streamingPeakMemory, afterGenerateMemory); - System.out.println("After generating bytes for stream: " + (afterGenerateMemory / 1024 / 1024) + "MB (+" + - ((afterGenerateMemory - streamingStartMemory) / 1024 / 1024) + "MB)"); - - SendMessageRequest sendRequest2 = SendMessageRequest.builder() - .queueUrl(queueUrl) - .build(); - - long beforeStreamSend = System.currentTimeMillis(); - java.io.InputStream messageStream = new java.io.ByteArrayInputStream(messageBytes); - extendedClient.sendStreamMessage(sendRequest2, messageStream, messageBytes.length); - long streamSendTime = System.currentTimeMillis() - beforeStreamSend; - - long afterStreamSendMemory = runtime.totalMemory() - runtime.freeMemory(); - streamingPeakMemory = Math.max(streamingPeakMemory, afterStreamSendMemory); - - System.out.println("Send time: " + streamSendTime + "ms"); - System.out.println("Memory after send: " + (afterStreamSendMemory / 1024 / 1024) + "MB (+" + - ((afterStreamSendMemory - streamingStartMemory) / 1024 / 1024) + "MB)"); - System.out.println("Peak memory: " + (streamingPeakMemory / 1024 / 1024) + "MB (+" + - ((streamingPeakMemory - streamingStartMemory) / 1024 / 1024) + "MB)"); - - String messageBytesStr = new String(messageBytes, StandardCharsets.UTF_8); - assertTrue(messageBytesStr.startsWith("START:"), "Streaming send message should start with START:"); - assertTrue(messageBytesStr.endsWith(":END"), "Streaming send message should end with :END"); - messageBytesStr = null; - - long streamingSendMemoryUsed = streamingPeakMemory - streamingStartMemory; - - messageBytes = null; - System.gc(); - Thread.sleep(200); - - System.out.println("\n=== SEND Comparison ==="); - System.out.println("Traditional send peak: " + (traditionalSendMemoryUsed / 1024 / 1024) + "MB"); - System.out.println("Streaming send peak: " + (streamingSendMemoryUsed / 1024 / 1024) + "MB"); - long sendMemorySaved = traditionalSendMemoryUsed - streamingSendMemoryUsed; - double sendPercentSaved = traditionalSendMemoryUsed > 0 ? - (sendMemorySaved * 100.0) / traditionalSendMemoryUsed : 0; - System.out.println("Memory saved: " + (sendMemorySaved / 1024 / 1024) + "MB (" + - String.format("%.1f", sendPercentSaved) + "% reduction)"); - } - - @Test - public void testRealWorldStreamingVsTraditionalReceive() throws IOException, InterruptedException { - // REAL-WORLD SCENARIO: Compare memory usage when actually PROCESSING the content - // This simulates what happens in real applications where you need to consume the data - - final int messageSizeBytes = 50 * 1024 * 1024; // 50MB - - System.out.println("\n=== REAL-WORLD RECEIVE Comparison ==="); - System.out.println("Message size: " + (messageSizeBytes / 1024 / 1024) + "MB"); - System.out.println("Simulating real app: processing content (counting chars, validating data)"); - - Runtime runtime = Runtime.getRuntime(); - - // TRADITIONAL APPROACH: Load entire content, then process it - String largeMessage = generateLargeMessage(messageSizeBytes); - SendMessageRequest sendRequest1 = SendMessageRequest.builder() - .queueUrl(queueUrl) - .messageBody(largeMessage) - .build(); - extendedClient.sendMessage(sendRequest1); - largeMessage = null; // Release reference - System.gc(); - Thread.sleep(200); - - System.out.println("\n--- Traditional: Load entire content, then process ---"); - - System.gc(); - Thread.sleep(200); - long traditionalStartMemory = runtime.totalMemory() - runtime.freeMemory(); - long traditionalPeakMemory = traditionalStartMemory; - System.out.println("Baseline memory: " + (traditionalStartMemory / 1024 / 1024) + "MB"); - - ReceiveMessageRequest receiveRequest1 = ReceiveMessageRequest.builder() - .queueUrl(queueUrl) - .maxNumberOfMessages(1) - .build(); - - long beforeReceive = System.currentTimeMillis(); - software.amazon.awssdk.services.sqs.model.ReceiveMessageResponse response = - extendedClient.receiveMessage(receiveRequest1); - long receiveTime = System.currentTimeMillis() - beforeReceive; - - long afterReceiveMemory = runtime.totalMemory() - runtime.freeMemory(); - traditionalPeakMemory = Math.max(traditionalPeakMemory, afterReceiveMemory); - - // REAL PROCESSING: Access the body (loads entire content) - String receivedBody = response.messages().get(0).body(); - long afterBodyAccessMemory = runtime.totalMemory() - runtime.freeMemory(); - traditionalPeakMemory = Math.max(traditionalPeakMemory, afterBodyAccessMemory); - - // Simulate real processing: count characters, validate content - long beforeProcessing = System.currentTimeMillis(); - int charCount = receivedBody.length(); - boolean hasValidContent = receivedBody.contains("START:") && receivedBody.contains(":END"); - int dataLines = 0; - for (char c : receivedBody.toCharArray()) { - if (c == '\n') dataLines++; - } - long processingTime = System.currentTimeMillis() - beforeProcessing; - - long afterProcessingMemory = runtime.totalMemory() - runtime.freeMemory(); - traditionalPeakMemory = Math.max(traditionalPeakMemory, afterProcessingMemory); - - System.out.println("Receive time: " + receiveTime + "ms"); - System.out.println("Processing time: " + processingTime + "ms"); - System.out.println("Content length: " + charCount + " chars"); - System.out.println("Data lines: " + dataLines); - System.out.println("Valid content: " + hasValidContent); - System.out.println("Memory after receive: " + (afterReceiveMemory / 1024 / 1024) + "MB (+" + - ((afterReceiveMemory - traditionalStartMemory) / 1024 / 1024) + "MB)"); - System.out.println("Memory after body access: " + (afterBodyAccessMemory / 1024 / 1024) + "MB (+" + - ((afterBodyAccessMemory - traditionalStartMemory) / 1024 / 1024) + "MB)"); - System.out.println("Memory after processing: " + (afterProcessingMemory / 1024 / 1024) + "MB (+" + - ((afterProcessingMemory - traditionalStartMemory) / 1024 / 1024) + "MB)"); - System.out.println("Peak memory: " + (traditionalPeakMemory / 1024 / 1024) + "MB (+" + - ((traditionalPeakMemory - traditionalStartMemory) / 1024 / 1024) + "MB)"); - - long traditionalMemoryUsed = traditionalPeakMemory - traditionalStartMemory; - - receivedBody = null; - response = null; - System.gc(); - Thread.sleep(200); - - // STREAMING APPROACH: Process content as it streams (realistic scenario) - byte[] messageBytes = generateLargeMessage(messageSizeBytes).getBytes(StandardCharsets.UTF_8); - SendMessageRequest sendRequest2 = SendMessageRequest.builder() - .queueUrl(queueUrl) - .build(); - java.io.InputStream sendStream = new java.io.ByteArrayInputStream(messageBytes); - extendedClient.sendStreamMessage(sendRequest2, sendStream, messageBytes.length); - messageBytes = null; - System.gc(); - Thread.sleep(200); - - System.out.println("\n--- Streaming: Process content as it streams ---"); - - System.gc(); - Thread.sleep(200); - long streamingStartMemory = runtime.totalMemory() - runtime.freeMemory(); - long streamingPeakMemory = streamingStartMemory; - System.out.println("Baseline memory: " + (streamingStartMemory / 1024 / 1024) + "MB"); - - ReceiveMessageRequest receiveRequest2 = ReceiveMessageRequest.builder() - .queueUrl(queueUrl) - .maxNumberOfMessages(1) - .build(); - - long beforeStreamReceive = System.currentTimeMillis(); - ReceiveStreamMessageResponse streamResponse = - extendedClient.receiveMessageAsStream(receiveRequest2); - long streamReceiveTime = System.currentTimeMillis() - beforeStreamReceive; - - long afterStreamReceiveMemory = runtime.totalMemory() - runtime.freeMemory(); - streamingPeakMemory = Math.max(streamingPeakMemory, afterStreamReceiveMemory); - - System.out.println("Receive time: " + streamReceiveTime + "ms"); - System.out.println("Memory after receive: " + (afterStreamReceiveMemory / 1024 / 1024) + "MB (+" + - ((afterStreamReceiveMemory - streamingStartMemory) / 1024 / 1024) + "MB)"); - - // REAL STREAMING PROCESSING: Process content as it streams - StreamMessage streamMessage = streamResponse.streamMessages().get(0); - ResponseInputStream payloadStream = streamMessage.getPayloadStream(); - - long beforeStreamProcessing = System.currentTimeMillis(); - long totalBytesRead = 0; - int streamCharCount = 0; - boolean streamHasValidContent = false; - int streamDataLines = 0; - boolean foundStart = false; - boolean foundEnd = false; - - try (ResponseInputStream s = payloadStream) { - byte[] buffer = new byte[8192]; // 8KB buffer - int bytesRead; - - // Patterns to search for anywhere in the stream - String startPattern = "START:"; - String endPattern = ":END"; - - // Sliding window to handle patterns that span chunk boundaries - byte[] previousChunkTail = new byte[0]; - int maxPatternLength = Math.max(startPattern.length(), endPattern.length()); - - while ((bytesRead = s.read(buffer)) != -1) { - totalBytesRead += bytesRead; - streamCharCount += bytesRead; - - // Count newlines by scanning bytes directly (memory efficient) - for (int i = 0; i < bytesRead; i++) { - if (buffer[i] == '\n') streamDataLines++; - } - - // Pattern matching: search in current chunk + overlap from previous chunk - if (!foundStart || !foundEnd) { - // Combine previous chunk tail with current chunk for pattern search - // This handles patterns that span chunk boundaries - byte[] searchBuffer; - - if (previousChunkTail.length > 0) { - searchBuffer = new byte[previousChunkTail.length + bytesRead]; - System.arraycopy(previousChunkTail, 0, searchBuffer, 0, previousChunkTail.length); - System.arraycopy(buffer, 0, searchBuffer, previousChunkTail.length, bytesRead); - } else { - searchBuffer = buffer; - } - - int searchLength = (searchBuffer == buffer) ? bytesRead : searchBuffer.length; - String chunk = new String(searchBuffer, 0, searchLength, StandardCharsets.UTF_8); - - if (!foundStart && chunk.contains(startPattern)) { - foundStart = true; - } - if (!foundEnd && chunk.contains(endPattern)) { - foundEnd = true; - } - - chunk = null; // Release immediately - - // Save tail of current chunk for next iteration (to handle boundary patterns) - int tailLength = Math.min(bytesRead, maxPatternLength); - previousChunkTail = new byte[tailLength]; - System.arraycopy(buffer, bytesRead - tailLength, previousChunkTail, 0, tailLength); - } - - long currentMemory = runtime.totalMemory() - runtime.freeMemory(); - streamingPeakMemory = Math.max(streamingPeakMemory, currentMemory); - } - } - - streamHasValidContent = foundStart && foundEnd; - long streamProcessingTime = System.currentTimeMillis() - beforeStreamProcessing; - - long afterStreamProcessingMemory = runtime.totalMemory() - runtime.freeMemory(); - streamingPeakMemory = Math.max(streamingPeakMemory, afterStreamProcessingMemory); - - System.out.println("Processing time: " + streamProcessingTime + "ms"); - System.out.println("Total bytes: " + (totalBytesRead / 1024 / 1024) + "MB"); - System.out.println("Content length: " + streamCharCount + " chars"); - System.out.println("Data lines: " + streamDataLines); - System.out.println("Valid content: " + streamHasValidContent); - System.out.println("Memory after processing: " + (afterStreamProcessingMemory / 1024 / 1024) + "MB (+" + - ((afterStreamProcessingMemory - streamingStartMemory) / 1024 / 1024) + "MB)"); - System.out.println("Peak memory: " + (streamingPeakMemory / 1024 / 1024) + "MB (+" + - ((streamingPeakMemory - streamingStartMemory) / 1024 / 1024) + "MB)"); - - long streamingMemoryUsed = streamingPeakMemory - streamingStartMemory; - - System.out.println("\n=== REAL-WORLD Comparison ==="); - System.out.println("Traditional peak: " + (traditionalMemoryUsed / 1024 / 1024) + "MB"); - System.out.println("Streaming peak: " + (streamingMemoryUsed / 1024 / 1024) + "MB"); - long memorySaved = traditionalMemoryUsed - streamingMemoryUsed; - double percentSaved = traditionalMemoryUsed > 0 ? - (memorySaved * 100.0) / traditionalMemoryUsed : 0; - System.out.println("Memory saved: " + (memorySaved / 1024 / 1024) + "MB (" + - String.format("%.1f", percentSaved) + "% reduction)"); - - assertTrue(totalBytesRead >= messageSizeBytes - 100 && totalBytesRead <= messageSizeBytes + 100); - assertEquals(charCount, streamCharCount, "Both approaches should count same characters"); - assertEquals(dataLines, streamDataLines, "Both approaches should count same lines"); - assertEquals(hasValidContent, streamHasValidContent, "Both approaches should validate content same way"); - } - - - private String generateLargeMessage(int sizeInBytes) { - int numChars = sizeInBytes; - StringBuilder sb = new StringBuilder(numChars); - sb.append("START:"); - for (int i = 0; i < numChars - 12; i++) { - sb.append((char) ('A' + (i % 26))); - } - sb.append(":END"); - return sb.toString(); - } - - private String readStreamContent(ResponseInputStream stream) throws IOException { - if (stream == null) { - return ""; - } - try (ResponseInputStream s = stream) { - byte[] bytes = s.readAllBytes(); - return new String(bytes, StandardCharsets.UTF_8); - } - } -} \ No newline at end of file From d749d1f3c30554722503c57d1ae4dd3d1de84488 Mon Sep 17 00:00:00 2001 From: akbarsigit Date: Fri, 10 Oct 2025 14:44:08 +0700 Subject: [PATCH 6/8] adjust sync client --- pom.xml | 2 +- .../AmazonSQSExtendedClient.java | 40 +++++++------------ .../ExtendedAsyncClientConfiguration.java | 30 -------------- .../ExtendedClientConfiguration.java | 5 --- .../AmazonSQSExtendedAsyncClientTest.java | 22 +++++----- .../AmazonSQSExtendedClientTest.java | 20 +++++----- .../ExtendedAsyncClientConfigurationTest.java | 37 +---------------- 7 files changed, 36 insertions(+), 120 deletions(-) diff --git a/pom.xml b/pom.xml index 2b97718..8a79f43 100644 --- a/pom.xml +++ b/pom.xml @@ -57,7 +57,7 @@ software.amazon.payloadoffloading payloadoffloading-common - 2.2.0 + 2.3.0 diff --git a/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClient.java b/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClient.java index aab4778..9a6ed07 100644 --- a/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClient.java +++ b/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClient.java @@ -35,18 +35,13 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; -import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; -import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider; -import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; import software.amazon.awssdk.awscore.AwsRequest; import software.amazon.awssdk.awscore.exception.AwsServiceException; import software.amazon.awssdk.core.ResponseInputStream; -import software.amazon.awssdk.services.s3.S3AsyncClient; import software.amazon.awssdk.services.s3.model.GetObjectResponse; import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.exception.SdkException; import software.amazon.awssdk.core.util.VersionInfo; -import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.s3.model.NoSuchKeyException; import software.amazon.awssdk.services.sqs.SqsClient; import software.amazon.awssdk.services.sqs.model.BatchEntryIdsNotDistinctException; @@ -157,22 +152,17 @@ public AmazonSQSExtendedClient(SqsClient sqsClient) { public AmazonSQSExtendedClient(SqsClient sqsClient, ExtendedClientConfiguration extendedClientConfig) { super(sqsClient); this.clientConfiguration = new ExtendedClientConfiguration(extendedClientConfig); + if (clientConfiguration.isStreamUploadEnabled()) { - S3AsyncClient s3AsyncClient = S3AsyncClient.builder() - .multipartEnabled(true) - .multipartConfiguration( - multipartConfig -> multipartConfig - .minimumPartSizeInBytes(clientConfiguration.getStreamUploadPartSize()) - .thresholdInBytes(clientConfiguration.getStreamUploadThreshold()) - ) - .credentialsProvider(DefaultCredentialsProvider.create()) - .region(Region.of(clientConfiguration.getS3Region())) - .build(); - - S3Dao s3Dao = new S3Dao(clientConfiguration.getS3Client(), s3AsyncClient, clientConfiguration.getServerSideEncryptionStrategy(), clientConfiguration.getObjectCannedACL()); + S3Dao s3Dao = new S3Dao( + clientConfiguration.getS3Client(), + clientConfiguration.getServerSideEncryptionStrategy(), + clientConfiguration.getObjectCannedACL(), + clientConfiguration.getStreamUploadPartSize(), + clientConfiguration.getStreamUploadThreshold()); this.payloadStore = new S3BackedStreamPayloadStore(s3Dao, clientConfiguration.getS3BucketName()); } else { - S3Dao s3Dao = new S3Dao(clientConfiguration.getS3Client(), clientConfiguration.getS3AsyncClient(), clientConfiguration.getServerSideEncryptionStrategy(), clientConfiguration.getObjectCannedACL()); + S3Dao s3Dao = new S3Dao(clientConfiguration.getS3Client(), clientConfiguration.getServerSideEncryptionStrategy(), clientConfiguration.getObjectCannedACL()); this.payloadStore = new S3BackedPayloadStore(s3Dao, clientConfiguration.getS3BucketName()); } } @@ -787,7 +777,7 @@ public SendMessageResponse sendStreamMessage(SendMessageRequest sendMessageReque if (!clientConfiguration.isPayloadSupportEnabled()) { // Convert stream to string for non-extended client try { - String messageBody = software.amazon.awssdk.utils.IoUtils.toUtf8String(messageBodyStream); + String messageBody = IoUtils.toUtf8String(messageBodyStream); sendMessageRequest = sendMessageRequest.toBuilder().messageBody(messageBody).build(); return super.sendMessage(sendMessageRequest); } catch (IOException e) { @@ -810,7 +800,7 @@ public SendMessageResponse sendStreamMessage(SendMessageRequest sendMessageReque } else { // Convert stream to string for small messages try { - String messageBody = software.amazon.awssdk.utils.IoUtils.toUtf8String(messageBodyStream); + String messageBody = IoUtils.toUtf8String(messageBodyStream); sendMessageRequest = sendMessageRequest.toBuilder().messageBody(messageBody).build(); } catch (IOException e) { throw new RuntimeException("Failed to read from InputStream", e); @@ -831,8 +821,8 @@ public SendMessageResponse sendStreamMessage(SendMessageRequest sendMessageReque * @return Result of the SendMessageBatch operation returned by the service. */ public SendMessageBatchResponse sendStreamMessageBatch(SendMessageBatchRequest sendMessageBatchRequest, - java.util.List messageBodyStreams, - java.util.List contentLengths) { + List messageBodyStreams, + List contentLengths) { if (sendMessageBatchRequest == null) { String errorMessage = "sendMessageBatchRequest cannot be null."; @@ -859,11 +849,11 @@ public SendMessageBatchResponse sendStreamMessageBatch(SendMessageBatchRequest s if (!clientConfiguration.isPayloadSupportEnabled()) { // Convert streams to strings for non-extended client - java.util.List entries = new java.util.ArrayList<>(); + List entries = new ArrayList<>(); for (int i = 0; i < sendMessageBatchRequest.entries().size(); i++) { SendMessageBatchRequestEntry entry = sendMessageBatchRequest.entries().get(i); try { - String messageBody = software.amazon.awssdk.utils.IoUtils.toUtf8String(messageBodyStreams.get(i)); + String messageBody = IoUtils.toUtf8String(messageBodyStreams.get(i)); entries.add(entry.toBuilder().messageBody(messageBody).build()); } catch (IOException e) { throw new RuntimeException("Failed to read from InputStream", e); @@ -891,7 +881,7 @@ public SendMessageBatchResponse sendStreamMessageBatch(SendMessageBatchRequest s } else { // Convert stream to string for small messages try { - String messageBody = software.amazon.awssdk.utils.IoUtils.toUtf8String(stream); + String messageBody = IoUtils.toUtf8String(stream); entry = entry.toBuilder().messageBody(messageBody).build(); } catch (IOException e) { throw new RuntimeException("Failed to read from InputStream", e); diff --git a/src/main/java/com/amazon/sqs/javamessaging/ExtendedAsyncClientConfiguration.java b/src/main/java/com/amazon/sqs/javamessaging/ExtendedAsyncClientConfiguration.java index baf7d6e..e96fff2 100644 --- a/src/main/java/com/amazon/sqs/javamessaging/ExtendedAsyncClientConfiguration.java +++ b/src/main/java/com/amazon/sqs/javamessaging/ExtendedAsyncClientConfiguration.java @@ -212,34 +212,4 @@ public ExtendedAsyncClientConfiguration withServerSideEncryption(ServerSideEncry this.setServerSideEncryptionStrategy(serverSideEncryption); return this; } - - /** - * Enables or disables stream upload support for large payload storage operations. - * @param enabled true to enable stream uploads when threshold exceeded. - * @return updated configuration - */ - public ExtendedAsyncClientConfiguration withStreamUploadEnabled(boolean enabled) { - setStreamUploadEnabled(enabled); - return this; - } - - /** - * Sets the threshold for stream upload in bytes. - * @param threshold the threshold in bytes - * @return updated configuration - */ - public ExtendedAsyncClientConfiguration withStreamUploadThreshold(int threshold) { - setStreamUploadThreshold(threshold); - return this; - } - - public ExtendedAsyncClientConfiguration withStreamUploadPartSize(int partSize) { - setStreamUploadPartSize(partSize); - return this; - } - - public ExtendedAsyncClientConfiguration withS3Region(String s3Region) { - setS3Region(s3Region); - return this; - } } \ No newline at end of file diff --git a/src/main/java/com/amazon/sqs/javamessaging/ExtendedClientConfiguration.java b/src/main/java/com/amazon/sqs/javamessaging/ExtendedClientConfiguration.java index a8ddcfc..9e0d41e 100644 --- a/src/main/java/com/amazon/sqs/javamessaging/ExtendedClientConfiguration.java +++ b/src/main/java/com/amazon/sqs/javamessaging/ExtendedClientConfiguration.java @@ -252,11 +252,6 @@ public ExtendedClientConfiguration withStreamUploadPartSize(int partSize) { return this; } - public ExtendedClientConfiguration withS3Region(String s3Region) { - setS3Region(s3Region); - return this; - } - /** * Enables support for large-payload messages. * diff --git a/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedAsyncClientTest.java b/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedAsyncClientTest.java index ff941c8..f94512d 100644 --- a/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedAsyncClientTest.java +++ b/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedAsyncClientTest.java @@ -782,9 +782,8 @@ public void testReceiveMessageAsStream_PayloadSupportDisabled_ReturnsMessagesWit public void testReceiveMessageAsStream_LargeMessage_WithStreamStore_ReturnsMessageWithStream() throws IOException { ExtendedAsyncClientConfiguration config = new ExtendedAsyncClientConfiguration() .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME) - .withPayloadSizeThreshold(262144) // 256KB - .withStreamUploadEnabled(true) - .withStreamUploadThreshold(1024 * 1024); // 1MB + .withPayloadSizeThreshold(262144); // 256KB + config.setStreamUploadEnabled(true); // Enable stream uploads AmazonSQSExtendedAsyncClient clientWithStream = new AmazonSQSExtendedAsyncClient(mockSqsBackend, config); @@ -813,6 +812,7 @@ public void testReceiveMessageAsStream_LargeMessage_WithStreamStore_ReturnsMessa when(mockSqsBackend.receiveMessage(any(ReceiveMessageRequest.class))) .thenReturn(CompletableFuture.completedFuture(sqsResponse)); + // Mock S3 to return ResponseInputStream for stream retrieval @SuppressWarnings("unchecked") CompletableFuture> futureStream = CompletableFuture.completedFuture(mockStream); when(mockS3.getObject(any(GetObjectRequest.class), any(AsyncResponseTransformer.class))) @@ -834,8 +834,7 @@ public void testReceiveMessageAsStream_LargeMessage_WithStreamStore_ReturnsMessa @Test public void testReceiveMessageAsStream_LargeMessage_WithoutStreamStore_FallsBackToRegularRetrieval() { ExtendedAsyncClientConfiguration config = new ExtendedAsyncClientConfiguration() - .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME) - .withStreamUploadEnabled(false); + .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME); AmazonSQSExtendedAsyncClient clientWithRegularStore = new AmazonSQSExtendedAsyncClient(mockSqsBackend, config); @@ -880,8 +879,7 @@ public void testReceiveMessageAsStream_LargeMessage_WithoutStreamStore_FallsBack public void testReceiveMessageAsStream_StreamRetrievalFails_IgnoreNotFoundEnabled_DeletesMessage() { ExtendedAsyncClientConfiguration config = new ExtendedAsyncClientConfiguration() .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME) - .withIgnorePayloadNotFound(true) - .withStreamUploadEnabled(true); + .withIgnorePayloadNotFound(true); AmazonSQSExtendedAsyncClient clientWithIgnore = new AmazonSQSExtendedAsyncClient(mockSqsBackend, config); @@ -928,8 +926,7 @@ public void testReceiveMessageAsStream_StreamRetrievalFails_IgnoreNotFoundEnable public void testReceiveMessageAsStream_StreamRetrievalFails_IgnoreNotFoundDisabled_ThrowsException() { ExtendedAsyncClientConfiguration config = new ExtendedAsyncClientConfiguration() .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME) - .withIgnorePayloadNotFound(false) - .withStreamUploadEnabled(true); + .withIgnorePayloadNotFound(false); AmazonSQSExtendedAsyncClient clientWithIgnore = new AmazonSQSExtendedAsyncClient(mockSqsBackend, config); ReceiveMessageRequest request = ReceiveMessageRequest.builder() @@ -967,9 +964,8 @@ public void testReceiveMessageAsStream_StreamRetrievalFails_IgnoreNotFoundDisabl public void testReceiveMessageAsStream_MultipleMessages_MixedTypes() throws IOException { ExtendedAsyncClientConfiguration config = new ExtendedAsyncClientConfiguration() .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME) - .withPayloadSizeThreshold(262144) // 256KB - .withStreamUploadEnabled(true) - .withStreamUploadThreshold(1024 * 1024); // 1MB + .withPayloadSizeThreshold(262144); // 256KB + config.setStreamUploadEnabled(true); // Enable stream uploads AmazonSQSExtendedAsyncClient clientWithStream = new AmazonSQSExtendedAsyncClient(mockSqsBackend, config); @@ -999,12 +995,14 @@ public void testReceiveMessageAsStream_MultipleMessages_MixedTypes() throws IOEx .messages(Arrays.asList(smallMessage, largeMessage)) .build(); + @SuppressWarnings("unchecked") ResponseInputStream mockStream = mock(ResponseInputStream.class); when(mockStream.read(any(byte[].class))).thenReturn(-1); when(mockSqsBackend.receiveMessage(any(ReceiveMessageRequest.class))) .thenReturn(CompletableFuture.completedFuture(sqsResponse)); + @SuppressWarnings("unchecked") CompletableFuture> futureStream = CompletableFuture.completedFuture(mockStream); when(mockS3.getObject(any(GetObjectRequest.class), any(AsyncResponseTransformer.class))) .thenReturn(futureStream); diff --git a/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientTest.java b/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientTest.java index e51d6cb..858abb6 100644 --- a/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientTest.java +++ b/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientTest.java @@ -85,6 +85,8 @@ import java.io.IOException; import java.io.InputStream; +import java.nio.charset.StandardCharsets; +import java.io.ByteArrayInputStream; /** @@ -121,7 +123,6 @@ public class AmazonSQSExtendedClientTest { // Stream upload thresholds private static final int STREAM_UPLOAD_THRESHOLD = 5 * 1024 * 1024; // 5MB default - private static final int LESS_THAN_STREAM_THRESHOLD = STREAM_UPLOAD_THRESHOLD - 1; private static final int MORE_THAN_STREAM_THRESHOLD = STREAM_UPLOAD_THRESHOLD + 1; // Stream part size @@ -815,7 +816,6 @@ public void testReceiveMessageAsStream_LargeMessage_WithStreamStore_ReturnsMessa .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME) .withStreamUploadEnabled(true) .withStreamUploadPartSize(STREAM_UPLOAD_PART_SIZE) - .withS3Region("ap-south-1") .withStreamUploadThreshold(STREAM_UPLOAD_THRESHOLD); AmazonSQSExtendedClient sqsExtended = new AmazonSQSExtendedClient(mockSqsBackend, extendedClientConfiguration); @@ -839,13 +839,12 @@ public void testReceiveMessageAsStream_LargeMessage_WithStreamStore_ReturnsMessa public void testSendStreamMessage_LargeFileUpload_StoresInS3AndSendsPointer() { int fileSizeBytes = MORE_THAN_STREAM_THRESHOLD; String fileContent = generateStringWithLength(fileSizeBytes); - java.io.InputStream fileStream = new java.io.ByteArrayInputStream(fileContent.getBytes(java.nio.charset.StandardCharsets.UTF_8)); + InputStream fileStream = new ByteArrayInputStream(fileContent.getBytes(StandardCharsets.UTF_8)); ExtendedClientConfiguration streamConfig = new ExtendedClientConfiguration() .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME) .withStreamUploadEnabled(true) - .withStreamUploadThreshold(STREAM_UPLOAD_THRESHOLD) - .withS3Region("us-east-1"); + .withStreamUploadThreshold(STREAM_UPLOAD_THRESHOLD); SqsClient streamClient = spy(new AmazonSQSExtendedClient(mockSqsBackend, streamConfig)); @@ -916,8 +915,7 @@ public void testSendStreamMessageBatch_MixedSizes_OnlyLargeMessagesUseS3() { ExtendedClientConfiguration streamConfig = new ExtendedClientConfiguration() .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME) .withStreamUploadEnabled(true) - .withStreamUploadThreshold(STREAM_UPLOAD_THRESHOLD) // 5MB - .withS3Region("us-east-1"); + .withStreamUploadThreshold(STREAM_UPLOAD_THRESHOLD); // 5MB SqsClient streamClient = spy(new AmazonSQSExtendedClient(mockSqsBackend, streamConfig)); @@ -931,7 +929,7 @@ public void testSendStreamMessageBatch_MixedSizes_OnlyLargeMessagesUseS3() { .id("msg1") .messageAttributes(ImmutableMap.of("size", MessageAttributeValue.builder().stringValue("small").dataType("String").build())) .build()); - streams.add(new java.io.ByteArrayInputStream(smallMsg.getBytes(java.nio.charset.StandardCharsets.UTF_8))); + streams.add(new ByteArrayInputStream(smallMsg.getBytes(StandardCharsets.UTF_8))); contentLengths.add((long) smallMsg.length()); // Large message (300KB - above SQS limit but below stream threshold) @@ -940,7 +938,7 @@ public void testSendStreamMessageBatch_MixedSizes_OnlyLargeMessagesUseS3() { .id("msg2") .messageAttributes(ImmutableMap.of("size", MessageAttributeValue.builder().stringValue("large").dataType("String").build())) .build()); - streams.add(new java.io.ByteArrayInputStream(largeMsg.getBytes(java.nio.charset.StandardCharsets.UTF_8))); + streams.add(new java.io.ByteArrayInputStream(largeMsg.getBytes(StandardCharsets.UTF_8))); contentLengths.add((long) largeMsg.length()); // Very large message (6MB - above stream threshold, uses multipart) @@ -948,7 +946,7 @@ public void testSendStreamMessageBatch_MixedSizes_OnlyLargeMessagesUseS3() { entries.add(SendMessageBatchRequestEntry.builder() .id("msg3") .build()); - streams.add(new java.io.ByteArrayInputStream(veryLargeMsg.getBytes(java.nio.charset.StandardCharsets.UTF_8))); + streams.add(new java.io.ByteArrayInputStream(veryLargeMsg.getBytes(StandardCharsets.UTF_8))); contentLengths.add((long) veryLargeMsg.length()); SendMessageBatchRequest batchRequest = SendMessageBatchRequest.builder() @@ -976,7 +974,7 @@ public void testSendStreamMessageBatch_MixedSizes_OnlyLargeMessagesUseS3() { public void testSendStreamMessage_WithEncryption_AppliesKMSToS3Upload() { int dataSize = MORE_THAN_SQS_SIZE_LIMIT; String sensitiveData = generateStringWithLength(dataSize); - java.io.InputStream stream = new java.io.ByteArrayInputStream(sensitiveData.getBytes(java.nio.charset.StandardCharsets.UTF_8)); + InputStream stream = new ByteArrayInputStream(sensitiveData.getBytes(StandardCharsets.UTF_8)); SendMessageRequest request = SendMessageRequest.builder() .queueUrl(SQS_QUEUE_URL) diff --git a/src/test/java/com/amazon/sqs/javamessaging/ExtendedAsyncClientConfigurationTest.java b/src/test/java/com/amazon/sqs/javamessaging/ExtendedAsyncClientConfigurationTest.java index ae8fa81..b772586 100644 --- a/src/test/java/com/amazon/sqs/javamessaging/ExtendedAsyncClientConfigurationTest.java +++ b/src/test/java/com/amazon/sqs/javamessaging/ExtendedAsyncClientConfigurationTest.java @@ -87,45 +87,10 @@ public void testLargePayloadSupportEnabledWithDeleteFromS3Disabled() { } @Test - public void testStreamUploadDisabledByDefault() { - ExtendedAsyncClientConfiguration extendedClientConfiguration = new ExtendedAsyncClientConfiguration(); - - assertFalse(extendedClientConfiguration.isStreamUploadEnabled()); - assertEquals(5 * 1024 * 1024, extendedClientConfiguration.getStreamUploadThreshold()); // 5MB default - } - - @Test - public void testStreamUploadEnabled() { + public void testStreamUploadEnabledEnabled() { ExtendedAsyncClientConfiguration extendedClientConfiguration = new ExtendedAsyncClientConfiguration(); extendedClientConfiguration.withStreamUploadEnabled(true); assertTrue(extendedClientConfiguration.isStreamUploadEnabled()); } - - @Test - public void testStreamUploadThresholdCustomValue() { - int customThreshold = 10 * 1024 * 1024; // 10MB - ExtendedAsyncClientConfiguration extendedClientConfiguration = new ExtendedAsyncClientConfiguration(); - extendedClientConfiguration.withStreamUploadThreshold(customThreshold); - - assertEquals(customThreshold, extendedClientConfiguration.getStreamUploadThreshold()); - } - - @Test - public void testStreamUploadPartSizeCustomValue() { - int customPartSize = 10 * 1024 * 1024; // 10MB - ExtendedAsyncClientConfiguration extendedClientConfiguration = new ExtendedAsyncClientConfiguration(); - extendedClientConfiguration.withStreamUploadPartSize(customPartSize); - - assertEquals(customPartSize, extendedClientConfiguration.getStreamUploadPartSize()); - } - - @Test - public void testStreamUploadPartSizeBelowMinimumRoundedUpTo5MB() { - int belowMinimum = 3 * 1024 * 1024; // 3MB (below 5MB minimum) - ExtendedAsyncClientConfiguration extendedClientConfiguration = new ExtendedAsyncClientConfiguration(); - extendedClientConfiguration.withStreamUploadPartSize(belowMinimum); - - assertEquals(5 * 1024 * 1024, extendedClientConfiguration.getStreamUploadPartSize()); - } } From 6811bea7d18dfdafea353159c8fc74b5424674d8 Mon Sep 17 00:00:00 2001 From: akbarsigit Date: Fri, 10 Oct 2025 15:57:26 +0700 Subject: [PATCH 7/8] cleanup --- .../sqs/javamessaging/AmazonSQSExtendedAsyncClient.java | 4 ++-- .../sqs/javamessaging/ExtendedAsyncClientConfiguration.java | 6 ++++++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedAsyncClient.java b/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedAsyncClient.java index 0bb5ceb..7e8b1d6 100644 --- a/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedAsyncClient.java +++ b/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedAsyncClient.java @@ -134,8 +134,8 @@ public AmazonSQSExtendedAsyncClient(SqsAsyncClient sqsClient, super(sqsClient); this.clientConfiguration = new ExtendedAsyncClientConfiguration(extendedClientConfig); S3AsyncDao s3Dao = new S3AsyncDao(clientConfiguration.getS3AsyncClient(), - clientConfiguration.getServerSideEncryptionStrategy(), - clientConfiguration.getObjectCannedACL()); + clientConfiguration.getServerSideEncryptionStrategy(), + clientConfiguration.getObjectCannedACL()); if (clientConfiguration.isStreamUploadEnabled()) { this.payloadStore = new S3BackedStreamPayloadStoreAsync(s3Dao, clientConfiguration.getS3BucketName()); diff --git a/src/main/java/com/amazon/sqs/javamessaging/ExtendedAsyncClientConfiguration.java b/src/main/java/com/amazon/sqs/javamessaging/ExtendedAsyncClientConfiguration.java index e96fff2..281bf3b 100644 --- a/src/main/java/com/amazon/sqs/javamessaging/ExtendedAsyncClientConfiguration.java +++ b/src/main/java/com/amazon/sqs/javamessaging/ExtendedAsyncClientConfiguration.java @@ -212,4 +212,10 @@ public ExtendedAsyncClientConfiguration withServerSideEncryption(ServerSideEncry this.setServerSideEncryptionStrategy(serverSideEncryption); return this; } + + @Override + public ExtendedAsyncClientConfiguration withStreamUploadEnabled(boolean enabled) { + setStreamUploadEnabled(enabled); + return this; + } } \ No newline at end of file From 6b65e208a016f6d47c2a6d2614b3870b12de6a8f Mon Sep 17 00:00:00 2001 From: akbarsigit Date: Fri, 10 Oct 2025 16:42:19 +0700 Subject: [PATCH 8/8] cleanup --- .../javamessaging/AmazonSQSExtendedAsyncClientTest.java | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedAsyncClientTest.java b/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedAsyncClientTest.java index f94512d..e3a42e9 100644 --- a/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedAsyncClientTest.java +++ b/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedAsyncClientTest.java @@ -783,7 +783,7 @@ public void testReceiveMessageAsStream_LargeMessage_WithStreamStore_ReturnsMessa ExtendedAsyncClientConfiguration config = new ExtendedAsyncClientConfiguration() .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME) .withPayloadSizeThreshold(262144); // 256KB - config.setStreamUploadEnabled(true); // Enable stream uploads + config.setStreamUploadEnabled(true); AmazonSQSExtendedAsyncClient clientWithStream = new AmazonSQSExtendedAsyncClient(mockSqsBackend, config); @@ -807,12 +807,11 @@ public void testReceiveMessageAsStream_LargeMessage_WithStreamStore_ReturnsMessa @SuppressWarnings("unchecked") ResponseInputStream mockStream = mock(ResponseInputStream.class); - when(mockStream.read(any(byte[].class))).thenReturn(-1); // End of stream + when(mockStream.read(any(byte[].class))).thenReturn(-1); when(mockSqsBackend.receiveMessage(any(ReceiveMessageRequest.class))) .thenReturn(CompletableFuture.completedFuture(sqsResponse)); - // Mock S3 to return ResponseInputStream for stream retrieval @SuppressWarnings("unchecked") CompletableFuture> futureStream = CompletableFuture.completedFuture(mockStream); when(mockS3.getObject(any(GetObjectRequest.class), any(AsyncResponseTransformer.class))) @@ -826,8 +825,6 @@ public void testReceiveMessageAsStream_LargeMessage_WithStreamStore_ReturnsMessa assertEquals("msg1", streamMessage.getMessage().messageId()); assertTrue(streamMessage.hasStreamPayload()); assertSame(mockStream, streamMessage.getPayloadStream()); - - // Verify receipt handle was modified assertTrue(streamMessage.getMessage().receiptHandle().contains("test-key")); }