Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 116 additions & 0 deletions Azure.Core.TestCommon/AsyncGate.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;

namespace Azure.Core.TestCommon;

/// <summary>
/// A gate for coordinating async operations in tests.
/// </summary>
[ExcludeFromCodeCoverage]
#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member
public class AsyncGate<TIn, TOut>
{
private static readonly TimeSpan DefaultTimeout = TimeSpan.FromSeconds(10);
private readonly object _sync = new();
private TaskCompletionSource<TIn> _signalTaskCompletionSource = new(TaskCreationOptions.RunContinuationsAsynchronously);
private TaskCompletionSource<TOut> _releaseTaskCompletionSource = new(TaskCreationOptions.RunContinuationsAsynchronously);

/// <summary>
/// Waits for a signal with the default timeout.
/// </summary>
public Task<TIn> WaitForSignal()
{
return TimeoutAfter(_signalTaskCompletionSource.Task, DefaultTimeout);
}

/// <summary>
/// Cycles through waiting for signal and releasing.
/// </summary>
public async Task<TIn> Cycle(TOut value = default!)
{
var signal = await WaitForSignal();
Release(value);
return signal;
}

/// <summary>
/// Cycles through waiting for signal and releasing with an exception.
/// </summary>
public async Task<TIn> CycleWithException(Exception exception)
{
var signal = await WaitForSignal();
ReleaseWithException(exception);
return signal;
}

/// <summary>
/// Releases the gate with a value.
/// </summary>
public void Release(TOut value = default!)
{
lock (_sync)
{
Reset().SetResult(value);
}
}

/// <summary>
/// Releases the gate with an exception.
/// </summary>
public void ReleaseWithException(Exception exception)
{
lock (_sync)
{
Reset().SetException(exception);
}
}

private TaskCompletionSource<TOut> Reset()
{
lock (_sync)
{
if (!_signalTaskCompletionSource.Task.IsCompleted)
{
throw new InvalidOperationException("No await call to release");
}

var releaseTaskCompletionSource = _releaseTaskCompletionSource;
_releaseTaskCompletionSource = new TaskCompletionSource<TOut>(TaskCreationOptions.RunContinuationsAsynchronously);
_signalTaskCompletionSource = new TaskCompletionSource<TIn>(TaskCreationOptions.RunContinuationsAsynchronously);
return releaseTaskCompletionSource;
}
}

/// <summary>
/// Waits for the gate to be released.
/// </summary>
public Task<TOut> WaitForRelease(TIn value = default!)
{
lock (_sync)
{
_signalTaskCompletionSource.SetResult(value);
return TimeoutAfter(_releaseTaskCompletionSource.Task, DefaultTimeout);
}
}

private static async Task<T> TimeoutAfter<T>(Task<T> task, TimeSpan timeout)
{
if (task.IsCompleted || Debugger.IsAttached)
{
return await task;
}

using var cts = new CancellationTokenSource();
if (task == await Task.WhenAny(task, Task.Delay(timeout, cts.Token)))
{
await cts.CancelAsync();
return await task;
}

throw new TimeoutException($"Operation timed out after {timeout}");
}
}
#pragma warning restore CS1591
100 changes: 100 additions & 0 deletions Azure.Core.TestCommon/AsyncValidatingStream.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

using System.Diagnostics.CodeAnalysis;

namespace Azure.Core.TestCommon;

/// <summary>
/// A stream wrapper that validates sync/async usage consistency.
/// </summary>
[ExcludeFromCodeCoverage]
#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member
internal class AsyncValidatingStream : Stream
{
private readonly bool _isAsync;
private readonly Stream _innerStream;

public AsyncValidatingStream(bool isAsync, Stream innerStream)
{
_isAsync = isAsync;
_innerStream = innerStream;
}

public override void Flush()
{
Validate(false);
_innerStream.Flush();
}

public override Task FlushAsync(CancellationToken cancellationToken)
{
Validate(true);
return _innerStream.FlushAsync(cancellationToken);
}

private void Validate(bool isAsync)
{
if (isAsync != _isAsync)
{
throw new InvalidOperationException(
$"All stream calls were expected to be {(_isAsync ? "async" : "sync")} but were {(isAsync ? "async" : "sync")}");
}
}

public override int Read(byte[] buffer, int offset, int count)
{
Validate(false);
return _innerStream.Read(buffer, offset, count);
}

public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken = default)
{
Validate(true);
return _innerStream.ReadAsync(buffer, offset, count, cancellationToken);
}

