diff --git a/.changes/next-release/deprecation-AWSSDKforJavav2-6069419.json b/.changes/next-release/deprecation-AWSSDKforJavav2-6069419.json new file mode 100644 index 000000000000..c30842499842 --- /dev/null +++ b/.changes/next-release/deprecation-AWSSDKforJavav2-6069419.json @@ -0,0 +1,6 @@ +{ + "type": "deprecation", + "category": "AWS SDK for Java v2", + "contributor": "", + "description": "Deprecate `AsyncRequestBody#split` in favor of `AsyncRequestBody#splitCloseable` that takes the same input but returns `SdkPublisher`" +} diff --git a/.changes/next-release/feature-AWSSDKforJavav2-500254f.json b/.changes/next-release/feature-AWSSDKforJavav2-500254f.json new file mode 100644 index 000000000000..5f1a9031602b --- /dev/null +++ b/.changes/next-release/feature-AWSSDKforJavav2-500254f.json @@ -0,0 +1,6 @@ +{ + "type": "feature", + "category": "AWS SDK for Java v2", + "contributor": "", + "description": "Add `AsyncRequestBody#splitCloseable` API that returns a Publisher of `ClosableAsyncRequestBody`" +} diff --git a/.changes/next-release/feature-AWSSDKforJavav2-72c5f55.json b/.changes/next-release/feature-AWSSDKforJavav2-72c5f55.json new file mode 100644 index 000000000000..d7dbd07cafa4 --- /dev/null +++ b/.changes/next-release/feature-AWSSDKforJavav2-72c5f55.json @@ -0,0 +1,6 @@ +{ + "type": "feature", + "category": "AWS SDK for Java v2", + "contributor": "", + "description": "Introduce CloseableAsyncRequestBody interface that extends both AsyncRequestBody and SdkAutoClosable interfaces" +} diff --git a/.changes/next-release/feature-AWSSDKforJavav2-ed2c57a.json b/.changes/next-release/feature-AWSSDKforJavav2-ed2c57a.json new file mode 100644 index 000000000000..e30f700e4c33 --- /dev/null +++ b/.changes/next-release/feature-AWSSDKforJavav2-ed2c57a.json @@ -0,0 +1,6 @@ +{ + "type": "feature", + "category": "AWS SDK for Java v2", + "contributor": "", + "description": "Introduce BufferedSplittableAsyncRequestBody that enables splitting into retryable sub-request bodies." +} diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/AsyncRequestBody.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/AsyncRequestBody.java index 3fd8c3cc0165..55a4249957e8 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/AsyncRequestBody.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/AsyncRequestBody.java @@ -500,7 +500,6 @@ static AsyncRequestBody empty() { return fromBytes(new byte[0]); } - /** * Converts this {@link AsyncRequestBody} to a publisher of {@link AsyncRequestBody}s, each of which publishes a specific * portion of the original data, based on the provided {@link AsyncRequestBodySplitConfiguration}. The default chunk size @@ -513,12 +512,45 @@ static AsyncRequestBody empty() { * than or equal to {@code chunkSizeInBytes}. Note that this behavior may be different if a specific implementation of this * interface overrides this method. * - * @see AsyncRequestBodySplitConfiguration + * @deprecated use {@link #splitCloseable(AsyncRequestBodySplitConfiguration)} instead. */ + @Deprecated default SdkPublisher split(AsyncRequestBodySplitConfiguration splitConfiguration) { Validate.notNull(splitConfiguration, "splitConfiguration"); + return SplittingPublisher.builder() + .asyncRequestBody(this) + .splitConfiguration(splitConfiguration) + .retryableSubAsyncRequestBodyEnabled(false) + .build() + .map(r -> r); + } - return new SplittingPublisher(this, splitConfiguration); + /** + * Converts this {@link AsyncRequestBody} to a publisher of {@link CloseableAsyncRequestBody}s, each of which publishes + * specific portion of the original data, based on the provided {@link AsyncRequestBodySplitConfiguration}. The default chunk + * size is 2MB and the default buffer size is 8MB. + * + *

+ * The default implementation behaves the same as {@link #split(AsyncRequestBodySplitConfiguration)}. This behavior may + * vary in different implementations. + * + *

+ * Caller is responsible for closing {@link CloseableAsyncRequestBody} when it is ready to be disposed to release any + * resources. + * + *

Note: This method is primarily intended for use by AWS SDK high-level libraries and internal components. + * SDK customers should typically use higher-level APIs provided by service clients rather than calling this method directly. + * + * @see #splitCloseable(Consumer) + * @see AsyncRequestBodySplitConfiguration + */ + default SdkPublisher splitCloseable(AsyncRequestBodySplitConfiguration splitConfiguration) { + Validate.notNull(splitConfiguration, "splitConfiguration"); + return SplittingPublisher.builder() + .asyncRequestBody(this) + .splitConfiguration(splitConfiguration) + .retryableSubAsyncRequestBodyEnabled(false) + .build(); } /** @@ -526,12 +558,29 @@ default SdkPublisher split(AsyncRequestBodySplitConfiguration * avoiding the need to create one manually via {@link AsyncRequestBodySplitConfiguration#builder()}. * * @see #split(AsyncRequestBodySplitConfiguration) + * @deprecated use {@link #splitCloseable(Consumer)} instead */ + @Deprecated default SdkPublisher split(Consumer splitConfiguration) { Validate.notNull(splitConfiguration, "splitConfiguration"); return split(AsyncRequestBodySplitConfiguration.builder().applyMutation(splitConfiguration).build()); } + /** + * This is a convenience method that passes an instance of the {@link AsyncRequestBodySplitConfiguration} builder, + * avoiding the need to create one manually via {@link AsyncRequestBodySplitConfiguration#builder()}. + * + *

Note: This method is primarily intended for use by AWS SDK high-level libraries and internal components. + * SDK customers should typically use higher-level APIs provided by service clients rather than calling this method directly. + * + * @see #splitCloseable(AsyncRequestBodySplitConfiguration) + */ + default SdkPublisher splitCloseable( + Consumer splitConfiguration) { + Validate.notNull(splitConfiguration, "splitConfiguration"); + return splitCloseable(AsyncRequestBodySplitConfiguration.builder().applyMutation(splitConfiguration).build()); + } + @SdkProtectedApi enum BodyType { FILE("File", "f"), diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/BufferedSplittableAsyncRequestBody.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/BufferedSplittableAsyncRequestBody.java new file mode 100644 index 000000000000..a1c46238dfee --- /dev/null +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/BufferedSplittableAsyncRequestBody.java @@ -0,0 +1,113 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.async; + +import java.nio.ByteBuffer; +import java.util.Optional; +import org.reactivestreams.Subscriber; +import software.amazon.awssdk.annotations.SdkPublicApi; +import software.amazon.awssdk.core.internal.async.SplittingPublisher; +import software.amazon.awssdk.utils.Validate; + +/** + * An {@link AsyncRequestBody} decorator that enables splitting into retryable sub-request bodies. + * + *

This wrapper allows any {@link AsyncRequestBody} to be split into multiple parts where each part + * can be retried independently. When split, each sub-body buffers its portion of data, enabling + * resubscription if a retry is needed (e.g., due to network failures or service errors).

+ * + *

Retry Requirements:

+ *

Retry is only possible if all the data has been successfully buffered during the first subscription. + * If the first subscriber fails to consume all the data (e.g., due to early cancellation or errors), + * subsequent retry attempts will fail since the complete data set is not available for resubscription.

+ * + *

Usage Example:

+ * {@snippet : + * AsyncRequestBody originalBody = AsyncRequestBody.fromString("Hello World"); + * BufferedSplittableAsyncRequestBody retryableBody = + * BufferedSplittableAsyncRequestBody.create(originalBody); + * } + * + *

Performance Considerations:

+ *

This implementation buffers data in memory to enable retries, but memory usage is controlled by + * the {@code bufferSizeInBytes} configuration. However, this buffering limits the ability to request + * more data from the original AsyncRequestBody until buffered data is consumed (i.e., when subscribers + * closes sub-body), which may increase latency compared to non-buffered implementations. + * + * @see AsyncRequestBody + * @see AsyncRequestBodySplitConfiguration + * @see CloseableAsyncRequestBody + */ +@SdkPublicApi +public final class BufferedSplittableAsyncRequestBody implements AsyncRequestBody { + private final AsyncRequestBody delegate; + + private BufferedSplittableAsyncRequestBody(AsyncRequestBody delegate) { + this.delegate = delegate; + } + + /** + * Creates a new {@link BufferedSplittableAsyncRequestBody} that wraps the provided {@link AsyncRequestBody}. + * + * @param delegate the {@link AsyncRequestBody} to wrap and make retryable. Must not be null. + * @return a new {@link BufferedSplittableAsyncRequestBody} instance + * @throws NullPointerException if delegate is null + */ + public static BufferedSplittableAsyncRequestBody create(AsyncRequestBody delegate) { + Validate.paramNotNull(delegate, "delegate"); + return new BufferedSplittableAsyncRequestBody(delegate); + } + + @Override + public Optional contentLength() { + return delegate.contentLength(); + } + + /** + * Splits this request body into multiple retryable parts based on the provided configuration. + * + *

Each part returned by the publisher will be a {@link CloseableAsyncRequestBody} that buffers + * its portion of data, enabling resubscription for retry scenarios. This is the key difference from non-buffered splitting - + * each part can be safely retried without data loss. + * + *

The splitting process respects the chunk size and buffer size specified in the configuration + * to optimize memory usage. + * + *

The subscriber MUST close each {@link CloseableAsyncRequestBody} to ensure resource is released + * + * @param splitConfiguration configuration specifying how to split the request body + * @return a publisher that emits retryable closable request body parts + * @see AsyncRequestBodySplitConfiguration + */ + @Override + public SdkPublisher splitCloseable(AsyncRequestBodySplitConfiguration splitConfiguration) { + return SplittingPublisher.builder() + .asyncRequestBody(this) + .splitConfiguration(splitConfiguration) + .retryableSubAsyncRequestBodyEnabled(true) + .build(); + } + + @Override + public void subscribe(Subscriber s) { + delegate.subscribe(s); + } + + @Override + public String body() { + return delegate.body(); + } +} diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/CloseableAsyncRequestBody.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/CloseableAsyncRequestBody.java new file mode 100644 index 000000000000..18c656c3018b --- /dev/null +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/CloseableAsyncRequestBody.java @@ -0,0 +1,26 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.async; + +import software.amazon.awssdk.annotations.SdkPublicApi; +import software.amazon.awssdk.utils.SdkAutoCloseable; + +/** + * An extension of {@link AsyncRequestBody} that is closable. + */ +@SdkPublicApi +public interface CloseableAsyncRequestBody extends AsyncRequestBody, SdkAutoCloseable { +} diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/listener/AsyncRequestBodyListener.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/listener/AsyncRequestBodyListener.java index a37b226d4bc3..8835d1a007ae 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/listener/AsyncRequestBodyListener.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/listener/AsyncRequestBodyListener.java @@ -23,6 +23,7 @@ import software.amazon.awssdk.annotations.SdkProtectedApi; import software.amazon.awssdk.core.async.AsyncRequestBody; import software.amazon.awssdk.core.async.AsyncRequestBodySplitConfiguration; +import software.amazon.awssdk.core.async.CloseableAsyncRequestBody; import software.amazon.awssdk.core.async.SdkPublisher; import software.amazon.awssdk.utils.Logger; import software.amazon.awssdk.utils.Validate; @@ -76,6 +77,17 @@ public SdkPublisher split(Consumer splitCloseable(AsyncRequestBodySplitConfiguration splitConfiguration) { + return delegate.splitCloseable(splitConfiguration); + } + + @Override + public SdkPublisher splitCloseable( + Consumer splitConfiguration) { + return delegate.splitCloseable(splitConfiguration); + } + @Override public void subscribe(Subscriber s) { invoke(() -> listener.publisherSubscribe(s), "publisherSubscribe"); diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/ByteBuffersAsyncRequestBody.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/ByteBuffersAsyncRequestBody.java index a304d75ccf94..e1f6d9e8cb05 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/ByteBuffersAsyncRequestBody.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/ByteBuffersAsyncRequestBody.java @@ -76,7 +76,9 @@ public final class ByteBuffersAsyncRequestBody implements AsyncRequestBody, SdkA private final Object lock = new Object(); private boolean closed; - private ByteBuffersAsyncRequestBody(String mimetype, Long length, List buffers) { + private ByteBuffersAsyncRequestBody(String mimetype, + Long length, + List buffers) { this.mimetype = mimetype; this.buffers = buffers; this.length = length; @@ -121,6 +123,10 @@ public String body() { return BodyType.BYTES.getName(); } + public static ByteBuffersAsyncRequestBody of(List buffers, long length) { + return new ByteBuffersAsyncRequestBody(Mimetype.MIMETYPE_OCTET_STREAM, length, buffers); + } + public static ByteBuffersAsyncRequestBody of(List buffers) { long length = buffers.stream() .mapToLong(ByteBuffer::remaining) @@ -129,7 +135,11 @@ public static ByteBuffersAsyncRequestBody of(List buffers) { } public static ByteBuffersAsyncRequestBody of(ByteBuffer... buffers) { - return of(Arrays.asList(buffers)); + List byteBuffers = Arrays.asList(buffers); + long length = byteBuffers.stream() + .mapToLong(ByteBuffer::remaining) + .sum(); + return of(byteBuffers, length); } public static ByteBuffersAsyncRequestBody of(Long length, ByteBuffer... buffers) { diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/FileAsyncRequestBody.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/FileAsyncRequestBody.java index f5dcc164f61c..2af70796f4e1 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/FileAsyncRequestBody.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/FileAsyncRequestBody.java @@ -34,6 +34,7 @@ import software.amazon.awssdk.annotations.SdkInternalApi; import software.amazon.awssdk.core.async.AsyncRequestBody; import software.amazon.awssdk.core.async.AsyncRequestBodySplitConfiguration; +import software.amazon.awssdk.core.async.CloseableAsyncRequestBody; import software.amazon.awssdk.core.async.SdkPublisher; import software.amazon.awssdk.core.internal.util.Mimetype; import software.amazon.awssdk.core.internal.util.NoopSubscription; @@ -86,6 +87,11 @@ public SdkPublisher split(AsyncRequestBodySplitConfiguration s return new FileAsyncRequestBodySplitHelper(this, splitConfiguration).split(); } + @Override + public SdkPublisher splitCloseable(AsyncRequestBodySplitConfiguration splitConfiguration) { + return split(splitConfiguration).map(body -> new ClosableAsyncRequestBodyWrapper(body)); + } + public Path path() { return path; } @@ -436,4 +442,32 @@ private void signalOnError(Throwable t) { private static AsynchronousFileChannel openInputChannel(Path path) throws IOException { return AsynchronousFileChannel.open(path, StandardOpenOption.READ); } + + private static class ClosableAsyncRequestBodyWrapper implements CloseableAsyncRequestBody { + private final AsyncRequestBody delegate; + + ClosableAsyncRequestBodyWrapper(AsyncRequestBody body) { + this.delegate = body; + } + + @Override + public Optional contentLength() { + return delegate.contentLength(); + } + + @Override + public void subscribe(Subscriber s) { + delegate.subscribe(s); + } + + @Override + public void close() { + // no op + } + + @Override + public String body() { + return delegate.body(); + } + } } \ No newline at end of file diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/NonRetryableSubAsyncRequestBody.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/NonRetryableSubAsyncRequestBody.java new file mode 100644 index 000000000000..221f9246a5e6 --- /dev/null +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/NonRetryableSubAsyncRequestBody.java @@ -0,0 +1,122 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.internal.async; + +import java.nio.ByteBuffer; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Consumer; +import org.reactivestreams.Subscriber; +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.core.exception.NonRetryableException; +import software.amazon.awssdk.core.internal.util.NoopSubscription; +import software.amazon.awssdk.utils.Logger; +import software.amazon.awssdk.utils.Validate; +import software.amazon.awssdk.utils.async.SimplePublisher; + +/** + * A {@link SubAsyncRequestBody} implementation that doesn't support resubscribe/retry + */ +@SdkInternalApi +public final class NonRetryableSubAsyncRequestBody implements SubAsyncRequestBody { + private static final Logger log = Logger.loggerFor(NonRetryableSubAsyncRequestBody.class); + private final SubAsyncRequestBodyConfiguration configuration; + private final int partNumber; + private final boolean contentLengthKnown; + private final String sourceBodyName; + private final SimplePublisher delegate = new SimplePublisher<>(); + private final AtomicBoolean subscribeCalled = new AtomicBoolean(false); + private volatile long bufferedLength = 0; + private final Consumer onNumBytesReceived; + private final Consumer onNumBytesConsumed; + + /** + * Creates a new NonRetryableSubAsyncRequestBody with the given configuration. + */ + public NonRetryableSubAsyncRequestBody(SubAsyncRequestBodyConfiguration configuration) { + this.configuration = Validate.paramNotNull(configuration, "configuration"); + this.partNumber = configuration.partNumber(); + this.contentLengthKnown = configuration.contentLengthKnown(); + this.sourceBodyName = configuration.sourceBodyName(); + this.onNumBytesReceived = configuration.onNumBytesReceived(); + this.onNumBytesConsumed = configuration.onNumBytesConsumed(); + } + + @Override + public Optional contentLength() { + return contentLengthKnown ? Optional.of(configuration.maxLength()) : Optional.of(bufferedLength); + } + + public void send(ByteBuffer data) { + log.debug(() -> String.format("Sending bytebuffer %s to part %d", data, partNumber)); + long length = data.remaining(); + bufferedLength += length; + onNumBytesReceived.accept(length); + delegate.send(data).whenComplete((r, t) -> { + onNumBytesConsumed.accept(length); + if (t != null) { + error(t); + } + }); + } + + public void complete() { + log.debug(() -> "Received complete() for part number: " + partNumber); + delegate.complete().whenComplete((r, t) -> { + if (t != null) { + error(t); + } + }); + } + + @Override + public long maxLength() { + return configuration.maxLength(); + } + + @Override + public long receivedBytesLength() { + return bufferedLength; + } + + @Override + public int partNumber() { + return partNumber; + } + + public void error(Throwable error) { + delegate.error(error); + } + + @Override + public void subscribe(Subscriber s) { + if (subscribeCalled.compareAndSet(false, true)) { + delegate.subscribe(s); + } else { + s.onSubscribe(new NoopSubscription(s)); + s.onError(NonRetryableException.create( + "Multiple subscribers detected. This could happen due to a retry attempt. The AsyncRequestBody implementation" + + " provided does not support splitting to retryable/resubscribable AsyncRequestBody. If you need retry " + + "capability or multiple subscriptions, consider using BufferedSplittableAsyncRequestBody to wrap your " + + "AsyncRequestBody.")); + } + } + + @Override + public String body() { + return sourceBodyName; + } +} diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/RetryableSubAsyncRequestBody.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/RetryableSubAsyncRequestBody.java new file mode 100644 index 000000000000..15f8b0199107 --- /dev/null +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/RetryableSubAsyncRequestBody.java @@ -0,0 +1,162 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.internal.async; + + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Consumer; +import org.reactivestreams.Subscriber; +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.core.exception.NonRetryableException; +import software.amazon.awssdk.core.internal.util.NoopSubscription; +import software.amazon.awssdk.utils.Logger; +import software.amazon.awssdk.utils.Validate; +import software.amazon.awssdk.utils.async.SimplePublisher; + +/** + * A {@link SubAsyncRequestBody} implementation that supports resubscribe/retry once all data has been published to the first + * subscriber + */ +@SdkInternalApi +public final class RetryableSubAsyncRequestBody implements SubAsyncRequestBody { + private static final Logger log = Logger.loggerFor(RetryableSubAsyncRequestBody.class); + /** + * The maximum length of the content this AsyncRequestBody can hold. If the upstream content length is known, this is + * the same as totalLength + */ + private final SubAsyncRequestBodyConfiguration configuration; + private final int partNumber; + private final boolean contentLengthKnown; + private final String sourceBodyName; + + private volatile long bufferedLength = 0; + private volatile ByteBuffersAsyncRequestBody bufferedAsyncRequestBody; + private List buffers = new ArrayList<>(); + private final AtomicBoolean subscribeCalled = new AtomicBoolean(false); + private final SimplePublisher delegate = new SimplePublisher<>(); + private final Consumer onNumBytesReceived; + private final Consumer onNumBytesConsumed; + private final Object buffersLock = new Object(); + + /** + * Creates a new RetryableSubAsyncRequestBody with the given configuration. + */ + public RetryableSubAsyncRequestBody(SubAsyncRequestBodyConfiguration configuration) { + this.configuration = Validate.paramNotNull(configuration, "configuration"); + this.partNumber = configuration.partNumber(); + this.contentLengthKnown = configuration.contentLengthKnown(); + this.sourceBodyName = configuration.sourceBodyName(); + this.onNumBytesReceived = configuration.onNumBytesReceived(); + this.onNumBytesConsumed = configuration.onNumBytesConsumed(); + } + + @Override + public Optional contentLength() { + return contentLengthKnown ? Optional.of(configuration.maxLength()) : Optional.of(bufferedLength); + } + + @Override + public void send(ByteBuffer data) { + log.trace(() -> String.format("Sending bytebuffer %s to part number %d", data, partNumber)); + long length = data.remaining(); + bufferedLength += length; + + onNumBytesReceived.accept(length); + delegate.send(data.asReadOnlyBuffer()).whenComplete((r, t) -> { + if (t != null) { + delegate.error(t); + } + }); + synchronized (buffersLock) { + buffers.add(data.asReadOnlyBuffer()); + } + } + + @Override + public void complete() { + log.debug(() -> "Received complete() for part number: " + partNumber); + // ByteBuffersAsyncRequestBody MUST be created before we complete the current + // request because retry may happen right after + synchronized (buffersLock) { + bufferedAsyncRequestBody = ByteBuffersAsyncRequestBody.of(buffers, bufferedLength); + } + delegate.complete().exceptionally(e -> { + delegate.error(e); + return null; + }); + } + + @Override + public long maxLength() { + return configuration.maxLength(); + } + + @Override + public long receivedBytesLength() { + return bufferedLength; + } + + @Override + public void subscribe(Subscriber s) { + log.debug(() -> "Subscribe for part number: " + partNumber); + if (subscribeCalled.compareAndSet(false, true)) { + delegate.subscribe(s); + } else { + log.debug(() -> "Resubscribe for part number " + partNumber); + if (bufferedAsyncRequestBody == null) { + s.onSubscribe(new NoopSubscription(s)); + s.onError(NonRetryableException.create( + "A retry was attempted, but data is not buffered successfully for retry for partNumber: " + partNumber)); + return; + } + bufferedAsyncRequestBody.subscribe(s); + } + } + + @Override + public void close() { + try { + log.debug(() -> "Closing current body " + partNumber); + onNumBytesConsumed.accept(bufferedLength); + if (bufferedAsyncRequestBody != null) { + synchronized (buffersLock) { + buffers.clear(); + buffers = null; + } + bufferedAsyncRequestBody.close(); + bufferedAsyncRequestBody = null; + } + } catch (Throwable e) { + log.warn(() -> String.format("Unexpected error thrown from cleaning up AsyncRequestBody for part number %d, " + + "resource may be leaked", partNumber)); + } + } + + @Override + public int partNumber() { + return partNumber; + } + + @Override + public String body() { + return sourceBodyName; + } + +} diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/SplittingPublisher.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/SplittingPublisher.java index 12278cf84dca..d4b58b0285e3 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/SplittingPublisher.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/SplittingPublisher.java @@ -16,7 +16,6 @@ package software.amazon.awssdk.core.internal.async; import java.nio.ByteBuffer; -import java.util.Optional; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; @@ -25,9 +24,8 @@ import software.amazon.awssdk.annotations.SdkInternalApi; import software.amazon.awssdk.core.async.AsyncRequestBody; import software.amazon.awssdk.core.async.AsyncRequestBodySplitConfiguration; +import software.amazon.awssdk.core.async.CloseableAsyncRequestBody; import software.amazon.awssdk.core.async.SdkPublisher; -import software.amazon.awssdk.core.exception.NonRetryableException; -import software.amazon.awssdk.core.internal.util.NoopSubscription; import software.amazon.awssdk.utils.Logger; import software.amazon.awssdk.utils.Validate; import software.amazon.awssdk.utils.async.SimplePublisher; @@ -40,28 +38,33 @@ * Otherwise, it is sent after the entire content for that chunk is buffered. This is required to get content length. */ @SdkInternalApi -public class SplittingPublisher implements SdkPublisher { +public class SplittingPublisher implements SdkPublisher { private static final Logger log = Logger.loggerFor(SplittingPublisher.class); private final AsyncRequestBody upstreamPublisher; private final SplittingSubscriber splittingSubscriber; - private final SimplePublisher downstreamPublisher = new SimplePublisher<>(); + private final SimplePublisher downstreamPublisher = new SimplePublisher<>(); private final long chunkSizeInBytes; private final long bufferSizeInBytes; - - public SplittingPublisher(AsyncRequestBody asyncRequestBody, - AsyncRequestBodySplitConfiguration splitConfiguration) { - this.upstreamPublisher = Validate.paramNotNull(asyncRequestBody, "asyncRequestBody"); - Validate.notNull(splitConfiguration, "splitConfiguration"); - this.chunkSizeInBytes = splitConfiguration.chunkSizeInBytes() == null ? + private final boolean retryableSubAsyncRequestBodyEnabled; + private final AtomicBoolean currentBodySent = new AtomicBoolean(false); + private final String sourceBodyName; + + private SplittingPublisher(Builder builder) { + this.upstreamPublisher = Validate.paramNotNull(builder.asyncRequestBody, "asyncRequestBody"); + Validate.notNull(builder.splitConfiguration, "splitConfiguration"); + this.chunkSizeInBytes = builder.splitConfiguration.chunkSizeInBytes() == null ? AsyncRequestBodySplitConfiguration.defaultConfiguration().chunkSizeInBytes() : - splitConfiguration.chunkSizeInBytes(); + builder.splitConfiguration.chunkSizeInBytes(); - this.bufferSizeInBytes = splitConfiguration.bufferSizeInBytes() == null ? + this.bufferSizeInBytes = builder.splitConfiguration.bufferSizeInBytes() == null ? AsyncRequestBodySplitConfiguration.defaultConfiguration().bufferSizeInBytes() : - splitConfiguration.bufferSizeInBytes(); + builder.splitConfiguration.bufferSizeInBytes(); this.splittingSubscriber = new SplittingSubscriber(upstreamPublisher.contentLength().orElse(null)); + this.retryableSubAsyncRequestBodyEnabled = Validate.paramNotNull(builder.retryableSubAsyncRequestBodyEnabled, + "retryableSubAsyncRequestBodyEnabled"); + this.sourceBodyName = builder.asyncRequestBody.body(); if (!upstreamPublisher.contentLength().isPresent()) { Validate.isTrue(bufferSizeInBytes >= chunkSizeInBytes, "bufferSizeInBytes must be larger than or equal to " + @@ -69,8 +72,15 @@ public SplittingPublisher(AsyncRequestBody asyncRequestBody, } } + /** + * Returns a newly initialized builder object for a {@link SplittingPublisher} + */ + public static Builder builder() { + return new Builder(); + } + @Override - public void subscribe(Subscriber downstreamSubscriber) { + public void subscribe(Subscriber downstreamSubscriber) { downstreamPublisher.subscribe(downstreamSubscriber); upstreamPublisher.subscribe(splittingSubscriber); } @@ -78,8 +88,11 @@ public void subscribe(Subscriber downstreamSubscriber) private class SplittingSubscriber implements Subscriber { private Subscription upstreamSubscription; private final Long upstreamSize; - private final AtomicInteger chunkNumber = new AtomicInteger(0); - private volatile DownstreamBody currentBody; + /** + * 1 based index number for each part/chunk + */ + private final AtomicInteger partNumber = new AtomicInteger(1); + private volatile SubAsyncRequestBody currentBody; private final AtomicBoolean hasOpenUpstreamDemand = new AtomicBoolean(false); private final AtomicLong dataBuffered = new AtomicLong(0); @@ -98,13 +111,31 @@ public void onSubscribe(Subscription s) { this.upstreamSubscription = s; this.currentBody = initializeNextDownstreamBody(upstreamSize != null, calculateChunkSize(upstreamSize), - chunkNumber.get()); + partNumber.get()); // We need to request subscription *after* we set currentBody because onNext could be invoked right away. upstreamSubscription.request(1); } - private DownstreamBody initializeNextDownstreamBody(boolean contentLengthKnown, long chunkSize, int chunkNumber) { - DownstreamBody body = new DownstreamBody(contentLengthKnown, chunkSize, chunkNumber); + private SubAsyncRequestBody initializeNextDownstreamBody(boolean contentLengthKnown, long chunkSize, int chunkNumber) { + SubAsyncRequestBody body; + log.debug(() -> "initializing next downstream body " + partNumber); + + SubAsyncRequestBodyConfiguration config = SubAsyncRequestBodyConfiguration.builder() + .contentLengthKnown(contentLengthKnown) + .maxLength(chunkSize) + .partNumber(chunkNumber) + .onNumBytesReceived(data -> addDataBuffered(data)) + .onNumBytesConsumed(data -> addDataBuffered(-data)) + .sourceBodyName(sourceBodyName) + .build(); + + if (retryableSubAsyncRequestBodyEnabled) { + body = new RetryableSubAsyncRequestBody(config); + } else { + body = new NonRetryableSubAsyncRequestBody(config); + } + + currentBodySent.set(false); if (contentLengthKnown) { sendCurrentBody(body); } @@ -158,7 +189,7 @@ public void onNext(ByteBuffer byteBuffer) { private void completeCurrentBodyAndCreateNewIfNeeded(ByteBuffer byteBuffer) { completeCurrentBody(); - int currentChunk = chunkNumber.incrementAndGet(); + int nextChunk = partNumber.incrementAndGet(); boolean shouldCreateNewDownstreamRequestBody; Long dataRemaining = totalDataRemaining(); @@ -170,18 +201,40 @@ private void completeCurrentBodyAndCreateNewIfNeeded(ByteBuffer byteBuffer) { if (shouldCreateNewDownstreamRequestBody) { long chunkSize = calculateChunkSize(dataRemaining); - currentBody = initializeNextDownstreamBody(upstreamSize != null, chunkSize, currentChunk); + currentBody = initializeNextDownstreamBody(upstreamSize != null, chunkSize, nextChunk); } } private int amountRemainingInChunk() { - return Math.toIntExact(currentBody.maxLength - currentBody.transferredLength); + return Math.toIntExact(currentBody.maxLength() - currentBody.receivedBytesLength()); } + private void completeCurrentBody() { - log.debug(() -> "completeCurrentBody for chunk " + chunkNumber.get()); + log.debug(() -> "completeCurrentBody for part " + currentBody.partNumber()); + long bufferedLength = currentBody.receivedBytesLength(); + // For unknown content length, we always create a new DownstreamBody once the current one is sent + // because we don't know if there is data + // left or not, so we need to check the length and only send the body if there is actually data + if (bufferedLength == 0) { + return; + } + + Long totalLength = currentBody.maxLength(); + if (upstreamSize != null && totalLength != bufferedLength) { + upstreamSubscription.cancel(); + downstreamPublisher.error(new IllegalStateException( + String.format("Content length of buffered data mismatches " + + "with the expected content length, buffered data content length: %d, " + + "expected length: %d", totalLength, + bufferedLength))); + return; + } currentBody.complete(); - if (upstreamSize == null) { + + // Current body could be completed in either onNext or onComplete, so we need to guard against sending the last body + // twice. + if (upstreamSize == null && currentBodySent.compareAndSet(false, true)) { sendCurrentBody(currentBody); } } @@ -189,18 +242,19 @@ private void completeCurrentBody() { @Override public void onComplete() { upstreamComplete = true; - log.trace(() -> "Received onComplete()"); + log.debug(() -> "Received onComplete()"); completeCurrentBody(); downstreamPublisher.complete(); } @Override public void onError(Throwable t) { - log.trace(() -> "Received onError()", t); + log.debug(() -> "Received onError()", t); downstreamPublisher.error(t); } - private void sendCurrentBody(AsyncRequestBody body) { + private void sendCurrentBody(SubAsyncRequestBody body) { + log.debug(() -> "sendCurrentBody for part " + body.partNumber()); downstreamPublisher.send(body).exceptionally(t -> { downstreamPublisher.error(t); upstreamSubscription.cancel(); @@ -223,88 +277,68 @@ private void maybeRequestMoreUpstreamData() { hasOpenUpstreamDemand.compareAndSet(false, true)) { log.trace(() -> "Requesting more data, current data buffered: " + buffered); upstreamSubscription.request(1); + } else { + log.trace(() -> "Should not request more data, current data buffered: " + buffered); } } private boolean shouldRequestMoreData(long buffered) { - return buffered == 0 || buffered + byteBufferSizeHint <= bufferSizeInBytes; + return buffered <= 0 || buffered + byteBufferSizeHint <= bufferSizeInBytes; } private Long totalDataRemaining() { if (upstreamSize == null) { return null; } - return upstreamSize - (chunkNumber.get() * chunkSizeInBytes); + return upstreamSize - ((partNumber.get() - 1) * chunkSizeInBytes); } - private final class DownstreamBody implements AsyncRequestBody { - - /** - * The maximum length of the content this AsyncRequestBody can hold. If the upstream content length is known, this is - * the same as totalLength - */ - private final long maxLength; - private final Long totalLength; - private final SimplePublisher delegate = new SimplePublisher<>(); - private final int chunkNumber; - private final AtomicBoolean subscribeCalled = new AtomicBoolean(false); - private volatile long transferredLength = 0; - - private DownstreamBody(boolean contentLengthKnown, long maxLength, int chunkNumber) { - this.totalLength = contentLengthKnown ? maxLength : null; - this.maxLength = maxLength; - this.chunkNumber = chunkNumber; + private void addDataBuffered(long length) { + log.trace(() -> "Adding data buffered " + length); + dataBuffered.addAndGet(length); + if (length < 0) { + maybeRequestMoreUpstreamData(); } + } + } - @Override - public Optional contentLength() { - return totalLength != null ? Optional.of(totalLength) : Optional.of(transferredLength); - } + public static final class Builder { + private AsyncRequestBody asyncRequestBody; + private AsyncRequestBodySplitConfiguration splitConfiguration; + private Boolean retryableSubAsyncRequestBodyEnabled; - public void send(ByteBuffer data) { - log.trace(() -> String.format("Sending bytebuffer %s to chunk %d", data, chunkNumber)); - int length = data.remaining(); - transferredLength += length; - addDataBuffered(length); - delegate.send(data).whenComplete((r, t) -> { - addDataBuffered(-length); - if (t != null) { - error(t); - } - }); - } + private Builder() { + } - public void complete() { - log.debug(() -> "Received complete() for chunk number: " + chunkNumber + " length " + transferredLength); - delegate.complete().whenComplete((r, t) -> { - if (t != null) { - error(t); - } - }); - } + /** + * Sets the AsyncRequestBody to be split. + */ + public Builder asyncRequestBody(AsyncRequestBody asyncRequestBody) { + this.asyncRequestBody = asyncRequestBody; + return this; + } - public void error(Throwable error) { - delegate.error(error); - } + /** + * Sets the split configuration. + */ + public Builder splitConfiguration(AsyncRequestBodySplitConfiguration splitConfiguration) { + this.splitConfiguration = splitConfiguration; + return this; + } - @Override - public void subscribe(Subscriber s) { - if (subscribeCalled.compareAndSet(false, true)) { - delegate.subscribe(s); - } else { - s.onSubscribe(new NoopSubscription(s)); - s.onError(NonRetryableException.create( - "A retry was attempted, but AsyncRequestBody.split does not " - + "support retries.")); - } - } + /** + * Sets whether to enable retryable sub async request bodies. + */ + public Builder retryableSubAsyncRequestBodyEnabled(Boolean retryableSubAsyncRequestBodyEnabled) { + this.retryableSubAsyncRequestBodyEnabled = retryableSubAsyncRequestBodyEnabled; + return this; + } - private void addDataBuffered(int length) { - dataBuffered.addAndGet(length); - if (length < 0) { - maybeRequestMoreUpstreamData(); - } - } + /** + * Builds a {@link SplittingPublisher} object based on the values held by this builder. + */ + public SplittingPublisher build() { + return new SplittingPublisher(this); } } -} +} \ No newline at end of file diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/SubAsyncRequestBody.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/SubAsyncRequestBody.java new file mode 100644 index 000000000000..6d3ec1b979a6 --- /dev/null +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/SubAsyncRequestBody.java @@ -0,0 +1,63 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.internal.async; + +import java.nio.ByteBuffer; +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.core.async.CloseableAsyncRequestBody; + +/** + * Represent a sub {@link AsyncRequestBody} that publishes a portion of the source {@link AsyncRequestBody} + */ +@SdkInternalApi +public interface SubAsyncRequestBody extends CloseableAsyncRequestBody { + + /** + * Send a byte buffer. + *

+ * This method must not be invoked concurrently. + */ + void send(ByteBuffer byteBuffer); + + /** + * Indicate that no more {@link #send(ByteBuffer)} )} calls will be made, + * and that stream of messages is completed successfully. + */ + void complete(); + + /** + * The maximum length of the content this AsyncRequestBody can hold. If the upstream content length is known, this should be + * the same as receivedBytesLength + */ + long maxLength(); + + /** + * The length of the bytes received + */ + long receivedBytesLength(); + + @Override + default void close() { + // no op + } + + /** + * The part number associated with this SubAsyncRequestBody + * @return + */ + int partNumber(); +} diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/SubAsyncRequestBodyConfiguration.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/SubAsyncRequestBodyConfiguration.java new file mode 100644 index 000000000000..cca1abe166d1 --- /dev/null +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/SubAsyncRequestBodyConfiguration.java @@ -0,0 +1,140 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.internal.async; + +import java.util.function.Consumer; +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.utils.Validate; + +/** + * Configuration class containing shared properties for SubAsyncRequestBody implementations. + */ +@SdkInternalApi +public final class SubAsyncRequestBodyConfiguration { + private final boolean contentLengthKnown; + private final long maxLength; + private final int partNumber; + private final Consumer onNumBytesReceived; + private final Consumer onNumBytesConsumed; + private final String sourceBodyName; + + private SubAsyncRequestBodyConfiguration(Builder builder) { + this.contentLengthKnown = Validate.paramNotNull(builder.contentLengthKnown, "contentLengthKnown"); + this.maxLength = Validate.paramNotNull(builder.maxLength, "maxLength"); + this.partNumber = Validate.paramNotNull(builder.partNumber, "partNumber"); + this.onNumBytesReceived = Validate.paramNotNull(builder.onNumBytesReceived, "onNumBytesReceived"); + this.onNumBytesConsumed = Validate.paramNotNull(builder.onNumBytesConsumed, "onNumBytesConsumed"); + this.sourceBodyName = Validate.paramNotNull(builder.sourceBodyName, "sourceBodyName"); + } + + /** + * Returns a newly initialized builder object for a {@link SubAsyncRequestBodyConfiguration} + */ + public static Builder builder() { + return new Builder(); + } + + public boolean contentLengthKnown() { + return contentLengthKnown; + } + + public long maxLength() { + return maxLength; + } + + public int partNumber() { + return partNumber; + } + + public Consumer onNumBytesReceived() { + return onNumBytesReceived; + } + + public Consumer onNumBytesConsumed() { + return onNumBytesConsumed; + } + + public String sourceBodyName() { + return sourceBodyName; + } + + public static final class Builder { + private Boolean contentLengthKnown; + private Long maxLength; + private Integer partNumber; + private Consumer onNumBytesReceived; + private Consumer onNumBytesConsumed; + private String sourceBodyName; + + private Builder() { + } + + /** + * Sets whether the content length is known. + */ + public Builder contentLengthKnown(Boolean contentLengthKnown) { + this.contentLengthKnown = contentLengthKnown; + return this; + } + + /** + * Sets the maximum length of the content this AsyncRequestBody can hold. + */ + public Builder maxLength(Long maxLength) { + this.maxLength = maxLength; + return this; + } + + /** + * Sets the part number for this request body. + */ + public Builder partNumber(Integer partNumber) { + this.partNumber = partNumber; + return this; + } + + /** + * Sets the callback to be invoked when bytes are received. + */ + public Builder onNumBytesReceived(Consumer onNumBytesReceived) { + this.onNumBytesReceived = onNumBytesReceived; + return this; + } + + /** + * Sets the callback to be invoked when bytes are consumed. + */ + public Builder onNumBytesConsumed(Consumer onNumBytesConsumed) { + this.onNumBytesConsumed = onNumBytesConsumed; + return this; + } + + /** + * Sets the source body name for identification. + */ + public Builder sourceBodyName(String sourceBodyName) { + this.sourceBodyName = sourceBodyName; + return this; + } + + /** + * Builds a {@link SubAsyncRequestBodyConfiguration} object based on the values held by this builder. + */ + public SubAsyncRequestBodyConfiguration build() { + return new SubAsyncRequestBodyConfiguration(this); + } + } +} diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/async/AsyncRequestBodyConfigurationTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/async/AsyncRequestBodySplitConfigurationTest.java similarity index 97% rename from core/sdk-core/src/test/java/software/amazon/awssdk/core/async/AsyncRequestBodyConfigurationTest.java rename to core/sdk-core/src/test/java/software/amazon/awssdk/core/async/AsyncRequestBodySplitConfigurationTest.java index 8b8f78f2b5e9..e932da3bfa1c 100644 --- a/core/sdk-core/src/test/java/software/amazon/awssdk/core/async/AsyncRequestBodyConfigurationTest.java +++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/async/AsyncRequestBodySplitConfigurationTest.java @@ -23,7 +23,7 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; -public class AsyncRequestBodyConfigurationTest { +public class AsyncRequestBodySplitConfigurationTest { @Test void equalsHashCode() { diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/async/AsyncRequestBodyTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/async/AsyncRequestBodyTest.java index cdd87822e3d4..f0c58c37e9a4 100644 --- a/core/sdk-core/src/test/java/software/amazon/awssdk/core/async/AsyncRequestBodyTest.java +++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/async/AsyncRequestBodyTest.java @@ -19,10 +19,12 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; +import static software.amazon.awssdk.core.internal.async.SplittingPublisherTestUtils.verifyIndividualAsyncRequestBody; import com.google.common.jimfs.Configuration; import com.google.common.jimfs.Jimfs; import io.reactivex.Flowable; +import java.io.File; import java.io.IOException; import java.nio.ByteBuffer; import java.nio.charset.Charset; @@ -31,13 +33,20 @@ import java.nio.file.Files; import java.nio.file.Path; import java.util.List; +import java.util.Optional; +import java.util.UUID; import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.apache.commons.lang3.RandomStringUtils; import org.assertj.core.util.Lists; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.ValueSource; import org.reactivestreams.Publisher; @@ -49,16 +58,33 @@ public class AsyncRequestBodyTest { private static final String testString = "Hello!"; - private static final Path path; - - static { - FileSystem fs = Jimfs.newFileSystem(Configuration.unix()); + private static Path path; + private static final int CONTENT_SIZE = 1024; + private static final byte[] CONTENT = + RandomStringUtils.randomAscii(CONTENT_SIZE).getBytes(Charset.defaultCharset()); + private static File fileForSplit; + private static FileSystem fs; + + @BeforeAll + public static void setup() throws IOException { + fs = Jimfs.newFileSystem(Configuration.unix()); path = fs.getPath("./test"); - try { - Files.write(path, testString.getBytes()); - } catch (IOException e) { - e.printStackTrace(); - } + Files.write(path, testString.getBytes()); + + fileForSplit = File.createTempFile("SplittingPublisherTest", UUID.randomUUID().toString()); + Files.write(fileForSplit.toPath(), CONTENT); + } + + @AfterAll + public static void teardown() throws IOException { + fileForSplit.delete(); + fs.close(); + } + + public static Stream asyncRequestBodies() { + return Stream.of(Arguments.of(AsyncRequestBody.fromBytes(CONTENT)), + Arguments.of(AsyncRequestBody.fromFile(b -> b.chunkSizeInBytes(50) + .path(fileForSplit.toPath())))); } @ParameterizedTest @@ -300,6 +326,34 @@ void rewindingByteBufferBuildersReadTheInputBufferFromTheBeginning( assertEquals(bb, publishedBuffer.get()); } + @ParameterizedTest + @MethodSource("asyncRequestBodies") + void legacySplit_shouldWork(AsyncRequestBody delegate) throws Exception { + long chunkSize = 20l; + AsyncRequestBody asyncRequestBody = new AsyncRequestBody() { + @Override + public Optional contentLength() { + return delegate.contentLength(); + } + + @Override + public void subscribe(Subscriber s) { + delegate.subscribe(s); + } + }; + + AsyncRequestBodySplitConfiguration configuration = AsyncRequestBodySplitConfiguration.builder() + .chunkSizeInBytes(chunkSize) + .bufferSizeInBytes(chunkSize) + .build(); + + SdkPublisher split = asyncRequestBody.split(configuration); + verifyIndividualAsyncRequestBody(split, fileForSplit.toPath(), (int) chunkSize); + } + + + + private static Function[] rewindingByteBufferBodyBuilders() { Function fromByteBuffer = AsyncRequestBody::fromByteBuffer; Function fromByteBufferUnsafe = AsyncRequestBody::fromByteBufferUnsafe; @@ -356,4 +410,13 @@ void publisherConstructorHasCorrectContentType() { AsyncRequestBody requestBody = AsyncRequestBody.fromPublisher(bodyPublisher); assertEquals(Mimetype.MIMETYPE_OCTET_STREAM, requestBody.contentType()); } + + @Test + void splitV2_nullConfig_shouldThrowException() { + AsyncRequestBody requestBody = AsyncRequestBody.fromString("hello world"); + AsyncRequestBodySplitConfiguration config = null; + assertThatThrownBy(() -> requestBody.splitCloseable(config)) + .isInstanceOf(NullPointerException.class) + .hasMessageContaining("splitConfig"); + } } diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/FileAsyncRequestBodySplitHelperTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/FileAsyncRequestBodySplitHelperTest.java index 4c5d0748d16d..1edea1a58b1b 100644 --- a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/FileAsyncRequestBodySplitHelperTest.java +++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/FileAsyncRequestBodySplitHelperTest.java @@ -77,7 +77,9 @@ public void split_differentChunkSize_shouldSplitCorrectly(int chunkSize) throws ScheduledFuture scheduledFuture = executor.scheduleWithFixedDelay(verifyConcurrentRequests(helper, maxConcurrency), 1, 50, TimeUnit.MICROSECONDS); - verifyIndividualAsyncRequestBody(helper.split(), testFile, chunkSize); + verifyIndividualAsyncRequestBody(helper.split(), + testFile, + chunkSize); scheduledFuture.cancel(true); int expectedMaxConcurrency = (int) (bufferSize / chunkSizeInBytes); assertThat(maxConcurrency.get()).isLessThanOrEqualTo(expectedMaxConcurrency); diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/NonRetryableSubAsyncRequestBodyTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/NonRetryableSubAsyncRequestBodyTest.java new file mode 100644 index 000000000000..583aa077f21c --- /dev/null +++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/NonRetryableSubAsyncRequestBodyTest.java @@ -0,0 +1,152 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.internal.async; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +import io.reactivex.Flowable; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.function.Consumer; +import org.apache.commons.lang3.RandomStringUtils; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import software.amazon.awssdk.core.exception.NonRetryableException; + +class NonRetryableSubAsyncRequestBodyTest { + + private SubAsyncRequestBodyConfiguration configuration; + private Consumer onNumBytesReceived; + private Consumer onNumBytesConsumed; + private NonRetryableSubAsyncRequestBody requestBody; + + @BeforeEach + void setUp() { + onNumBytesReceived = mock(Consumer.class); + onNumBytesConsumed = mock(Consumer.class); + + configuration = SubAsyncRequestBodyConfiguration.builder() + .contentLengthKnown(true) + .maxLength(1024L) + .partNumber(1) + .onNumBytesReceived(onNumBytesReceived) + .onNumBytesConsumed(onNumBytesConsumed) + .sourceBodyName("test-body") + .build(); + + requestBody = new NonRetryableSubAsyncRequestBody(configuration); + } + + @Test + void getters_shouldReturnConfigurationValues() { + assertThat(requestBody.maxLength()).isEqualTo(1024L); + assertThat(requestBody.partNumber()).isEqualTo(1); + assertThat(requestBody.body()).isEqualTo("test-body"); + assertThat(requestBody.contentLength()).isEqualTo(Optional.of(1024L)); + assertThat(requestBody.receivedBytesLength()).isEqualTo(0L); + } + + @Test + void constructor_withNullConfiguration_shouldThrowException() { + assertThatThrownBy(() -> new NonRetryableSubAsyncRequestBody(null)) + .isInstanceOf(NullPointerException.class); + } + + @Test + void contentLength_whenContentLengthUnknown_shouldReturnBufferedLength() { + SubAsyncRequestBodyConfiguration unknownLengthConfig = SubAsyncRequestBodyConfiguration.builder() + .contentLengthKnown(false) + .maxLength(1024L) + .partNumber(1) + .onNumBytesReceived(onNumBytesReceived) + .onNumBytesConsumed(onNumBytesConsumed) + .sourceBodyName("test-body") + .build(); + + NonRetryableSubAsyncRequestBody unknownLengthBody = new NonRetryableSubAsyncRequestBody(unknownLengthConfig); + + assertThat(unknownLengthBody.contentLength()).isEqualTo(Optional.of(0L)); + + // Send some data + ByteBuffer data = ByteBuffer.wrap("test".getBytes()); + unknownLengthBody.send(data); + + assertThat(unknownLengthBody.contentLength()).isEqualTo(Optional.of(4L)); + } + + @Test + void subscribe_shouldReceiveAllData() { + byte[] part1 = RandomStringUtils.randomAscii(1024).getBytes(StandardCharsets.UTF_8); + byte[] part2 = RandomStringUtils.randomAscii(512).getBytes(StandardCharsets.UTF_8); + requestBody.send(ByteBuffer.wrap(part1)); + requestBody.send(ByteBuffer.wrap(part2)); + requestBody.complete(); + List receivedBuffers = new ArrayList<>(); + Flowable.fromPublisher(requestBody).forEach(buffer -> receivedBuffers.add(buffer)); + + verify(onNumBytesReceived).accept(1024L); + verify(onNumBytesConsumed).accept(1024L); + verify(onNumBytesReceived).accept(512L); + verify(onNumBytesConsumed).accept(512L); + assertThat(requestBody.receivedBytesLength()).isEqualTo(1536L); + assertThat(receivedBuffers).containsExactly(ByteBuffer.wrap(part1), ByteBuffer.wrap(part2)); + } + + @Test + void subscribe_secondTime_shouldSendError() { + Subscriber subscriber1 = mock(Subscriber.class); + Subscriber subscriber2 = mock(Subscriber.class); + + // First subscription + requestBody.subscribe(subscriber1); + + // Second subscription should fail + requestBody.subscribe(subscriber2); + + ArgumentCaptor subscriptionCaptor = ArgumentCaptor.forClass(Subscription.class); + verify(subscriber2).onSubscribe(subscriptionCaptor.capture()); + + ArgumentCaptor errorCaptor = ArgumentCaptor.forClass(Throwable.class); + verify(subscriber2).onError(errorCaptor.capture()); + + Throwable error = errorCaptor.getValue(); + assertThat(error).isInstanceOf(NonRetryableException.class); + assertThat(error.getMessage()).contains("This could happen due to a retry attempt"); + } + + @Test + void receivedBytesLength_shouldTrackSentData() { + assertThat(requestBody.receivedBytesLength()).isEqualTo(0L); + + ByteBuffer data1 = ByteBuffer.wrap("hello".getBytes()); + requestBody.send(data1); + assertThat(requestBody.receivedBytesLength()).isEqualTo(5L); + + ByteBuffer data2 = ByteBuffer.wrap(" world".getBytes()); + requestBody.send(data2); + assertThat(requestBody.receivedBytesLength()).isEqualTo(11L); + } +} diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/RetryableSubAsyncRequestBodyTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/RetryableSubAsyncRequestBodyTest.java new file mode 100644 index 000000000000..969bdb99cb4a --- /dev/null +++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/RetryableSubAsyncRequestBodyTest.java @@ -0,0 +1,203 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.internal.async; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import io.reactivex.Flowable; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Consumer; +import org.apache.commons.lang3.RandomStringUtils; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import software.amazon.awssdk.core.exception.NonRetryableException; + +class RetryableSubAsyncRequestBodyTest { + + private SubAsyncRequestBodyConfiguration configuration; + private Consumer onNumBytesReceived; + private Consumer onNumBytesConsumed; + private RetryableSubAsyncRequestBody requestBody; + + @BeforeEach + void setUp() { + onNumBytesReceived = mock(Consumer.class); + onNumBytesConsumed = mock(Consumer.class); + + configuration = SubAsyncRequestBodyConfiguration.builder() + .contentLengthKnown(true) + .maxLength(1024L) + .partNumber(1) + .onNumBytesReceived(onNumBytesReceived) + .onNumBytesConsumed(onNumBytesConsumed) + .sourceBodyName("test-body") + .build(); + + requestBody = new RetryableSubAsyncRequestBody(configuration); + } + + @Test + void getters_shouldReturnConfigurationValues() { + assertThat(requestBody.maxLength()).isEqualTo(1024L); + assertThat(requestBody.partNumber()).isEqualTo(1); + assertThat(requestBody.body()).isEqualTo("test-body"); + assertThat(requestBody.contentLength()).isEqualTo(Optional.of(1024L)); + assertThat(requestBody.receivedBytesLength()).isEqualTo(0L); + } + + @Test + void constructor_withNullConfiguration_shouldThrowException() { + assertThatThrownBy(() -> new RetryableSubAsyncRequestBody(null)) + .isInstanceOf(NullPointerException.class); + } + + @Test + void contentLength_whenContentLengthUnknown_shouldReturnBufferedLength() { + SubAsyncRequestBodyConfiguration unknownLengthConfig = SubAsyncRequestBodyConfiguration.builder() + .contentLengthKnown(false) + .maxLength(1024L) + .partNumber(1) + .onNumBytesReceived(onNumBytesReceived) + .onNumBytesConsumed(onNumBytesConsumed) + .sourceBodyName("test-body") + .build(); + + RetryableSubAsyncRequestBody unknownLengthBody = new RetryableSubAsyncRequestBody(unknownLengthConfig); + + assertThat(unknownLengthBody.contentLength()).isEqualTo(Optional.of(0L)); + + // Send some data + ByteBuffer data = ByteBuffer.wrap("test".getBytes()); + unknownLengthBody.send(data); + + assertThat(unknownLengthBody.contentLength()).isEqualTo(Optional.of(4L)); + } + + @Test + void subscribe_shouldReceiveAllData() { + byte[] part1 = RandomStringUtils.randomAscii(1024).getBytes(StandardCharsets.UTF_8); + byte[] part2 = RandomStringUtils.randomAscii(512).getBytes(StandardCharsets.UTF_8); + requestBody.send(ByteBuffer.wrap(part1)); + requestBody.send(ByteBuffer.wrap(part2)); + requestBody.complete(); + List receivedBuffers = new ArrayList<>(); + Flowable.fromPublisher(requestBody).forEach(buffer -> receivedBuffers.add(buffer)); + + verify(onNumBytesReceived).accept(1024L); + verify(onNumBytesReceived).accept(512L); + assertThat(requestBody.receivedBytesLength()).isEqualTo(1536L); + assertThat(receivedBuffers).containsExactly(ByteBuffer.wrap(part1), ByteBuffer.wrap(part2)); + } + + @Test + void subscribe_secondTime_shouldUseBufferedBody() { + byte[] part1 = RandomStringUtils.randomAscii(1024).getBytes(StandardCharsets.UTF_8); + byte[] part2 = RandomStringUtils.randomAscii(512).getBytes(StandardCharsets.UTF_8); + requestBody.send(ByteBuffer.wrap(part1)); + requestBody.send(ByteBuffer.wrap(part2)); + requestBody.complete(); + + List buffer1 = new ArrayList<>(); + Flowable.fromPublisher(requestBody).forEach(buffer -> buffer1.add(buffer)); + + List buffer2 = new ArrayList<>(); + Flowable.fromPublisher(requestBody).forEach(buffer -> buffer2.add(buffer)); + + assertThat(buffer1).containsExactly(ByteBuffer.wrap(part1), ByteBuffer.wrap(part2)); + assertThat(buffer2).containsExactly(ByteBuffer.wrap(part1), ByteBuffer.wrap(part2)); + } + + @Test + void subscribe_retryWithoutFirstSubscriberDone_shouldSendError() { + Subscriber subscriber1 = mock(Subscriber.class); + Subscriber subscriber2 = mock(Subscriber.class); + + // First subscription + requestBody.subscribe(subscriber1); + // Second subscription without completing first (no buffered data) + requestBody.subscribe(subscriber2); + + ArgumentCaptor subscriptionCaptor = ArgumentCaptor.forClass(Subscription.class); + verify(subscriber2).onSubscribe(subscriptionCaptor.capture()); + + ArgumentCaptor errorCaptor = ArgumentCaptor.forClass(Throwable.class); + verify(subscriber2).onError(errorCaptor.capture()); + + Throwable error = errorCaptor.getValue(); + assertThat(error).isInstanceOf(NonRetryableException.class); + assertThat(error.getMessage()).contains("data is not buffered successfully for retry"); + } + + @Test + void subscribe_resubscribeAfterClose_shouldSendError() { + byte[] data = RandomStringUtils.randomAscii(1024).getBytes(StandardCharsets.UTF_8); + requestBody.send(ByteBuffer.wrap(data)); + requestBody.complete(); + + Flowable.fromPublisher(requestBody).forEach(buffer -> {}); + + requestBody.close(); + Subscriber secondSubscriber = mock(Subscriber.class); + requestBody.subscribe(secondSubscriber); + + ArgumentCaptor subscriptionCaptor = ArgumentCaptor.forClass(Subscription.class); + verify(secondSubscriber).onSubscribe(subscriptionCaptor.capture()); + + ArgumentCaptor errorCaptor = ArgumentCaptor.forClass(Throwable.class); + verify(secondSubscriber).onError(errorCaptor.capture()); + + Throwable error = errorCaptor.getValue(); + assertThat(error).isInstanceOf(NonRetryableException.class); + assertThat(error.getMessage()).contains("data is not buffered successfully for retry"); + } + + @Test + void close_shouldInvokeOnNumBytesConsumed() { + ByteBuffer data = ByteBuffer.wrap("test data".getBytes()); + requestBody.send(data); + + requestBody.close(); + + verify(onNumBytesConsumed).accept(9L); + } + + @Test + void receivedBytesLength_shouldTrackSentData() { + assertThat(requestBody.receivedBytesLength()).isEqualTo(0L); + + ByteBuffer data1 = ByteBuffer.wrap("hello".getBytes()); + requestBody.send(data1); + assertThat(requestBody.receivedBytesLength()).isEqualTo(5L); + + ByteBuffer data2 = ByteBuffer.wrap(" world".getBytes()); + requestBody.send(data2); + assertThat(requestBody.receivedBytesLength()).isEqualTo(11L); + } +} diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/SplittingPublisherTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/SplittingPublisherTest.java index 6f116ca2667c..87ac0b4726f4 100644 --- a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/SplittingPublisherTest.java +++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/SplittingPublisherTest.java @@ -16,7 +16,6 @@ package software.amazon.awssdk.core.internal.async; import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.assertj.core.api.Assertions.fail; import static org.assertj.core.api.AssertionsForClassTypes.assertThat; import static software.amazon.awssdk.core.internal.async.SplittingPublisherTestUtils.verifyIndividualAsyncRequestBody; import static software.amazon.awssdk.utils.FunctionalUtils.invokeSafely; @@ -24,17 +23,20 @@ import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.File; -import java.io.FileInputStream; import java.io.IOException; import java.nio.ByteBuffer; import java.nio.charset.Charset; import java.nio.file.Files; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.UUID; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicInteger; import org.apache.commons.lang3.RandomStringUtils; import org.junit.jupiter.api.AfterAll; @@ -42,12 +44,12 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; -import org.reactivestreams.Publisher; import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; import software.amazon.awssdk.core.async.AsyncRequestBody; import software.amazon.awssdk.core.async.AsyncRequestBodySplitConfiguration; import software.amazon.awssdk.utils.BinaryUtils; +import software.amazon.awssdk.utils.Pair; public class SplittingPublisherTest { private static final int CHUNK_SIZE = 5; @@ -75,10 +77,15 @@ public static void afterAll() throws Exception { public void split_contentUnknownMaxMemorySmallerThanChunkSize_shouldThrowException() { AsyncRequestBody body = AsyncRequestBody.fromPublisher(s -> { }); - assertThatThrownBy(() -> new SplittingPublisher(body, AsyncRequestBodySplitConfiguration.builder() - .chunkSizeInBytes(10L) - .bufferSizeInBytes(5L) - .build())) + AsyncRequestBodySplitConfiguration configuration = AsyncRequestBodySplitConfiguration.builder() + .chunkSizeInBytes(10L) + .bufferSizeInBytes(5L) + .build(); + assertThatThrownBy(() -> SplittingPublisher.builder() + .asyncRequestBody(body) + .splitConfiguration(configuration) + .retryableSubAsyncRequestBodyEnabled(false) + .build()) .hasMessageContaining("must be larger than or equal"); } @@ -91,16 +98,24 @@ void differentChunkSize_shouldSplitAsyncRequestBodyCorrectly(int chunkSize) thro .chunkSizeInBytes(chunkSize) .build(); verifySplitContent(fileAsyncRequestBody, chunkSize); + fileAsyncRequestBody = FileAsyncRequestBody.builder() + .path(testFile.toPath()) + .chunkSizeInBytes(chunkSize) + .build(); + verifyRetryableSplitContent(fileAsyncRequestBody, chunkSize); } @ParameterizedTest @ValueSource(ints = {CHUNK_SIZE, CHUNK_SIZE * 2 - 1, CHUNK_SIZE * 2}) void differentChunkSize_byteArrayShouldSplitAsyncRequestBodyCorrectly(int chunkSize) throws Exception { verifySplitContent(AsyncRequestBody.fromBytes(CONTENT), chunkSize); + verifyRetryableSplitContent(AsyncRequestBody.fromBytes(CONTENT), chunkSize); } - @Test - void contentLengthNotPresent_shouldHandle() throws Exception { + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void contentLengthNotPresent_shouldHandle(boolean enableRetryableSubAsyncRequestBody) throws Exception { CompletableFuture future = new CompletableFuture<>(); TestAsyncRequestBody asyncRequestBody = new TestAsyncRequestBody() { @Override @@ -108,10 +123,14 @@ public Optional contentLength() { return Optional.empty(); } }; - SplittingPublisher splittingPublisher = new SplittingPublisher(asyncRequestBody, AsyncRequestBodySplitConfiguration.builder() - .chunkSizeInBytes((long) CHUNK_SIZE) - .bufferSizeInBytes(10L) - .build()); + SplittingPublisher splittingPublisher = SplittingPublisher.builder() + .asyncRequestBody(asyncRequestBody) + .splitConfiguration(AsyncRequestBodySplitConfiguration.builder() + .chunkSizeInBytes((long) CHUNK_SIZE) + .bufferSizeInBytes(10L) + .build()) + .retryableSubAsyncRequestBodyEnabled(enableRetryableSubAsyncRequestBody) + .build(); List> futures = new ArrayList<>(); @@ -147,29 +166,93 @@ public Optional contentLength() { } - @Test - void downStreamFailed_shouldPropagateCancellation() { + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void downStreamFailed_shouldPropagateCancellation(boolean enableRetryableSubAsyncRequestBody) throws Exception { CompletableFuture future = new CompletableFuture<>(); TestAsyncRequestBody asyncRequestBody = new TestAsyncRequestBody(); - SplittingPublisher splittingPublisher = new SplittingPublisher(asyncRequestBody, AsyncRequestBodySplitConfiguration.builder() - .chunkSizeInBytes((long) CHUNK_SIZE) - .bufferSizeInBytes(10L) - .build()); - + SplittingPublisher splittingPublisher = SplittingPublisher.builder() + .asyncRequestBody(asyncRequestBody) + .splitConfiguration(AsyncRequestBodySplitConfiguration.builder() + .chunkSizeInBytes((long) CHUNK_SIZE) + .bufferSizeInBytes(10L) + .build()) + .retryableSubAsyncRequestBodyEnabled(enableRetryableSubAsyncRequestBody) + .build(); assertThatThrownBy(() -> splittingPublisher.subscribe(requestBody -> { throw new RuntimeException("foobar"); }).get(5, TimeUnit.SECONDS)).hasMessageContaining("foobar"); assertThat(asyncRequestBody.cancelled).isTrue(); } + @Test + void retryableSubAsyncRequestBodyEnabled_shouldBeAbleToResubscribe() throws ExecutionException, InterruptedException, TimeoutException { + int chunkSize = 5; + AsyncRequestBody asyncRequestBody = FileAsyncRequestBody.builder() + .path(testFile.toPath()) + .chunkSizeInBytes(chunkSize) + .build(); + + SplittingPublisher splittingPublisher = SplittingPublisher.builder() + .asyncRequestBody(asyncRequestBody) + .splitConfiguration(AsyncRequestBodySplitConfiguration.builder() + .chunkSizeInBytes((long) chunkSize) + .bufferSizeInBytes((long) chunkSize * 4) + .build()) + .retryableSubAsyncRequestBodyEnabled(true) + .build(); + + + + Map, CompletableFuture>> futures = new HashMap<>(); + AtomicInteger index = new AtomicInteger(); + splittingPublisher.subscribe(requestBody -> { + int i = index.getAndIncrement(); + CompletableFuture future = new CompletableFuture<>(); + BaosSubscriber subscriber = new BaosSubscriber(future); + requestBody.subscribe(subscriber); + + future.whenComplete((r, t) -> { + CompletableFuture future2 = new CompletableFuture<>(); + BaosSubscriber anotherSubscriber = new BaosSubscriber(future2); + requestBody.subscribe(anotherSubscriber); + futures.put(i, Pair.of(future, future2)); + + future2.whenComplete((res, throwable) -> { + requestBody.close(); + }); + }); + }).get(5, TimeUnit.SECONDS); + + for (int i = 0; i < futures.size(); i++) { + assertThat(futures.get(i).left().join()).containsExactly( futures.get(i).right().join()); + } + } + private static void verifySplitContent(AsyncRequestBody asyncRequestBody, int chunkSize) throws Exception { - SplittingPublisher splittingPublisher = new SplittingPublisher(asyncRequestBody, - AsyncRequestBodySplitConfiguration.builder() - .chunkSizeInBytes((long) chunkSize) - .bufferSizeInBytes((long) chunkSize * 4) - .build()); + SplittingPublisher splittingPublisher = SplittingPublisher.builder() + .asyncRequestBody(asyncRequestBody) + .splitConfiguration(AsyncRequestBodySplitConfiguration.builder() + .chunkSizeInBytes((long) chunkSize) + .bufferSizeInBytes((long) chunkSize * 4) + .build()) + .retryableSubAsyncRequestBodyEnabled(false) + .build(); + + verifyIndividualAsyncRequestBody(splittingPublisher.map(m -> m), testFile.toPath(), chunkSize); + } - verifyIndividualAsyncRequestBody(splittingPublisher, testFile.toPath(), chunkSize); + private static void verifyRetryableSplitContent(AsyncRequestBody asyncRequestBody, int chunkSize) throws Exception { + SplittingPublisher splittingPublisher = SplittingPublisher.builder() + .asyncRequestBody(asyncRequestBody) + .splitConfiguration(AsyncRequestBodySplitConfiguration.builder() + .chunkSizeInBytes((long) chunkSize) + .bufferSizeInBytes((long) chunkSize * 4) + .build()) + .retryableSubAsyncRequestBodyEnabled(false) + .build(); + + verifyIndividualAsyncRequestBody(splittingPublisher.map(m -> m), testFile.toPath(), chunkSize); } private static class TestAsyncRequestBody implements AsyncRequestBody { @@ -204,30 +287,6 @@ public void cancel() { } } - private static final class OnlyRequestOnceSubscriber implements Subscriber { - private List asyncRequestBodies = new ArrayList<>(); - - @Override - public void onSubscribe(Subscription s) { - s.request(1); - } - - @Override - public void onNext(AsyncRequestBody requestBody) { - asyncRequestBodies.add(requestBody); - } - - @Override - public void onError(Throwable t) { - - } - - @Override - public void onComplete() { - - } - } - private static final class BaosSubscriber implements Subscriber { private final CompletableFuture resultFuture; diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/SplittingPublisherTestUtils.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/SplittingPublisherTestUtils.java index 04da97adbf42..095877d9f2b5 100644 --- a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/SplittingPublisherTestUtils.java +++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/SplittingPublisherTestUtils.java @@ -15,23 +15,16 @@ package software.amazon.awssdk.core.internal.async; -import static org.assertj.core.api.AssertionsForClassTypes.assertThat; - -import java.io.File; import java.io.FileInputStream; import java.nio.file.Path; import java.util.ArrayList; import java.util.List; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; -import java.util.concurrent.TimeoutException; import org.assertj.core.api.Assertions; -import org.reactivestreams.Publisher; import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.core.async.CloseableAsyncRequestBody; import software.amazon.awssdk.core.async.SdkPublisher; -import software.amazon.awssdk.core.internal.async.ByteArrayAsyncResponseTransformer; -import software.amazon.awssdk.core.internal.async.SplittingPublisherTest; public final class SplittingPublisherTestUtils { @@ -45,6 +38,11 @@ public static void verifyIndividualAsyncRequestBody(SdkPublisher { + if (requestBody instanceof CloseableAsyncRequestBody) { + ((CloseableAsyncRequestBody) requestBody).close(); + } + }); futures.add(baosFuture); }).get(5, TimeUnit.SECONDS); diff --git a/services/s3/src/it/java/software/amazon/awssdk/services/s3/multipart/S3MultipartClientPutObjectIntegrationTest.java b/services/s3/src/it/java/software/amazon/awssdk/services/s3/multipart/S3MultipartClientPutObjectIntegrationTest.java index e1f68ce2e234..de20ce7e6331 100644 --- a/services/s3/src/it/java/software/amazon/awssdk/services/s3/multipart/S3MultipartClientPutObjectIntegrationTest.java +++ b/services/s3/src/it/java/software/amazon/awssdk/services/s3/multipart/S3MultipartClientPutObjectIntegrationTest.java @@ -43,6 +43,7 @@ import java.util.Random; import java.util.UUID; import java.util.concurrent.CompletableFuture; +import java.util.stream.Stream; import java.util.zip.CRC32; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; @@ -54,6 +55,9 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import org.reactivestreams.Subscriber; import software.amazon.awssdk.core.ClientType; import software.amazon.awssdk.core.ResponseBytes; @@ -61,6 +65,7 @@ import software.amazon.awssdk.core.async.AsyncRequestBody; import software.amazon.awssdk.core.async.AsyncResponseTransformer; import software.amazon.awssdk.core.async.BlockingInputStreamAsyncRequestBody; +import software.amazon.awssdk.core.async.BufferedSplittableAsyncRequestBody; import software.amazon.awssdk.core.interceptor.Context; import software.amazon.awssdk.core.interceptor.ExecutionAttributes; import software.amazon.awssdk.core.interceptor.ExecutionInterceptor; @@ -75,6 +80,7 @@ import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest; import software.amazon.awssdk.services.s3.model.GetObjectRequest; import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; import software.amazon.awssdk.services.s3.model.PutObjectResponse; import software.amazon.awssdk.services.s3.model.UploadPartRequest; import software.amazon.awssdk.services.s3.utils.ChecksumUtils; @@ -93,7 +99,9 @@ public class S3MultipartClientPutObjectIntegrationTest extends S3IntegrationTest private static final CapturingInterceptor CAPTURING_INTERCEPTOR = new CapturingInterceptor(); private static File testFile; private static S3AsyncClient mpuS3Client; - private static ExecutorService executorService = Executors.newFixedThreadPool(2); + private static ExecutorService executorService = Executors.newFixedThreadPool(5); + private static byte[] bytes; + private static byte[] expectedChecksum; @BeforeAll public static void setup() throws Exception { @@ -101,6 +109,8 @@ public static void setup() throws Exception { createBucket(TEST_BUCKET); testFile = new RandomTempFile(OBJ_SIZE); + bytes = Files.readAllBytes(testFile.toPath()); + expectedChecksum = ChecksumUtils.computeCheckSum(Files.newInputStream(testFile.toPath())); mpuS3Client = S3AsyncClient .builder() .region(DEFAULT_REGION) @@ -119,6 +129,18 @@ public static void teardown() throws Exception { executorService.shutdown(); } + public static Stream asyncRequestBodies() { + return Stream.of(Arguments.of("file", AsyncRequestBody.fromFile(testFile)), + Arguments.of("bytes", AsyncRequestBody.fromBytes(bytes)), + Arguments.of("inputStream_knownLength", + AsyncRequestBody.fromInputStream(new ByteArrayInputStream(bytes), (long) bytes.length, + executorService)), + Arguments.of("inputStream_unknownLength", + AsyncRequestBody.fromInputStream(new ByteArrayInputStream(bytes), null, + executorService)) + ); + } + @BeforeEach public void reset() { CAPTURING_INTERCEPTOR.reset(); @@ -144,76 +166,28 @@ public void upload_blockingInputStream_shouldSucceed() throws IOException { assertEquals(expectedMd5, actualMd5); } - @Test - void putObject_fileRequestBody_objectSentCorrectly() throws Exception { - AsyncRequestBody body = AsyncRequestBody.fromFile(testFile.toPath()); + @ParameterizedTest + @MethodSource("asyncRequestBodies") + void putObject_variousRequestBody_objectSentCorrectly(String description, AsyncRequestBody body) throws Exception { mpuS3Client.putObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY), body).join(); - assertThat(CAPTURING_INTERCEPTOR.createMpuChecksumAlgorithm).isEqualTo("CRC32"); - assertThat(CAPTURING_INTERCEPTOR.uploadPartChecksumAlgorithm).isEqualTo("CRC32"); - ResponseInputStream objContent = s3.getObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY), ResponseTransformer.toInputStream()); - assertThat(objContent.response().contentLength()).isEqualTo(testFile.length()); - byte[] expectedSum = ChecksumUtils.computeCheckSum(Files.newInputStream(testFile.toPath())); - assertThat(ChecksumUtils.computeCheckSum(objContent)).isEqualTo(expectedSum); - } - - @Test - void putObject_inputStreamAsyncRequestBody_objectSentCorrectly() throws Exception { - AsyncRequestBody body = AsyncRequestBody.fromInputStream( - new FileInputStream(testFile), - Long.valueOf(OBJ_SIZE), - executorService); - mpuS3Client.putObject(r -> r.bucket(TEST_BUCKET) - .key(TEST_KEY) - .contentLength(Long.valueOf(OBJ_SIZE)), body).join(); - assertThat(CAPTURING_INTERCEPTOR.createMpuChecksumAlgorithm).isEqualTo("CRC32"); assertThat(CAPTURING_INTERCEPTOR.uploadPartChecksumAlgorithm).isEqualTo("CRC32"); - ResponseInputStream objContent = s3.getObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY), - ResponseTransformer.toInputStream()); - assertThat(objContent.response().contentLength()).isEqualTo(testFile.length()); - byte[] expectedSum = ChecksumUtils.computeCheckSum(Files.newInputStream(testFile.toPath())); - assertThat(ChecksumUtils.computeCheckSum(objContent)).isEqualTo(expectedSum); - } - - @Test - void putObject_byteAsyncRequestBody_objectSentCorrectly() throws Exception { - byte[] bytes = RandomStringUtils.randomAscii(OBJ_SIZE).getBytes(Charset.defaultCharset()); - AsyncRequestBody body = AsyncRequestBody.fromBytes(bytes); - mpuS3Client.putObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY), body).join(); - - assertThat(CAPTURING_INTERCEPTOR.createMpuChecksumAlgorithm).isEqualTo("CRC32"); - assertThat(CAPTURING_INTERCEPTOR.uploadPartChecksumAlgorithm).isEqualTo("CRC32"); - - ResponseInputStream objContent = s3.getObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY), - ResponseTransformer.toInputStream()); - - assertThat(objContent.response().contentLength()).isEqualTo(OBJ_SIZE); - byte[] expectedSum = ChecksumUtils.computeCheckSum(new ByteArrayInputStream(bytes)); - assertThat(ChecksumUtils.computeCheckSum(objContent)).isEqualTo(expectedSum); + assertThat(ChecksumUtils.computeCheckSum(objContent)).isEqualTo(expectedChecksum); } - @Test - void putObject_unknownContentLength_objectSentCorrectly() throws Exception { - AsyncRequestBody body = FileAsyncRequestBody.builder() - .path(testFile.toPath()) - .build(); - mpuS3Client.putObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY), new AsyncRequestBody() { - @Override - public Optional contentLength() { - return Optional.empty(); - } - @Override - public void subscribe(Subscriber s) { - body.subscribe(s); - } - }).get(30, SECONDS); + @ParameterizedTest + @MethodSource("asyncRequestBodies") + void putObject_wrapWithBufferedSplittableAsyncRequestBody_objectSentCorrectly(String description, + AsyncRequestBody asyncRequestBody) throws Exception { + AsyncRequestBody body = BufferedSplittableAsyncRequestBody.create(asyncRequestBody); + mpuS3Client.putObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY), body).join(); assertThat(CAPTURING_INTERCEPTOR.createMpuChecksumAlgorithm).isEqualTo("CRC32"); assertThat(CAPTURING_INTERCEPTOR.uploadPartChecksumAlgorithm).isEqualTo("CRC32"); @@ -222,8 +196,7 @@ public void subscribe(Subscriber s) { ResponseTransformer.toInputStream()); assertThat(objContent.response().contentLength()).isEqualTo(testFile.length()); - byte[] expectedSum = ChecksumUtils.computeCheckSum(Files.newInputStream(testFile.toPath())); - assertThat(ChecksumUtils.computeCheckSum(objContent)).isEqualTo(expectedSum); + assertThat(ChecksumUtils.computeCheckSum(objContent)).isEqualTo(expectedChecksum); } @Test @@ -251,8 +224,7 @@ void putObject_withSSECAndChecksum_objectSentCorrectly() throws Exception { ResponseTransformer.toInputStream()); assertThat(objContent.response().contentLength()).isEqualTo(testFile.length()); - byte[] expectedSum = ChecksumUtils.computeCheckSum(Files.newInputStream(testFile.toPath())); - assertThat(ChecksumUtils.computeCheckSum(objContent)).isEqualTo(expectedSum); + assertThat(ChecksumUtils.computeCheckSum(objContent)).isEqualTo(expectedChecksum); } @Test diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/KnownContentLengthAsyncRequestBodySubscriber.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/KnownContentLengthAsyncRequestBodySubscriber.java index 93bc0dfeb6f8..d86005d85bc4 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/KnownContentLengthAsyncRequestBodySubscriber.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/KnownContentLengthAsyncRequestBodySubscriber.java @@ -34,6 +34,7 @@ import org.reactivestreams.Subscription; import software.amazon.awssdk.annotations.SdkInternalApi; import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.core.async.CloseableAsyncRequestBody; import software.amazon.awssdk.core.async.listener.PublisherListener; import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadResponse; @@ -47,7 +48,7 @@ import software.amazon.awssdk.utils.Pair; @SdkInternalApi -public class KnownContentLengthAsyncRequestBodySubscriber implements Subscriber { +public class KnownContentLengthAsyncRequestBodySubscriber implements Subscriber { private static final Logger log = Logger.loggerFor(KnownContentLengthAsyncRequestBodySubscriber.class); @@ -144,16 +145,21 @@ public void onSubscribe(Subscription s) { } @Override - public void onNext(AsyncRequestBody asyncRequestBody) { + public void onNext(CloseableAsyncRequestBody asyncRequestBody) { if (isPaused || isDone) { return; } int currentPartNum = partNumber.getAndIncrement(); + + log.debug(() -> String.format("Received asyncRequestBody for part number %d with length %s", currentPartNum, + asyncRequestBody.contentLength())); + if (existingParts.containsKey(currentPartNum)) { asyncRequestBody.subscribe(new CancelledSubscriber<>()); - subscription.request(1); asyncRequestBody.contentLength().ifPresent(progressListener::subscriberOnNext); + asyncRequestBody.close(); + subscription.request(1); return; } @@ -178,10 +184,12 @@ public void onNext(AsyncRequestBody asyncRequestBody) { multipartUploadHelper.sendIndividualUploadPartRequest(uploadId, completedPartConsumer, futures, Pair.of(uploadRequest, asyncRequestBody), progressListener) .whenComplete((r, t) -> { + asyncRequestBody.close(); if (t != null) { if (shouldFailRequest()) { multipartUploadHelper.failRequestsElegantly(futures, t, uploadId, returnFuture, putObjectRequest); + subscription.cancel(); } } else { completeMultipartUploadIfFinished(asyncRequestBodyInFlight.decrementAndGet()); @@ -206,7 +214,7 @@ private Optional validatePart(AsyncRequestBody asyncRequestB } if (currentPartSize != partSize) { - return Optional.of(contentLengthMismatchForPart(partSize, currentPartSize)); + return Optional.of(contentLengthMismatchForPart(partSize, currentPartSize, currentPartNum)); } return Optional.empty(); } diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelper.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelper.java index d25d5b6fa7fa..d7c988c16e55 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelper.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelper.java @@ -162,11 +162,12 @@ static SdkClientException contentLengthMissingForPart(int currentPartNum) { return SdkClientException.create("Content length is missing on the AsyncRequestBody for part number " + currentPartNum); } - static SdkClientException contentLengthMismatchForPart(long expected, long actual) { + static SdkClientException contentLengthMismatchForPart(long expected, long actual, int partNum) { return SdkClientException.create(String.format("Content length must not be greater than " - + "part size. Expected: %d, Actual: %d", + + "part size. Expected: %d, Actual: %d, partNum: %d", expected, - actual)); + actual, + partNum)); } static SdkClientException partNumMismatch(int expectedNumParts, int actualNumParts) { diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithKnownContentLengthHelper.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithKnownContentLengthHelper.java index 04690677c92b..0fdeb1674798 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithKnownContentLengthHelper.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithKnownContentLengthHelper.java @@ -186,8 +186,8 @@ private void splitAndSubscribe(MpuRequestContext mpuRequestContext, CompletableF attachSubscriberToObservable(subscriber, mpuRequestContext.request().left()); mpuRequestContext.request().right() - .split(b -> b.chunkSizeInBytes(mpuRequestContext.partSize()) - .bufferSizeInBytes(maxMemoryUsageInBytes)) + .splitCloseable(b -> b.chunkSizeInBytes(mpuRequestContext.partSize()) + .bufferSizeInBytes(maxMemoryUsageInBytes)) .subscribe(subscriber); } diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithUnknownContentLengthHelper.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithUnknownContentLengthHelper.java index 520625ad90b0..cab480a540cb 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithUnknownContentLengthHelper.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithUnknownContentLengthHelper.java @@ -33,6 +33,7 @@ import org.reactivestreams.Subscription; import software.amazon.awssdk.annotations.SdkInternalApi; import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.core.async.CloseableAsyncRequestBody; import software.amazon.awssdk.core.async.SdkPublisher; import software.amazon.awssdk.core.async.listener.PublisherListener; import software.amazon.awssdk.core.exception.SdkClientException; @@ -81,9 +82,9 @@ public CompletableFuture uploadObject(PutObjectRequest putObj AsyncRequestBody asyncRequestBody) { CompletableFuture returnFuture = new CompletableFuture<>(); - SdkPublisher splitAsyncRequestBodyResponse = - asyncRequestBody.split(b -> b.chunkSizeInBytes(partSizeInBytes) - .bufferSizeInBytes(maxMemoryUsageInBytes)); + SdkPublisher splitAsyncRequestBodyResponse = + asyncRequestBody.splitCloseable(b -> b.chunkSizeInBytes(partSizeInBytes) + .bufferSizeInBytes(maxMemoryUsageInBytes)); splitAsyncRequestBodyResponse.subscribe(new UnknownContentLengthAsyncRequestBodySubscriber(partSizeInBytes, putObjectRequest, @@ -91,7 +92,7 @@ public CompletableFuture uploadObject(PutObjectRequest putObj return returnFuture; } - private class UnknownContentLengthAsyncRequestBodySubscriber implements Subscriber { + private class UnknownContentLengthAsyncRequestBodySubscriber implements Subscriber { /** * Indicates whether this is the first async request body or not. */ @@ -127,7 +128,7 @@ private class UnknownContentLengthAsyncRequestBodySubscriber implements Subscrib private final CompletableFuture returnFuture; private final PublisherListener progressListener; private Subscription subscription; - private AsyncRequestBody firstRequestBody; + private CloseableAsyncRequestBody firstRequestBody; private String uploadId; private volatile boolean isDone; @@ -161,7 +162,7 @@ public void onSubscribe(Subscription s) { } @Override - public void onNext(AsyncRequestBody asyncRequestBody) { + public void onNext(CloseableAsyncRequestBody asyncRequestBody) { if (isDone) { return; } @@ -224,14 +225,14 @@ private Optional validatePart(AsyncRequestBody asyncRequestB Long contentLengthCurrentPart = contentLength.get(); if (contentLengthCurrentPart > partSizeInBytes) { - return Optional.of(contentLengthMismatchForPart(partSizeInBytes, contentLengthCurrentPart)); + return Optional.of(contentLengthMismatchForPart(partSizeInBytes, contentLengthCurrentPart, currentPartNum)); } return Optional.empty(); } private void sendUploadPartRequest(String uploadId, - AsyncRequestBody asyncRequestBody, + CloseableAsyncRequestBody asyncRequestBody, int currentPartNum) { Long contentLengthCurrentPart = asyncRequestBody.contentLength().get(); this.contentLength.getAndAdd(contentLengthCurrentPart); @@ -240,6 +241,7 @@ private void sendUploadPartRequest(String uploadId, .sendIndividualUploadPartRequest(uploadId, completedParts::add, futures, uploadPart(asyncRequestBody, currentPartNum), progressListener) .whenComplete((r, t) -> { + asyncRequestBody.close(); if (t != null) { if (failureActionInitiated.compareAndSet(false, true)) { multipartUploadHelper.failRequestsElegantly(futures, t, uploadId, returnFuture, putObjectRequest); @@ -305,4 +307,4 @@ private void completeMultipartUploadIfFinish(int requestsInFlight) { } } } -} +} \ No newline at end of file diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/KnownContentLengthAsyncRequestBodySubscriberTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/KnownContentLengthAsyncRequestBodySubscriberTest.java index 4faf9d4a04b0..c18f088f1cd9 100644 --- a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/KnownContentLengthAsyncRequestBodySubscriberTest.java +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/KnownContentLengthAsyncRequestBodySubscriberTest.java @@ -17,7 +17,6 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyList; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; @@ -25,11 +24,9 @@ import java.io.IOException; import java.util.Collection; -import java.util.HashMap; import java.util.Map; import java.util.Optional; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.function.Function; import java.util.stream.Collectors; @@ -39,9 +36,9 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.mockito.ArgumentCaptor; -import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.core.async.CloseableAsyncRequestBody; import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.services.s3.S3AsyncClient; import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadResponse; @@ -60,7 +57,7 @@ public class KnownContentLengthAsyncRequestBodySubscriberTest { private static final int TOTAL_NUM_PARTS = 4; private static final String UPLOAD_ID = "1234"; private static RandomTempFile testFile; - + private AsyncRequestBody asyncRequestBody; private PutObjectRequest putObjectRequest; private S3AsyncClient s3AsyncClient; @@ -114,7 +111,7 @@ void validatePart_withPartSizeExceedingLimit_shouldFailRequest() { void validateLastPartSize_withIncorrectSize_shouldFailRequest() { long expectedLastPartSize = MPU_CONTENT_SIZE % PART_SIZE; long incorrectLastPartSize = expectedLastPartSize + 1; - + KnownContentLengthAsyncRequestBodySubscriber lastPartSubscriber = createSubscriber(createDefaultMpuRequestContext()); lastPartSubscriber.onSubscribe(subscription); @@ -130,12 +127,12 @@ void validateLastPartSize_withIncorrectSize_shouldFailRequest() { @Test void validateTotalPartNum_receivedMoreParts_shouldFail() { long expectedLastPartSize = MPU_CONTENT_SIZE % PART_SIZE; - + KnownContentLengthAsyncRequestBodySubscriber lastPartSubscriber = createSubscriber(createDefaultMpuRequestContext()); lastPartSubscriber.onSubscribe(subscription); for (int i = 0; i < TOTAL_NUM_PARTS - 1; i++) { - AsyncRequestBody regularPart = createMockAsyncRequestBody(PART_SIZE); + CloseableAsyncRequestBody regularPart = createMockAsyncRequestBody(PART_SIZE); when(multipartUploadHelper.sendIndividualUploadPartRequest(eq(UPLOAD_ID), any(), any(), any(), any())) .thenReturn(CompletableFuture.completedFuture(null)); lastPartSubscriber.onNext(regularPart); @@ -157,7 +154,7 @@ void validateLastPartSize_withCorrectSize_shouldNotFail() { subscriber.onSubscribe(subscription); for (int i = 0; i < TOTAL_NUM_PARTS - 1; i++) { - AsyncRequestBody regularPart = createMockAsyncRequestBody(PART_SIZE); + CloseableAsyncRequestBody regularPart = createMockAsyncRequestBody(PART_SIZE); when(multipartUploadHelper.sendIndividualUploadPartRequest(eq(UPLOAD_ID), any(), any(), any(), any())) .thenReturn(CompletableFuture.completedFuture(null)); subscriber.onNext(regularPart); @@ -175,7 +172,7 @@ void validateLastPartSize_withCorrectSize_shouldNotFail() { void pause_withOngoingCompleteMpuFuture_shouldReturnTokenAndCancelFuture() { CompletableFuture completeMpuFuture = new CompletableFuture<>(); int numExistingParts = 2; - + S3ResumeToken resumeToken = testPauseScenario(numExistingParts, completeMpuFuture); verifyResumeToken(resumeToken, numExistingParts); @@ -187,7 +184,7 @@ void pause_withCompletedCompleteMpuFuture_shouldReturnNullToken() { CompletableFuture completeMpuFuture = CompletableFuture.completedFuture(CompleteMultipartUploadResponse.builder().build()); int numExistingParts = 2; - + S3ResumeToken resumeToken = testPauseScenario(numExistingParts, completeMpuFuture); assertThat(resumeToken).isNull(); @@ -196,15 +193,15 @@ void pause_withCompletedCompleteMpuFuture_shouldReturnNullToken() { @Test void pause_withUninitiatedCompleteMpuFuture_shouldReturnToken() { int numExistingParts = 2; - + S3ResumeToken resumeToken = testPauseScenario(numExistingParts, null); verifyResumeToken(resumeToken, numExistingParts); } - - private S3ResumeToken testPauseScenario(int numExistingParts, + + private S3ResumeToken testPauseScenario(int numExistingParts, CompletableFuture completeMpuFuture) { - KnownContentLengthAsyncRequestBodySubscriber subscriber = + KnownContentLengthAsyncRequestBodySubscriber subscriber = createSubscriber(createMpuRequestContextWithExistingParts(numExistingParts)); when(multipartUploadHelper.completeMultipartUpload(any(CompletableFuture.class), any(String.class), @@ -246,14 +243,14 @@ private KnownContentLengthAsyncRequestBodySubscriber createSubscriber(MpuRequest return new KnownContentLengthAsyncRequestBodySubscriber(mpuRequestContext, returnFuture, multipartUploadHelper); } - private AsyncRequestBody createMockAsyncRequestBody(long contentLength) { - AsyncRequestBody mockBody = mock(AsyncRequestBody.class); + private CloseableAsyncRequestBody createMockAsyncRequestBody(long contentLength) { + CloseableAsyncRequestBody mockBody = mock(CloseableAsyncRequestBody.class); when(mockBody.contentLength()).thenReturn(Optional.of(contentLength)); return mockBody; } - private AsyncRequestBody createMockAsyncRequestBodyWithEmptyContentLength() { - AsyncRequestBody mockBody = mock(AsyncRequestBody.class); + private CloseableAsyncRequestBody createMockAsyncRequestBodyWithEmptyContentLength() { + CloseableAsyncRequestBody mockBody = mock(CloseableAsyncRequestBody.class); when(mockBody.contentLength()).thenReturn(Optional.empty()); return mockBody; } diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/S3MultipartClientPutObjectWiremockTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/S3MultipartClientPutObjectWiremockTest.java index 90e14dcff2dd..859f5aebacac 100644 --- a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/S3MultipartClientPutObjectWiremockTest.java +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/S3MultipartClientPutObjectWiremockTest.java @@ -18,14 +18,20 @@ import static com.github.tomakehurst.wiremock.client.WireMock.aResponse; import static com.github.tomakehurst.wiremock.client.WireMock.anyUrl; import static com.github.tomakehurst.wiremock.client.WireMock.delete; +import static com.github.tomakehurst.wiremock.client.WireMock.matching; import static com.github.tomakehurst.wiremock.client.WireMock.post; import static com.github.tomakehurst.wiremock.client.WireMock.put; +import static com.github.tomakehurst.wiremock.client.WireMock.putRequestedFor; import static com.github.tomakehurst.wiremock.client.WireMock.stubFor; import static com.github.tomakehurst.wiremock.client.WireMock.urlEqualTo; +import static com.github.tomakehurst.wiremock.client.WireMock.verify; import static org.assertj.core.api.Assertions.assertThatThrownBy; +import com.github.tomakehurst.wiremock.client.ResponseDefinitionBuilder; +import com.github.tomakehurst.wiremock.http.Fault; import com.github.tomakehurst.wiremock.junit5.WireMockRuntimeInfo; import com.github.tomakehurst.wiremock.junit5.WireMockTest; +import com.github.tomakehurst.wiremock.stubbing.Scenario; import io.reactivex.rxjava3.core.Flowable; import java.io.InputStream; import java.net.URI; @@ -46,20 +52,19 @@ import org.reactivestreams.Subscriber; import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; +import software.amazon.awssdk.core.SdkBytes; import software.amazon.awssdk.core.async.AsyncRequestBody; -import software.amazon.awssdk.core.async.AsyncRequestBodySplitConfiguration; import software.amazon.awssdk.core.async.BlockingInputStreamAsyncRequestBody; -import software.amazon.awssdk.core.async.SdkPublisher; -import software.amazon.awssdk.core.exception.SdkClientException; -import software.amazon.awssdk.http.crt.AwsCrtAsyncHttpClient; +import software.amazon.awssdk.core.async.BufferedSplittableAsyncRequestBody; +import software.amazon.awssdk.core.exception.NonRetryableException; +import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient; import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.s3.S3AsyncClient; import software.amazon.awssdk.services.s3.model.PutObjectResponse; import software.amazon.awssdk.services.s3.model.S3Exception; -import software.amazon.awssdk.utils.async.SimplePublisher; @WireMockTest -@Timeout(100) +@Timeout(120) public class S3MultipartClientPutObjectWiremockTest { private static final String BUCKET = "Example-Bucket"; @@ -71,52 +76,43 @@ public class S3MultipartClientPutObjectWiremockTest { + ""; private S3AsyncClient s3AsyncClient; - public static Stream invalidAsyncRequestBodies() { + public static Stream retryableErrorTestCase() { return Stream.of( - Arguments.of("knownContentLength_nullPartSize", new TestPublisherWithIncorrectSplitImpl(20L, null), - "Content length is missing on the AsyncRequestBody for part number"), - Arguments.of("unknownContentLength_nullPartSize", new TestPublisherWithIncorrectSplitImpl(null, null), - "Content length is missing on the AsyncRequestBody for part number"), - Arguments.of("knownContentLength_partSizeIncorrect", new TestPublisherWithIncorrectSplitImpl(20L, 11L), - "Content length must not be greater than part size"), - Arguments.of("unknownContentLength_partSizeIncorrect", new TestPublisherWithIncorrectSplitImpl(null, 11L), - "Content length must not be greater than part size"), - Arguments.of("knownContentLength_sendMoreParts", new TestPublisherWithIncorrectSplitImpl(20L, 10L, 3), - "The number of parts divided is not equal to the expected number of parts"), - Arguments.of("knownContentLength_sendFewerParts", new TestPublisherWithIncorrectSplitImpl(20L, 10L, 1), - "The number of parts divided is not equal to the expected number of parts")); + Arguments.of("unknownContentLength_failOfConnectionReset", null, + aResponse().withFault(Fault.CONNECTION_RESET_BY_PEER)), + Arguments.of("unknownContentLength_failOf500", null, + aResponse().withStatus(500)), + Arguments.of("knownContentLength_failOfConnectionReset", 20L, + aResponse().withFault(Fault.CONNECTION_RESET_BY_PEER)), + Arguments.of("knownContentLength_failOf500", 20L, + aResponse().withStatus(500)) + ); } @BeforeEach public void setup(WireMockRuntimeInfo wiremock) { - stubFailedPutObjectCalls(); s3AsyncClient = S3AsyncClient.builder() .region(Region.US_EAST_1) .endpointOverride(URI.create("http://localhost:" + wiremock.getHttpPort())) .credentialsProvider( StaticCredentialsProvider.create(AwsBasicCredentials.create("key", "secret"))) .multipartEnabled(true) - .multipartConfiguration(b -> b.minimumPartSizeInBytes(10L).apiCallBufferSizeInBytes(10L)) - .httpClientBuilder(AwsCrtAsyncHttpClient.builder()) + .multipartConfiguration(b -> b.minimumPartSizeInBytes(10L).apiCallBufferSizeInBytes(20L)) + .httpClientBuilder(NettyNioAsyncHttpClient.builder()) .build(); } - private void stubFailedPutObjectCalls() { + private void stubPutObject404Calls() { stubFor(post(anyUrl()).willReturn(aResponse().withStatus(200).withBody(CREATE_MULTIPART_PAYLOAD))); stubFor(put(anyUrl()).willReturn(aResponse().withStatus(404))); stubFor(put(urlEqualTo("/Example-Bucket/Example-Object?partNumber=1&uploadId=string")).willReturn(aResponse().withStatus(200))); stubFor(delete(anyUrl()).willReturn(aResponse().withStatus(200))); } - private void stubSuccessfulPutObjectCalls() { - stubFor(post(anyUrl()).willReturn(aResponse().withStatus(200).withBody(CREATE_MULTIPART_PAYLOAD))); - stubFor(put(anyUrl()).willReturn(aResponse().withStatus(200))); - } - - // https://github.com/aws/aws-sdk-java-v2/issues/4801 @Test void uploadWithUnknownContentLength_onePartFails_shouldCancelUpstream() { + stubPutObject404Calls(); BlockingInputStreamAsyncRequestBody blockingInputStreamAsyncRequestBody = AsyncRequestBody.forBlockingInputStream(null); CompletableFuture putObjectResponse = s3AsyncClient.putObject( r -> r.bucket(BUCKET).key(KEY), blockingInputStreamAsyncRequestBody); @@ -132,6 +128,7 @@ void uploadWithUnknownContentLength_onePartFails_shouldCancelUpstream() { @Test void uploadWithKnownContentLength_onePartFails_shouldCancelUpstream() { + stubPutObject404Calls(); BlockingInputStreamAsyncRequestBody blockingInputStreamAsyncRequestBody = AsyncRequestBody.forBlockingInputStream(1024L * 20); // must be larger than the buffer used in // InputStreamConsumingPublisher to trigger the error @@ -147,86 +144,94 @@ void uploadWithKnownContentLength_onePartFails_shouldCancelUpstream() { assertThatThrownBy(() -> putObjectResponse.join()).hasRootCauseInstanceOf(S3Exception.class); } - @ParameterizedTest(name = "{index} {0}") - @MethodSource("invalidAsyncRequestBodies") - void uploadWithIncorrectAsyncRequestBodySplit_contentLengthMismatch_shouldThrowException(String description, - TestPublisherWithIncorrectSplitImpl asyncRequestBody, - String errorMsg) { - stubSuccessfulPutObjectCalls(); - CompletableFuture putObjectResponse = s3AsyncClient.putObject( - r -> r.bucket(BUCKET).key(KEY), asyncRequestBody); + @ParameterizedTest + @MethodSource("retryableErrorTestCase") + void mpuWithBufferedSplittableAsyncRequestBody_partsFailOfRetryableError_shouldRetry(String description, + Long contentLength, + ResponseDefinitionBuilder responseDefinitionBuilder) { + stubUploadPartFailsInitialAttemptSucceedsUponRetryCalls(responseDefinitionBuilder); + List buffers = new ArrayList<>(); + buffers.add(SdkBytes.fromUtf8String(RandomStringUtils.randomAscii(10)).asByteBuffer()); + buffers.add(SdkBytes.fromUtf8String(RandomStringUtils.randomAscii(10)).asByteBuffer()); + AsyncRequestBody asyncRequestBody = new AsyncRequestBody() { + @Override + public Optional contentLength() { + return Optional.ofNullable(contentLength); + } - assertThatThrownBy(() -> putObjectResponse.join()).hasMessageContaining(errorMsg) - .hasRootCauseInstanceOf(SdkClientException.class); + @Override + public void subscribe(Subscriber s) { + Flowable.fromIterable(buffers).subscribe(s); + } + }; + + s3AsyncClient.putObject(b -> b.bucket(BUCKET).key(KEY), BufferedSplittableAsyncRequestBody.create(asyncRequestBody)) + .join(); + + verify(2, putRequestedFor(anyUrl()).withQueryParam("partNumber", matching(String.valueOf(1)))); + verify(2, putRequestedFor(anyUrl()).withQueryParam("partNumber", matching(String.valueOf(2)))); } - private InputStream createUnlimitedInputStream() { - return new InputStream() { + @ParameterizedTest + @MethodSource("retryableErrorTestCase") + void mpuDefaultSplitImpl_partsFailOfRetryableError_shouldFail(String description, + Long contentLength, + ResponseDefinitionBuilder responseDefinitionBuilder) { + stubUploadPartFailsInitialAttemptSucceedsUponRetryCalls(responseDefinitionBuilder); + List buffers = new ArrayList<>(); + buffers.add(SdkBytes.fromUtf8String(RandomStringUtils.randomAscii(10)).asByteBuffer()); + buffers.add(SdkBytes.fromUtf8String(RandomStringUtils.randomAscii(10)).asByteBuffer()); + AsyncRequestBody asyncRequestBody = new AsyncRequestBody() { @Override - public int read() { - return 1; + public Optional contentLength() { + return Optional.ofNullable(contentLength); + } + + @Override + public void subscribe(Subscriber s) { + Flowable.fromIterable(buffers).subscribe(s); } }; + + assertThatThrownBy(() -> s3AsyncClient.putObject(b -> b.bucket(BUCKET).key(KEY), asyncRequestBody) + .join()) + .hasCauseInstanceOf(NonRetryableException.class) + .hasMessageContaining("Multiple subscribers detected."); + + verify(1, putRequestedFor(anyUrl()).withQueryParam("partNumber", matching(String.valueOf(1)))); + verify(1, putRequestedFor(anyUrl()).withQueryParam("partNumber", matching(String.valueOf(1)))); } - private static class TestPublisherWithIncorrectSplitImpl implements AsyncRequestBody { - private SimplePublisher simplePublisher = new SimplePublisher<>(); - private Long totalSize; - private Long partSize; - private Integer numParts; - - private TestPublisherWithIncorrectSplitImpl(Long totalSize, Long partSize) { - this.totalSize = totalSize; - this.partSize = partSize; - } - - private TestPublisherWithIncorrectSplitImpl(Long totalSize, long partSize, int numParts) { - this.totalSize = totalSize; - this.partSize = partSize; - this.numParts = numParts; - } - - @Override - public Optional contentLength() { - return Optional.ofNullable(totalSize); - } - - @Override - public void subscribe(Subscriber s) { - simplePublisher.subscribe(s); - } - - @Override - public SdkPublisher split(AsyncRequestBodySplitConfiguration splitConfiguration) { - List requestBodies = new ArrayList<>(); - int numAsyncRequestBodies = numParts == null ? 1 : numParts; - for (int i = 0; i < numAsyncRequestBodies; i++) { - requestBodies.add(new TestAsyncRequestBody(partSize)); - } - return SdkPublisher.adapt(Flowable.fromArray(requestBodies.toArray(new AsyncRequestBody[requestBodies.size()]))); - } + private void stubUploadPartFailsInitialAttemptSucceedsUponRetryCalls(ResponseDefinitionBuilder responseDefinitionBuilder) { + stubFor(post(anyUrl()).willReturn(aResponse().withStatus(200).withBody(CREATE_MULTIPART_PAYLOAD))); + stubUploadFailsInitialAttemptCalls(1, responseDefinitionBuilder); + stubUploadFailsInitialAttemptCalls(2, responseDefinitionBuilder); } - private static class TestAsyncRequestBody implements AsyncRequestBody { - private Long partSize; - private SimplePublisher simplePublisher = new SimplePublisher<>(); - - public TestAsyncRequestBody(Long partSize) { - this.partSize = partSize; - } - - @Override - public Optional contentLength() { - return Optional.ofNullable(partSize); - } - - @Override - public void subscribe(Subscriber s) { - simplePublisher.subscribe(s); - simplePublisher.send(ByteBuffer.wrap( - RandomStringUtils.randomAscii(Math.toIntExact(partSize)).getBytes())); - simplePublisher.complete(); - } + private void stubUploadFailsInitialAttemptCalls(int partNumber, ResponseDefinitionBuilder responseDefinitionBuilder) { + stubFor(put(anyUrl()) + .withQueryParam("partNumber", matching(String.valueOf(partNumber))) + .inScenario(String.valueOf(partNumber)) + .whenScenarioStateIs(Scenario.STARTED) + .willReturn(responseDefinitionBuilder) + .willSetStateTo("SecondAttempt" + partNumber)); + + stubFor(put(anyUrl()) + .withQueryParam("partNumber", matching(String.valueOf(partNumber))) + .inScenario(String.valueOf(partNumber)) + .whenScenarioStateIs("SecondAttempt" + partNumber) + .willReturn(aResponse().withStatus(200))); + } + + + private InputStream createUnlimitedInputStream() { + return new InputStream() { + @Override + public int read() { + return 1; + } + }; } } + diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithUnknownContentLengthHelperTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithUnknownContentLengthHelperTest.java index 972f0b86241a..b7bd330a6e75 100644 --- a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithUnknownContentLengthHelperTest.java +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithUnknownContentLengthHelperTest.java @@ -26,11 +26,9 @@ import static software.amazon.awssdk.services.s3.internal.multipart.MpuTestUtils.stubSuccessfulCreateMultipartCall; import static software.amazon.awssdk.services.s3.internal.multipart.MpuTestUtils.stubSuccessfulUploadPartCalls; -import java.io.ByteArrayInputStream; import java.io.FileInputStream; import java.io.FileNotFoundException; import java.io.IOException; -import java.io.InputStream; import java.util.List; import java.util.Optional; import java.util.concurrent.CompletableFuture; @@ -46,18 +44,17 @@ import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.core.async.CloseableAsyncRequestBody; import software.amazon.awssdk.core.async.BlockingInputStreamAsyncRequestBody; import software.amazon.awssdk.core.async.SdkPublisher; import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.services.s3.S3AsyncClient; import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadRequest; import software.amazon.awssdk.services.s3.model.CompletedPart; -import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest; import software.amazon.awssdk.services.s3.model.PutObjectRequest; import software.amazon.awssdk.services.s3.model.PutObjectResponse; import software.amazon.awssdk.services.s3.model.UploadPartRequest; import software.amazon.awssdk.testutils.RandomTempFile; -import software.amazon.awssdk.utils.StringInputStream; public class UploadWithUnknownContentLengthHelperTest { private static final String BUCKET = "bucket"; @@ -114,14 +111,14 @@ void upload_blockingInputStream_shouldInOrder() throws FileNotFoundException { @Test void uploadObject_withMissingContentLength_shouldFailRequest() { - AsyncRequestBody asyncRequestBody = createMockAsyncRequestBodyWithEmptyContentLength(); + CloseableAsyncRequestBody asyncRequestBody = createMockAsyncRequestBodyWithEmptyContentLength(); CompletableFuture future = setupAndTriggerUploadFailure(asyncRequestBody); verifyFailureWithMessage(future, "Content length is missing on the AsyncRequestBody for part number"); } @Test void uploadObject_withPartSizeExceedingLimit_shouldFailRequest() { - AsyncRequestBody asyncRequestBody = createMockAsyncRequestBody(PART_SIZE + 1); + CloseableAsyncRequestBody asyncRequestBody = createMockAsyncRequestBody(PART_SIZE + 1); CompletableFuture future = setupAndTriggerUploadFailure(asyncRequestBody); verifyFailureWithMessage(future, "Content length must not be greater than part size"); } @@ -139,27 +136,27 @@ private List createCompletedParts(int totalNumParts) { .collect(Collectors.toList()); } - private AsyncRequestBody createMockAsyncRequestBody(long contentLength) { - AsyncRequestBody mockBody = mock(AsyncRequestBody.class); + private CloseableAsyncRequestBody createMockAsyncRequestBody(long contentLength) { + CloseableAsyncRequestBody mockBody = mock(CloseableAsyncRequestBody.class); when(mockBody.contentLength()).thenReturn(Optional.of(contentLength)); return mockBody; } - private AsyncRequestBody createMockAsyncRequestBodyWithEmptyContentLength() { - AsyncRequestBody mockBody = mock(AsyncRequestBody.class); + private CloseableAsyncRequestBody createMockAsyncRequestBodyWithEmptyContentLength() { + CloseableAsyncRequestBody mockBody = mock(CloseableAsyncRequestBody.class); when(mockBody.contentLength()).thenReturn(Optional.empty()); return mockBody; } - private CompletableFuture setupAndTriggerUploadFailure(AsyncRequestBody asyncRequestBody) { - SdkPublisher mockPublisher = mock(SdkPublisher.class); - when(asyncRequestBody.split(any(Consumer.class))).thenReturn(mockPublisher); + private CompletableFuture setupAndTriggerUploadFailure(CloseableAsyncRequestBody asyncRequestBody) { + SdkPublisher mockPublisher = mock(SdkPublisher.class); + when(asyncRequestBody.splitCloseable(any(Consumer.class))).thenReturn(mockPublisher); - ArgumentCaptor> subscriberCaptor = ArgumentCaptor.forClass(Subscriber.class); + ArgumentCaptor> subscriberCaptor = ArgumentCaptor.forClass(Subscriber.class); CompletableFuture future = helper.uploadObject(createPutObjectRequest(), asyncRequestBody); verify(mockPublisher).subscribe(subscriberCaptor.capture()); - Subscriber subscriber = subscriberCaptor.getValue(); + Subscriber subscriber = subscriberCaptor.getValue(); Subscription subscription = mock(Subscription.class); subscriber.onSubscribe(subscription); diff --git a/services/s3/src/test/resources/log4j2.properties b/services/s3/src/test/resources/log4j2.properties index ad5cb8e79a64..fc101dc6e0d7 100644 --- a/services/s3/src/test/resources/log4j2.properties +++ b/services/s3/src/test/resources/log4j2.properties @@ -36,3 +36,6 @@ rootLogger.appenderRef.stdout.ref = ConsoleAppender # #logger.netty.name = io.netty.handler.logging #logger.netty.level = debug + +#logger.multipart.name = software.amazon.awssdk.services.s3.internal.multipart +#logger.multipart.level = debug \ No newline at end of file diff --git a/test/architecture-tests/src/test/java/software/amazon/awssdk/archtests/CodingConventionWithSuppressionTest.java b/test/architecture-tests/src/test/java/software/amazon/awssdk/archtests/CodingConventionWithSuppressionTest.java index d2edcaac742d..635cfdc834ce 100644 --- a/test/architecture-tests/src/test/java/software/amazon/awssdk/archtests/CodingConventionWithSuppressionTest.java +++ b/test/architecture-tests/src/test/java/software/amazon/awssdk/archtests/CodingConventionWithSuppressionTest.java @@ -32,6 +32,7 @@ import java.util.Set; import java.util.regex.Pattern; import org.junit.jupiter.api.Test; +import software.amazon.awssdk.core.internal.async.RetryableSubAsyncRequestBody; import software.amazon.awssdk.core.internal.http.pipeline.stages.MakeHttpRequestStage; import software.amazon.awssdk.core.sync.ResponseTransformer; import software.amazon.awssdk.metrics.publishers.emf.EmfMetricLoggingPublisher; @@ -52,7 +53,8 @@ public class CodingConventionWithSuppressionTest { ArchUtils.classNameToPattern(MakeHttpRequestStage.class), ArchUtils.classNameToPattern("software.amazon.awssdk.services.s3.internal.crt.S3CrtResponseHandlerAdapter"), ArchUtils.classNameToPattern( - "software.amazon.awssdk.services.s3.internal.crt.CrtResponseFileResponseTransformer"))); + "software.amazon.awssdk.services.s3.internal.crt.CrtResponseFileResponseTransformer"), + ArchUtils.classNameToPattern(RetryableSubAsyncRequestBody.class))); private static final Set ALLOWED_ERROR_LOG_SUPPRESSION = new HashSet<>( Arrays.asList( diff --git a/test/s3-tests/src/it/java/software/amazon/awssdk/services/s3/regression/upload/UploadStreamingRegressionTesting.java b/test/s3-tests/src/it/java/software/amazon/awssdk/services/s3/regression/upload/UploadStreamingRegressionTesting.java index 151dd6b02192..8785ff44d6b1 100644 --- a/test/s3-tests/src/it/java/software/amazon/awssdk/services/s3/regression/upload/UploadStreamingRegressionTesting.java +++ b/test/s3-tests/src/it/java/software/amazon/awssdk/services/s3/regression/upload/UploadStreamingRegressionTesting.java @@ -15,6 +15,7 @@ package software.amazon.awssdk.services.s3.regression.upload; +import static org.assertj.core.api.Assertions.as; import static org.assertj.core.api.Assertions.assertThat; import static software.amazon.awssdk.services.s3.regression.S3ChecksumsTestUtils.assumeNotAccessPointWithPathStyle; import static software.amazon.awssdk.services.s3.regression.S3ChecksumsTestUtils.crc32; @@ -46,6 +47,7 @@ import software.amazon.awssdk.core.async.AsyncRequestBody; import software.amazon.awssdk.core.async.BlockingInputStreamAsyncRequestBody; import software.amazon.awssdk.core.async.BlockingOutputStreamAsyncRequestBody; +import software.amazon.awssdk.core.async.BufferedSplittableAsyncRequestBody; import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; import software.amazon.awssdk.core.interceptor.Context; import software.amazon.awssdk.core.interceptor.ExecutionAttributes; @@ -356,6 +358,23 @@ protected TestAsyncBody getAsyncRequestBody(BodyType bodyType, ContentSize conte contentSize.precalculatedCrc32(), bodyType); } + case BUFFERED_SPLITTABLE_KNOWN_CONTENT_LENGTH: { + byte[] content = contentSize.byteContent(); + + AsyncRequestBody asyncRequestBody = AsyncRequestBody.fromInputStream( + new ByteArrayInputStream(content), + (long) content.length, ASYNC_REQUEST_BODY_EXECUTOR); + return new TestAsyncBody(BufferedSplittableAsyncRequestBody.create(asyncRequestBody), content.length, + contentSize.precalculatedCrc32(), bodyType); + } + case BUFFERED_SPLITTABLE_UNKNOWN_CONTENT_LENGTH: { + byte[] content = contentSize.byteContent(); + + AsyncRequestBody asyncRequestBody = AsyncRequestBody.fromInputStream( + new ByteArrayInputStream(content), null, ASYNC_REQUEST_BODY_EXECUTOR); + return new TestAsyncBody(BufferedSplittableAsyncRequestBody.create(asyncRequestBody), content.length, + contentSize.precalculatedCrc32(), bodyType); + } default: throw new RuntimeException("Unsupported async body type: " + bodyType); } @@ -398,7 +417,9 @@ protected enum BodyType { BUFFERS_REMAINING, BLOCKING_INPUT_STREAM, - BLOCKING_OUTPUT_STREAM + BLOCKING_OUTPUT_STREAM, + BUFFERED_SPLITTABLE_KNOWN_CONTENT_LENGTH, + BUFFERED_SPLITTABLE_UNKNOWN_CONTENT_LENGTH } protected enum ContentSize {