Skip to content

Commit 6a31b25

Browse files
Fix download directory concurrency issue and refactor (#4180)
1 parent c2109a6 commit 6a31b25

File tree

5 files changed

+962
-154
lines changed

5 files changed

+962
-154
lines changed

sdk/src/Services/S3/Custom/Transfer/Internal/DownloadDirectoryCommand.cs

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,7 @@ internal partial class DownloadDirectoryCommand : BaseCommand<TransferUtilityDow
4848
long _transferredBytes;
4949
string _currentFile;
5050

51-
internal DownloadDirectoryCommand(IAmazonS3 s3Client, TransferUtilityDownloadDirectoryRequest request)
52-
: this(s3Client, request, useMultipartDownload: false)
53-
{
54-
}
55-
56-
internal DownloadDirectoryCommand(IAmazonS3 s3Client, TransferUtilityDownloadDirectoryRequest request, bool useMultipartDownload)
51+
internal DownloadDirectoryCommand(IAmazonS3 s3Client, TransferUtilityDownloadDirectoryRequest request, TransferUtilityConfig config, bool useMultipartDownload)
5752
{
5853
if (s3Client == null)
5954
throw new ArgumentNullException(nameof(s3Client));
@@ -62,6 +57,7 @@ internal DownloadDirectoryCommand(IAmazonS3 s3Client, TransferUtilityDownloadDir
6257

6358
this._s3Client = s3Client;
6459
this._request = request;
60+
this._config = config;
6561
this._skipEncryptionInstructionFiles = s3Client is Amazon.S3.Internal.IAmazonS3Encryption;
6662
_failurePolicy =
6763
request.FailurePolicy == FailurePolicy.AbortOnFailure
@@ -70,12 +66,6 @@ internal DownloadDirectoryCommand(IAmazonS3 s3Client, TransferUtilityDownloadDir
7066
this._useMultipartDownload = useMultipartDownload;
7167
}
7268

73-
internal DownloadDirectoryCommand(IAmazonS3 s3Client, TransferUtilityDownloadDirectoryRequest request, TransferUtilityConfig config, bool useMultipartDownload)
74-
: this(s3Client, request, useMultipartDownload)
75-
{
76-
this._config = config;
77-
}
78-
7969
private void downloadedProgressEventCallback(object sender, WriteObjectProgressArgs e)
8070
{
8171
var transferredBytes = Interlocked.Add(ref _transferredBytes, e.IncrementTransferred);

sdk/src/Services/S3/Custom/Transfer/Internal/TaskHelpers.cs

Lines changed: 103 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,12 @@
1313
* permissions and limitations under the License.
1414
*/
1515

16+
using System;
1617
using System.Collections.Generic;
18+
using System.Linq;
1719
using System.Threading;
1820
using System.Threading.Tasks;
21+
using Amazon.Runtime.Internal.Util;
1922

2023
namespace Amazon.S3.Transfer.Internal
2124
{
@@ -24,6 +27,11 @@ namespace Amazon.S3.Transfer.Internal
2427
/// </summary>
2528
internal static class TaskHelpers
2629
{
30+
private static Logger Logger
31+
{
32+
get { return Logger.GetLogger(typeof(TaskHelpers)); }
33+
}
34+
2735
/// <summary>
2836
/// Waits for all tasks to complete or till any task fails or is canceled.
2937
/// </summary>
@@ -33,7 +41,10 @@ internal static class TaskHelpers
3341
internal static async Task WhenAllOrFirstExceptionAsync(List<Task> pendingTasks, CancellationToken cancellationToken)
3442
{
3543
int processed = 0;
36-
int total = pendingTasks.Count;
44+
int total = pendingTasks.Count;
45+
46+
Logger.DebugFormat("TaskHelpers.WhenAllOrFirstExceptionAsync: Starting with TotalTasks={0}", total);
47+
3748
while (processed < total)
3849
{
3950
cancellationToken.ThrowIfCancellationRequested();
@@ -48,7 +59,12 @@ await completedTask
4859

4960
pendingTasks.Remove(completedTask);
5061
processed++;
62+
63+
Logger.DebugFormat("TaskHelpers.WhenAllOrFirstExceptionAsync: Task completed (Processed={0}/{1}, Remaining={2})",
64+
processed, total, pendingTasks.Count);
5165
}
66+
67+
Logger.DebugFormat("TaskHelpers.WhenAllOrFirstExceptionAsync: All tasks completed (Total={0})", total);
5268
}
5369

5470
/// <summary>
@@ -64,6 +80,9 @@ internal static async Task<List<T>> WhenAllOrFirstExceptionAsync<T>(List<Task<T>
6480
int processed = 0;
6581
int total = pendingTasks.Count;
6682
var responses = new List<T>();
83+
84+
Logger.DebugFormat("TaskHelpers.WhenAllOrFirstExceptionAsync<T>: Starting with TotalTasks={0}", total);
85+
6786
while (processed < total)
6887
{
6988
cancellationToken.ThrowIfCancellationRequested();
@@ -79,9 +98,92 @@ internal static async Task<List<T>> WhenAllOrFirstExceptionAsync<T>(List<Task<T>
7998

8099
pendingTasks.Remove(completedTask);
81100
processed++;
101+
102+
Logger.DebugFormat("TaskHelpers.WhenAllOrFirstExceptionAsync<T>: Task completed (Processed={0}/{1}, Remaining={2})",
103+
processed, total, pendingTasks.Count);
82104
}
83105

106+
Logger.DebugFormat("TaskHelpers.WhenAllOrFirstExceptionAsync<T>: All tasks completed (Total={0})", total);
107+
84108
return responses;
85109
}
110+
111+
/// <summary>
112+
/// Executes work items with limited concurrency using a task pool pattern.
113+
/// Creates only as many tasks as the concurrency limit allows, rather than creating
114+
/// all tasks upfront. This reduces memory overhead for large collections.
115+
/// </summary>
116+
/// <remarks>
117+
/// This method provides a clean way to limit concurrent operations without creating
118+
/// all tasks upfront. It maintains a pool of active tasks up to the maxConcurrency limit,
119+
/// replacing completed tasks with new ones until all items are processed.
120+
/// The caller is responsible for implementing failure handling within the processAsync function.
121+
/// </remarks>
122+
/// <typeparam name="T">The type of items to process</typeparam>
123+
/// <param name="items">The collection of items to process</param>
124+
/// <param name="maxConcurrency">Maximum number of concurrent tasks</param>
125+
/// <param name="processAsync">Async function to process each item</param>
126+
/// <param name="cancellationToken">Cancellation token to observe</param>
127+
/// <returns>A task that completes when all items are processed, or throws on first failure</returns>
128+
internal static async Task ForEachWithConcurrencyAsync<T>(
129+
IEnumerable<T> items,
130+
int maxConcurrency,
131+
Func<T, CancellationToken, Task> processAsync,
132+
CancellationToken cancellationToken)
133+
{
134+
var itemList = items as IList<T> ?? items.ToList();
135+
if (itemList.Count == 0)
136+
{
137+
Logger.DebugFormat("TaskHelpers.ForEachWithConcurrencyAsync: No items to process");
138+
return;
139+
}
140+
141+
Logger.DebugFormat("TaskHelpers.ForEachWithConcurrencyAsync: Starting with TotalItems={0}, MaxConcurrency={1}",
142+
itemList.Count, maxConcurrency);
143+
144+
int nextIndex = 0;
145+
var activeTasks = new List<Task>();
146+
147+
// Start initial batch up to concurrency limit
148+
int initialBatchSize = Math.Min(maxConcurrency, itemList.Count);
149+
Logger.DebugFormat("TaskHelpers.ForEachWithConcurrencyAsync: Starting initial batch of {0} tasks", initialBatchSize);
150+
151+
for (int i = 0; i < initialBatchSize; i++)
152+
{
153+
var task = processAsync(itemList[nextIndex++], cancellationToken);
154+
activeTasks.Add(task);
155+
}
156+
157+
// Process completions and start new tasks until all work is done
158+
while (activeTasks.Count > 0)
159+
{
160+
cancellationToken.ThrowIfCancellationRequested();
161+
162+
var completedTask = await Task.WhenAny(activeTasks)
163+
.ConfigureAwait(continueOnCapturedContext: false);
164+
165+
// Propagate exceptions (fail-fast behavior by default)
166+
// Caller's processAsync function should handle failure policy if needed
167+
await completedTask
168+
.ConfigureAwait(continueOnCapturedContext: false);
169+
170+
activeTasks.Remove(completedTask);
171+
172+
int itemsCompleted = nextIndex - activeTasks.Count;
173+
Logger.DebugFormat("TaskHelpers.ForEachWithConcurrencyAsync: Task completed (Active={0}, Completed={1}/{2}, Remaining={3})",
174+
activeTasks.Count, itemsCompleted, itemList.Count, itemList.Count - itemsCompleted);
175+
176+
// Start next task if more work remains
177+
if (nextIndex < itemList.Count)
178+
{
179+
Logger.DebugFormat("TaskHelpers.ForEachWithConcurrencyAsync: Starting next task (Index={0}/{1}, Active={2})",
180+
nextIndex + 1, itemList.Count, activeTasks.Count + 1);
181+
var nextTask = processAsync(itemList[nextIndex++], cancellationToken);
182+
activeTasks.Add(nextTask);
183+
}
184+
}
185+
186+
Logger.DebugFormat("TaskHelpers.ForEachWithConcurrencyAsync: All items processed (Total={0})", itemList.Count);
187+
}
86188
}
87189
}

0 commit comments

Comments
 (0)