public override long Seek(long offset, SeekOrigin origin)
{
return _innerStream.Seek(offset, origin);
}

public override void SetLength(long value)
{
_innerStream.SetLength(value);
}

public override void Write(byte[] buffer, int offset, int count)
{
Validate(false);
_innerStream.Write(buffer, offset, count);
}

public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken = default)
{
Validate(true);
return _innerStream.WriteAsync(buffer, offset, count, cancellationToken);
}

public override Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken)
{
Validate(true);
return _innerStream.CopyToAsync(destination, bufferSize, cancellationToken);
}

public override void Close()
{
_innerStream.Close();
}

public override bool CanRead => _innerStream.CanRead;
public override bool CanSeek => _innerStream.CanSeek;
public override bool CanWrite => _innerStream.CanWrite;
public override long Length => _innerStream.Length;
public override long Position
{
get => _innerStream.Position;
set => _innerStream.Position = value;
}
}
#pragma warning restore CS1591
25 changes: 25 additions & 0 deletions Azure.Core.TestCommon/Azure.Core.TestCommon.csproj
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
<Project Sdk="Microsoft.NET.Sdk">

<!--
Curated subset of Azure.Core.TestFramework for testing Azure SDK clients.
Source: https://github.com/Azure/azure-sdk-for-net/tree/main/sdk/core/Azure.Core.TestFramework/src
Copied: 2026-03-06 (Azure.Core 1.51.1 era)

The full upstream framework is not usable as a submodule or NuGet package
because it depends on the Azure SDK's internal eng/ build system and
internal-only packages. This project contains only the mock types needed
for unit testing (MockResponse, MockTransport, MockRequest, etc.).
-->

<PropertyGroup>
<TargetFramework>net8.0</TargetFramework>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
<IsPackable>false</IsPackable>
</PropertyGroup>

<ItemGroup>
<PackageReference Include="Azure.Core" />
</ItemGroup>

</Project>
105 changes: 105 additions & 0 deletions Azure.Core.TestCommon/DictionaryHeaders.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

using System.Diagnostics.CodeAnalysis;

namespace Azure.Core.TestCommon;

/// <summary>
/// An implementation for manipulating headers on Request.
/// </summary>
[ExcludeFromCodeCoverage]
internal class DictionaryHeaders
{
private readonly Dictionary<string, object> _headers = new(StringComparer.OrdinalIgnoreCase);

/// <summary>
/// Adds a header value to the header collection.
/// </summary>
public void AddHeader(string name, string value)
{
if (!_headers.TryGetValue(name, out object? objValue))
{
_headers[name] = value;
}
else
{
if (objValue is List<string> values)
{
values.Add(value);
}
else
{
_headers[name] = new List<string> { (objValue as string)!, value };
}
}
}

/// <summary>
/// Sets a header value, replacing any existing values.
/// </summary>
public void SetHeader(string name, string value)
{
_headers[name] = value;
}

/// <summary>
/// Returns header value if the header is stored in the collection.
/// </summary>
public bool TryGetHeader(string name, out string? value)
{
if (_headers.TryGetValue(name, out object? objValue))
{
value = objValue is List<string> values ? JoinHeaderValue(values) : objValue as string;
return true;
}

value = null;
return false;
}

/// <summary>
/// Returns header values if the header is stored in the collection.
/// </summary>
public bool TryGetHeaderValues(string name, out IEnumerable<string>? values)
{
if (_headers.TryGetValue(name, out object? objValue))
{
values = objValue is List<string> valuesList
? valuesList
: new List<string> { (objValue as string)! };
return true;
}

values = null;
return false;
}

/// <summary>
/// Removes a header from the collection.
/// </summary>
public bool RemoveHeader(string name)
{
return _headers.Remove(name);
}

/// <summary>
/// Enumerates all headers in the collection.
/// </summary>
public IEnumerable<HttpHeader> EnumerateHeaders()
{
foreach (var kvp in _headers)
{
if (kvp.Value is List<string> values)
{
yield return new HttpHeader(kvp.Key, JoinHeaderValue(values));
}
else
{
yield return new HttpHeader(kvp.Key, (kvp.Value as string)!);
}
}
}

private static string JoinHeaderValue(IEnumerable<string> values) => string.Join(",", values);
}
Loading
Loading