diff --git a/extensions/src/AWSSDK.Extensions.Bedrock.MEAI/BedrockChatClient.cs b/extensions/src/AWSSDK.Extensions.Bedrock.MEAI/BedrockChatClient.cs index bfe33dbda368..34f6fec1e9d5 100644 --- a/extensions/src/AWSSDK.Extensions.Bedrock.MEAI/BedrockChatClient.cs +++ b/extensions/src/AWSSDK.Extensions.Bedrock.MEAI/BedrockChatClient.cs @@ -20,6 +20,7 @@ using System; using System.Collections.Generic; using System.Diagnostics; +using System.IO; using System.Linq; using System.Runtime.CompilerServices; using System.Text; @@ -35,6 +36,11 @@ internal sealed partial class BedrockChatClient : IChatClient /// A default logger to use. private static readonly ILogger DefaultLogger = Logger.GetLogger(typeof(BedrockChatClient)); + /// The name used for the synthetic tool that enforces response format. + private const string ResponseFormatToolName = "generate_response"; + /// The description used for the synthetic tool that enforces response format. + private const string ResponseFormatToolDescription = "Generate response in specified format"; + /// The wrapped instance. private readonly IAmazonBedrockRuntime _runtime; /// Default model ID to use when no model is specified in the request. @@ -63,6 +69,13 @@ public void Dispose() } /// + /// + /// When is specified, the model must support + /// the ToolChoice feature. Models without this support will return an error from the Bedrock API + /// (typically with ErrorCode "ValidationException"). + /// If the model fails to return the expected structured output, + /// is thrown. + /// public async Task GetResponseAsync( IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) { @@ -79,7 +92,7 @@ public async Task GetResponseAsync( request.InferenceConfig = CreateInferenceConfiguration(request.InferenceConfig, options); request.AdditionalModelRequestFields = CreateAdditionalModelRequestFields(request.AdditionalModelRequestFields, options); - var response = await _runtime.ConverseAsync(request, cancellationToken).ConfigureAwait(false); + ConverseResponse response = await _runtime.ConverseAsync(request, cancellationToken).ConfigureAwait(false); ChatMessage result = new() { @@ -89,6 +102,48 @@ public async Task GetResponseAsync( MessageId = Guid.NewGuid().ToString("N"), }; + // Check if ResponseFormat was used and extract structured content + // When ResponseFormat is active, Bedrock returns the JSON response as a ToolUseBlock + // within the Message.Content array, rather than as plain text. + bool usingResponseFormat = options?.ResponseFormat is ChatResponseFormatJson; + if (usingResponseFormat) + { + // Search the response's ContentBlocks for our synthetic tool's input + // (ConverseResponse.Output.Message.Content contains a list of ContentBlock objects) + Document toolInput = GetResponseFormatToolInput(response.Output?.Message); + if (toolInput != default) + { + string structuredContent = DocumentToJsonString(toolInput); + + // Return only the structured JSON content, not the ToolUseBlock metadata + // This gives the user clean JSON conforming to their schema + result.Contents.Add(new TextContent(structuredContent) { RawRepresentation = response.Output?.Message }); + + if (DocumentToDictionary(response.AdditionalModelResponseFields) is { } responseFieldsDict) + { + result.AdditionalProperties = new(responseFieldsDict); + } + + return new(result) + { + CreatedAt = result.CreatedAt, + FinishReason = response.StopReason is not null ? GetChatFinishReason(response.StopReason) : null, + Usage = response.Usage is TokenUsage tokenUsage ? CreateUsageDetails(tokenUsage) : null, + RawRepresentation = response, + }; + } + else + { + // Model succeeded but did not return expected structured output + throw new InvalidOperationException( + $"Model '{request.ModelId}' did not return structured output as requested. " + + "This may indicate the model refused to follow the tool use instruction, " + + "the schema was too complex, or the prompt conflicted with the requirement. " + + $"StopReason: {response.StopReason?.Value ?? "unknown"}."); + } + } + + // Normal content processing when not using ResponseFormat or extraction failed if (response.Output?.Message?.Content is { } contents) { foreach (var content in contents) @@ -182,6 +237,20 @@ public async IAsyncEnumerable GetStreamingResponseAsync( throw new ArgumentNullException(nameof(messages)); } + // ResponseFormat is not supported for streaming because it requires forcing a specific + // tool via ToolChoice. Since we create a synthetic tool for ResponseFormat and set + // toolChoice to force its use, this conflicts with the dynamic nature of streaming responses + // where tool calls may be interleaved with text content. + // + // For more information about tool use in streaming, see: + // https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html#conversation-inference-supported-models-features + if (options?.ResponseFormat is ChatResponseFormatJson) + { + throw new NotSupportedException( + "ResponseFormat is not yet supported for streaming responses with Amazon Bedrock. " + + "Please use GetResponseAsync for structured output."); + } + ConverseStreamRequest request = options?.RawRepresentationFactory?.Invoke(this) as ConverseStreamRequest ?? new(); request.ModelId ??= options?.ModelId ?? _modelId; request.Messages = CreateMessages(request.Messages, messages); @@ -794,84 +863,220 @@ private static Document ToDocument(JsonElement json) } } - /// Creates an from the specified options. + /// Creates a from the specified options. private static ToolConfiguration? CreateToolConfig(ToolConfiguration? toolConfig, ChatOptions? options) { if (options?.Tools is { Count: > 0 } tools) { - foreach (AITool tool in tools) + toolConfig = AddUserTools(toolConfig, tools); + } + + if (options?.ResponseFormat is ChatResponseFormatJson jsonFormat) + { + toolConfig = AddResponseFormatTool(toolConfig, jsonFormat); + } + + if (toolConfig?.Tools is { Count: > 0 } && toolConfig.ToolChoice is null) + { + toolConfig = ApplyToolMode(toolConfig, options); + } + + return toolConfig; + } + + /// Adds user-provided tools to the tool configuration. + private static ToolConfiguration AddUserTools(ToolConfiguration? toolConfig, IList tools) + { + foreach (AITool tool in tools) + { + if (tool is not AIFunctionDeclaration f) { - if (tool is not AIFunctionDeclaration f) - { - continue; - } + continue; + } - Document inputs = default; - List required = []; + Document inputs = default; + List required = []; - if (f.JsonSchema.TryGetProperty("properties", out JsonElement properties)) + if (f.JsonSchema.TryGetProperty("properties", out JsonElement properties)) + { + foreach (JsonProperty parameter in properties.EnumerateObject()) { - foreach (JsonProperty parameter in properties.EnumerateObject()) - { - inputs.Add(parameter.Name, ToDocument(parameter.Value)); - } + inputs.Add(parameter.Name, ToDocument(parameter.Value)); } + } - if (f.JsonSchema.TryGetProperty("required", out JsonElement requiredProperties)) + if (f.JsonSchema.TryGetProperty("required", out JsonElement requiredProperties)) + { + foreach (JsonElement requiredProperty in requiredProperties.EnumerateArray()) { - foreach (JsonElement requiredProperty in requiredProperties.EnumerateArray()) - { - required.Add(requiredProperty.GetString()); - } + required.Add(requiredProperty.GetString()); } + } - Dictionary schemaDictionary = new() - { - ["type"] = new Document("object"), - }; + Dictionary schemaDictionary = new() + { + ["type"] = new Document("object"), + }; - if (inputs != default) - { - schemaDictionary["properties"] = inputs; - } + if (inputs != default) + { + schemaDictionary["properties"] = inputs; + } - if (required.Count > 0) - { - schemaDictionary["required"] = new Document(required); - } + if (required.Count > 0) + { + schemaDictionary["required"] = new Document(required); + } - toolConfig ??= new(); - toolConfig.Tools ??= []; - toolConfig.Tools.Add(new() + toolConfig ??= new(); + toolConfig.Tools ??= []; + toolConfig.Tools.Add(new() + { + ToolSpec = new ToolSpecification() { - ToolSpec = new ToolSpecification() + Name = f.Name, + Description = !string.IsNullOrEmpty(f.Description) ? f.Description : f.Name, + InputSchema = new() { - Name = f.Name, - Description = !string.IsNullOrEmpty(f.Description) ? f.Description : f.Name, - InputSchema = new() - { - Json = new(schemaDictionary) - }, + Json = new(schemaDictionary) }, - }); - } + }, + }); } - if (toolConfig?.Tools is { Count: > 0 } && toolConfig.ToolChoice is null) + return toolConfig!; + } + + /// Adds the ResponseFormat synthetic tool to enforce structured output. + private static ToolConfiguration AddResponseFormatTool(ToolConfiguration? toolConfig, ChatResponseFormatJson jsonFormat) + { + // Check for conflict with user-provided tools + // Bedrock's ToolChoice can only force ONE specific tool at a time. Since ResponseFormat + // works by creating a synthetic tool and forcing its use via toolChoice, we cannot + // simultaneously support user-provided tools (which may have their own toolChoice requirements). + // This is a Bedrock API constraint, not an SDK limitation. + if (toolConfig?.Tools?.Count > 0) { - switch (options!.ToolMode) + throw new ArgumentException( + "ResponseFormat cannot be used with Tools in Amazon Bedrock. " + + "ResponseFormat uses Bedrock's tool mechanism for structured output, " + + "which conflicts with user-provided tools."); + } + + // Create the synthetic tool with the schema from ResponseFormat + toolConfig ??= new(); + toolConfig.Tools ??= []; + + // Parse the schema if provided, otherwise create an empty object schema + Document schemaDoc; + if (jsonFormat.Schema.HasValue) + { + // Schema is already a JsonElement (parsed JSON), convert directly to Document + schemaDoc = ToDocument(jsonFormat.Schema.Value); + } + else + { + // For JSON mode without schema, create a generic object schema + schemaDoc = new Document(new Dictionary { - case RequiredChatToolMode r: - toolConfig.ToolChoice = !string.IsNullOrWhiteSpace(r.RequiredFunctionName) ? - new ToolChoice() { Tool = new() { Name = r.RequiredFunctionName } } : - new ToolChoice() { Any = new() }; - break; + ["type"] = new Document("object"), + ["additionalProperties"] = new Document(true) + }); + } + + toolConfig.Tools.Add(new Tool + { + ToolSpec = new ToolSpecification + { + Name = ResponseFormatToolName, + Description = jsonFormat.SchemaDescription ?? ResponseFormatToolDescription, + InputSchema = new ToolInputSchema + { + Json = schemaDoc + } } + }); + + // Force the model to use the synthetic tool + toolConfig.ToolChoice = new ToolChoice { Tool = new() { Name = ResponseFormatToolName } }; + + return toolConfig; + } + + /// Applies ToolMode configuration to set ToolChoice if not already set. + private static ToolConfiguration ApplyToolMode(ToolConfiguration toolConfig, ChatOptions? options) + { + switch (options!.ToolMode) + { + case RequiredChatToolMode r: + toolConfig.ToolChoice = !string.IsNullOrWhiteSpace(r.RequiredFunctionName) ? + new ToolChoice() { Tool = new() { Name = r.RequiredFunctionName } } : + new ToolChoice() { Any = new() }; + break; } return toolConfig; } + /// + /// Gets the tool input from the synthetic ResponseFormat tool, if present. + /// + /// The Bedrock Message object containing the response ContentBlocks. + /// + /// The tool input if found, otherwise default (Document is a struct). + /// + /// + /// + /// Bedrock returns responses as ConverseResponse.Output.Message.Content, which is a list of ContentBlock objects. + /// Each ContentBlock can contain one of several types: Text, ToolUse, Image, Video, Document, etc. + /// + /// + /// When ResponseFormat is specified, we create a synthetic tool ("generate_response") and force the model to use it + /// via ToolChoice. The model returns its structured JSON response as a ToolUseBlock within the ContentBlock list, + /// rather than as plain Text. + /// + /// + /// This method searches through the ContentBlock list to find the ToolUseBlock matching our synthetic tool name, + /// then extracts the Document from toolUse.Input. This Document contains the structured JSON conforming to the + /// user's schema. + /// + /// + private static Document GetResponseFormatToolInput(Message? message) + { + if (message?.Content is null) + { + return default; + } + + // Message.Content is a List - each block can be Text, ToolUse, Image, etc. + // We're searching for the ToolUseBlock that matches our synthetic tool name. + foreach (var content in message.Content) + { + if (content.ToolUse is ToolUseBlock toolUse && + toolUse.Name == ResponseFormatToolName && + toolUse.Input != default) + { + return toolUse.Input; + } + } + + return default; + } + + /// + /// Converts a to a JSON string using the SDK's standard DocumentMarshaller. + /// Note: Document is a struct (value type), so circular references are structurally impossible. + /// + private static string DocumentToJsonString(Document document) + { + using var stream = new MemoryStream(); + using (var writer = new Utf8JsonWriter(stream, new JsonWriterOptions { Indented = false })) + { + Amazon.Runtime.Documents.Internal.Transform.DocumentMarshaller.Instance.Write(writer, document); + } + return Encoding.UTF8.GetString(stream.ToArray()); + } + /// Creates an from the specified options. private static InferenceConfiguration CreateInferenceConfiguration(InferenceConfiguration config, ChatOptions? options) { diff --git a/extensions/test/BedrockMEAITests/BedrockChatClientTests.cs b/extensions/test/BedrockMEAITests/BedrockChatClientTests.cs index 8f5099c973d8..b9dff182a517 100644 --- a/extensions/test/BedrockMEAITests/BedrockChatClientTests.cs +++ b/extensions/test/BedrockMEAITests/BedrockChatClientTests.cs @@ -1,11 +1,44 @@ -using Microsoft.Extensions.AI; +using Amazon.BedrockRuntime.Model; +using Amazon.Runtime; +using Amazon.Runtime.Documents; +using Amazon.Runtime.Internal; +using Amazon.Runtime.Internal.Transform; +using Microsoft.Extensions.AI; +using Moq; using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Net; +using System.Net.Http; +using System.Reflection; +using System.Text; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; using Xunit; namespace Amazon.BedrockRuntime; +// Simple test implementation of AIFunctionDeclaration +internal sealed class TestAIFunction : AIFunctionDeclaration +{ + public TestAIFunction(string name, string description, JsonElement jsonSchema) + { + Name = name; + Description = description; + JsonSchema = jsonSchema; + } + + public override string Name { get; } + public override string Description { get; } + public override JsonElement JsonSchema { get; } +} + public class BedrockChatClientTests { + #region Basic Client Tests + [Fact] [Trait("UnitTest", "BedrockRuntime")] public void AsIChatClient_InvalidArguments_Throws() @@ -19,8 +52,8 @@ public void AsIChatClient_InvalidArguments_Throws() [InlineData("claude")] public void AsIChatClient_ReturnsInstance(string modelId) { - IAmazonBedrockRuntime runtime = new AmazonBedrockRuntimeClient("awsAccessKeyId", "awsSecretAccessKey", RegionEndpoint.USEast1); - IChatClient client = runtime.AsIChatClient(modelId); + var mockRuntime = new Mock(); + IChatClient client = mockRuntime.Object.AsIChatClient(modelId); Assert.NotNull(client); Assert.Equal("aws.bedrock", client.GetService()?.ProviderName); @@ -31,17 +64,1136 @@ public void AsIChatClient_ReturnsInstance(string modelId) [Trait("UnitTest", "BedrockRuntime")] public void AsIChatClient_GetService() { - IAmazonBedrockRuntime runtime = new AmazonBedrockRuntimeClient("awsAccessKeyId", "awsSecretAccessKey", RegionEndpoint.USEast1); - IChatClient client = runtime.AsIChatClient(); + var mockRuntime = new Mock(); + IChatClient client = mockRuntime.Object.AsIChatClient(); - Assert.Same(runtime, client.GetService()); - Assert.Same(runtime, client.GetService()); + Assert.Same(mockRuntime.Object, client.GetService()); Assert.Same(client, client.GetService()); - Assert.Null(client.GetService()); - - Assert.Null(client.GetService("key")); Assert.Null(client.GetService("key")); - Assert.Null(client.GetService("key")); } + + #endregion + + #region ResponseFormat Tests + + [Fact] + [Trait("UnitTest", "BedrockRuntime")] + public async Task ResponseFormat_Json_WithSchema_CreatesSyntheticToolWithCorrectSchema() + { + // Arrange + var mockRuntime = new Mock(); + ConverseRequest capturedRequest = null; + + mockRuntime + .Setup(x => x.ConverseAsync(It.IsAny(), It.IsAny())) + .Callback((req, ct) => capturedRequest = req) + .ReturnsAsync(new ConverseResponse + { + Output = new ConverseOutput + { + Message = new Message + { + Role = ConversationRole.Assistant, + Content = new List + { + new ContentBlock + { + ToolUse = new ToolUseBlock + { + ToolUseId = "test-id", + Name = "generate_response", + Input = new Document(new Dictionary + { + ["name"] = new Document("John Doe"), + ["age"] = new Document(30) + }) + } + } + } + } + }, + StopReason = new StopReason("tool_use") + }); + + var client = mockRuntime.Object.AsIChatClient("claude-3"); + var messages = new[] { new ChatMessage(ChatRole.User, "Test") }; + + var schemaJson = """ + { + "type": "object", + "properties": { + "name": { "type": "string" }, + "age": { "type": "number" } + }, + "required": ["name"] + } + """; + var schemaElement = JsonDocument.Parse(schemaJson).RootElement; + var options = new ChatOptions + { + ResponseFormat = ChatResponseFormat.ForJsonSchema(schemaElement, + schemaName: "PersonSchema", + schemaDescription: "A person object") + }; + + // Act + await client.GetResponseAsync(messages, options); + + // Assert + Assert.NotNull(capturedRequest); + var tool = capturedRequest.ToolConfig.Tools[0]; + Assert.Equal("generate_response", tool.ToolSpec.Name); + Assert.Equal("A person object", tool.ToolSpec.Description); + + // Verify schema structure matches input + var schema = tool.ToolSpec.InputSchema.Json; + Assert.True(schema.IsDictionary()); + var schemaDict = schema.AsDictionary(); + + Assert.Equal("object", schemaDict["type"].AsString()); + Assert.True(schemaDict.ContainsKey("properties")); + + var properties = schemaDict["properties"].AsDictionary(); + Assert.True(properties.ContainsKey("name")); + Assert.True(properties.ContainsKey("age")); + Assert.Equal("string", properties["name"].AsDictionary()["type"].AsString()); + Assert.Equal("number", properties["age"].AsDictionary()["type"].AsString()); + + Assert.True(schemaDict.ContainsKey("required")); + var required = schemaDict["required"].AsList(); + Assert.Single(required); + Assert.Equal("name", required[0].AsString()); + + // Verify the mock was called + mockRuntime.Verify(x => x.ConverseAsync(It.IsAny(), It.IsAny()), Times.Once); + } + + [Fact] + [Trait("UnitTest", "BedrockRuntime")] + public async Task ResponseFormat_Json_ModelReturnsToolUse_ExtractsJsonCorrectly() + { + // Arrange + var mockRuntime = new Mock(); + + // Setup mock to return tool use with structured data + mockRuntime + .Setup(x => x.ConverseAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new ConverseResponse + { + Output = new ConverseOutput + { + Message = new Message + { + Role = ConversationRole.Assistant, + Content = new List + { + new ContentBlock + { + ToolUse = new ToolUseBlock + { + ToolUseId = "test-id", + Name = "generate_response", + Input = new Document(new Dictionary + { + ["city"] = new Document("Seattle"), + ["temperature"] = new Document(72), + ["conditions"] = new Document("sunny") + }) + } + } + } + } + }, + StopReason = new StopReason("tool_use"), + Usage = new TokenUsage { InputTokens = 10, OutputTokens = 20, TotalTokens = 30 } + }); + + var client = mockRuntime.Object.AsIChatClient("claude-3"); + var messages = new[] { new ChatMessage(ChatRole.User, "Get weather") }; + var options = new ChatOptions { ResponseFormat = ChatResponseFormat.Json }; + + // Act + var response = await client.GetResponseAsync(messages, options); + + // Assert + Assert.NotNull(response); + Assert.NotNull(response.Text); + + // Parse the JSON to verify structure + var json = JsonDocument.Parse(response.Text); + Assert.Equal("Seattle", json.RootElement.GetProperty("city").GetString()); + Assert.Equal(72, json.RootElement.GetProperty("temperature").GetInt32()); + Assert.Equal("sunny", json.RootElement.GetProperty("conditions").GetString()); + } + + [Fact] + [Trait("UnitTest", "BedrockRuntime")] + public async Task ResponseFormat_Json_WithTools_ThrowsArgumentException() + { + // Arrange + var mockRuntime = new Mock(); + var client = mockRuntime.Object.AsIChatClient("claude-3"); + var messages = new[] { new ChatMessage(ChatRole.User, "Test") }; + + // Create test tool + var tool = new TestAIFunction("test", "Test tool", JsonDocument.Parse("{}").RootElement); + + var options = new ChatOptions + { + ResponseFormat = ChatResponseFormat.Json, + Tools = new[] { tool } + }; + + // Act & Assert + await Assert.ThrowsAsync(async () => + await client.GetResponseAsync(messages, options)); + } + + [Fact] + [Trait("UnitTest", "BedrockRuntime")] + public async Task ResponseFormat_Json_UnsupportedModel_ThrowsValidationException() + { + // Arrange + var mockRuntime = new Mock(); + + // Setup mock to throw BedrockRuntimeException with toolChoice error + mockRuntime + .Setup(x => x.ConverseAsync(It.IsAny(), It.IsAny())) + .ThrowsAsync(new AmazonBedrockRuntimeException("ValidationException: toolChoice is not supported by this model") + { + ErrorCode = "ValidationException" + }); + + var client = mockRuntime.Object.AsIChatClient("titan"); + var messages = new[] { new ChatMessage(ChatRole.User, "Test") }; + var options = new ChatOptions { ResponseFormat = ChatResponseFormat.Json }; + + // Act & Assert + var ex = await Assert.ThrowsAsync(async () => + await client.GetResponseAsync(messages, options)); + + Assert.Equal("ValidationException", ex.ErrorCode); + Assert.Contains("toolChoice is not supported", ex.Message); + } + + [Fact] + [Trait("UnitTest", "BedrockRuntime")] + public async Task ResponseFormat_Json_ModelReturnsText_ThrowsInvalidOperationException() + { + // Arrange - Model returns text instead of tool_use + var mockRuntime = new Mock(); + + mockRuntime + .Setup(x => x.ConverseAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new ConverseResponse + { + Output = new ConverseOutput + { + Message = new Message + { + Role = ConversationRole.Assistant, + Content = new List + { + new ContentBlock { Text = "Here is some text" } + } + } + }, + StopReason = new StopReason("end_turn") + }); + + var client = mockRuntime.Object.AsIChatClient("claude-3"); + var messages = new[] { new ChatMessage(ChatRole.User, "Generate data") }; + var options = new ChatOptions { ResponseFormat = ChatResponseFormat.Json }; + + // Act & Assert + var ex = await Assert.ThrowsAsync(async () => + await client.GetResponseAsync(messages, options)); + + Assert.Contains("did not return structured output", ex.Message); + Assert.Contains("end_turn", ex.Message); + } + + [Fact] + [Trait("UnitTest", "BedrockRuntime")] + public async Task ResponseFormat_Json_WrongToolName_ThrowsInvalidOperationException() + { + // Arrange - Model uses wrong tool name + var mockRuntime = new Mock(); + + mockRuntime + .Setup(x => x.ConverseAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new ConverseResponse + { + Output = new ConverseOutput + { + Message = new Message + { + Role = ConversationRole.Assistant, + Content = new List + { + new ContentBlock + { + ToolUse = new ToolUseBlock + { + ToolUseId = "wrong-id", + Name = "wrong_tool_name", + Input = new Document(new Dictionary + { + ["data"] = new Document("value") + }) + } + } + } + } + }, + StopReason = new StopReason("tool_use") + }); + + var client = mockRuntime.Object.AsIChatClient("claude-3"); + var messages = new[] { new ChatMessage(ChatRole.User, "Generate data") }; + var options = new ChatOptions { ResponseFormat = ChatResponseFormat.Json }; + + // Act & Assert + var ex = await Assert.ThrowsAsync(async () => + await client.GetResponseAsync(messages, options)); + + Assert.Contains("did not return structured output", ex.Message); + } + + [Fact] + [Trait("UnitTest", "BedrockRuntime")] + public async Task ResponseFormat_Json_EmptyToolInput_ReturnsEmptyJson() + { + // Arrange - Tool with empty input + var mockRuntime = new Mock(); + + mockRuntime + .Setup(x => x.ConverseAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new ConverseResponse + { + Output = new ConverseOutput + { + Message = new Message + { + Role = ConversationRole.Assistant, + Content = new List + { + new ContentBlock + { + ToolUse = new ToolUseBlock + { + ToolUseId = "empty-id", + Name = "generate_response", + Input = new Document(new Dictionary()) + } + } + } + } + }, + StopReason = new StopReason("tool_use") + }); + + var client = mockRuntime.Object.AsIChatClient("claude-3"); + var messages = new[] { new ChatMessage(ChatRole.User, "Generate data") }; + var options = new ChatOptions { ResponseFormat = ChatResponseFormat.Json }; + + // Act + var response = await client.GetResponseAsync(messages, options); + + // Assert - Empty object is valid JSON + Assert.NotNull(response.Text); + var json = JsonDocument.Parse(response.Text); + Assert.Equal(JsonValueKind.Object, json.RootElement.ValueKind); + } + + [Fact] + [Trait("UnitTest", "BedrockRuntime")] + public async Task ResponseFormat_Json_NullToolInput_ThrowsInvalidOperationException() + { + // Arrange - ToolUse with default/null Input (edge case: malformed API response) + var mockRuntime = new Mock(); + + mockRuntime + .Setup(x => x.ConverseAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new ConverseResponse + { + Output = new ConverseOutput + { + Message = new Message + { + Role = ConversationRole.Assistant, + Content = new List + { + new ContentBlock + { + ToolUse = new ToolUseBlock + { + ToolUseId = "null-input-id", + Name = "generate_response", + Input = default // Default/null Document + } + } + } + } + }, + StopReason = new StopReason("tool_use") + }); + + var client = mockRuntime.Object.AsIChatClient("claude-3"); + var messages = new[] { new ChatMessage(ChatRole.User, "Generate data") }; + var options = new ChatOptions { ResponseFormat = ChatResponseFormat.Json }; + + // Act & Assert - Should throw InvalidOperationException, not NullReferenceException + var ex = await Assert.ThrowsAsync(async () => + await client.GetResponseAsync(messages, options)); + + Assert.Contains("did not return structured output", ex.Message); + } + + #endregion } + +/// +/// Tests using HTTP-layer mocking to test actual Converse API response scenarios. +/// This allows testing beyond the happy path with realistic service responses. +/// Based on Peter's (peterrsongg) suggestion to test different response structures. +/// +public class BedrockChatClientHttpMockedTests : IClassFixture +{ + private readonly HttpMockFixture _fixture; + + public BedrockChatClientHttpMockedTests(HttpMockFixture fixture) + { + _fixture = fixture; + } + + /// + /// Helper method to inject stubbed web response data into a request's state + /// + private static void InjectMockedResponse(ConverseRequest request, StubWebResponseData webResponseData) + { + var interfaceType = typeof(IAmazonWebServiceRequest); + var requestStatePropertyInfo = interfaceType.GetProperty("RequestState"); + var requestState = (Dictionary)requestStatePropertyInfo.GetValue(request); + requestState["response"] = webResponseData; + } + + #region HTTP Mocking Infrastructure (Based on Peter's Working Code) + + /// + /// Pipeline customizer that replaces the HTTP handler with a mock implementation + /// + private class MockPipelineCustomizer : IRuntimePipelineCustomizer + { + public string UniqueName => "BedrockMEAIMockPipeline"; + + public void Customize(Type type, RuntimePipeline pipeline) + { +#if NETFRAMEWORK + // On .NET Framework, use Stream + pipeline.ReplaceHandler>( + new HttpHandler(new MockHttpRequestFactory(), new object())); +#else + // On .NET Core/.NET 5+, use HttpContent + pipeline.ReplaceHandler>( + new HttpHandler(new MockHttpRequestFactory(), new object())); +#endif + } + } + + /// + /// Factory for creating mock HTTP requests + /// +#if NETFRAMEWORK + private class MockHttpRequestFactory : IHttpRequestFactory + { + public IHttpRequest CreateHttpRequest(Uri requestUri) + { + return new MockHttpRequest(requestUri); + } +#else + private class MockHttpRequestFactory : IHttpRequestFactory + { + public IHttpRequest CreateHttpRequest(Uri requestUri) + { + return new MockHttpRequest(requestUri); + } +#endif + + public void Dispose() + { + // No resources to dispose + } + } + + /// + /// Mock HTTP request that retrieves stubbed response data from request state + /// +#if NETFRAMEWORK + private class MockHttpRequest : IHttpRequest +#else + private class MockHttpRequest : IHttpRequest +#endif + { + private IWebResponseData _webResponseData; + + public MockHttpRequest(Uri requestUri) + { + RequestUri = requestUri; + } + + public string Method { get; set; } + public Uri RequestUri { get; set; } + public Version HttpProtocolVersion { get; set; } + + public void ConfigureRequest(IRequestContext requestContext) + { + // Retrieve the stubbed response from request state + // This is the critical line that Peter identified (line 60 in his comment) + var request = requestContext.OriginalRequest as IAmazonWebServiceRequest; + if (request != null && request.RequestState.ContainsKey("response")) + { + _webResponseData = request.RequestState["response"] as IWebResponseData; + } + } + + public void SetRequestHeaders(IDictionary headers) + { + // Not needed for mock + } + +#if NETFRAMEWORK + public Stream GetRequestContent() + { + return new MemoryStream(); + } +#else + public HttpContent GetRequestContent() + { + return null; + } +#endif + + public IWebResponseData GetResponse() + { + return GetResponseAsync(CancellationToken.None).Result; + } + + public Task GetResponseAsync(CancellationToken cancellationToken) + { + return Task.FromResult(_webResponseData); + } + +#if NETFRAMEWORK + public void WriteToRequestBody(Stream requestContent, Stream contentStream, + IDictionary contentHeaders, IRequestContext requestContext) + { + // Not needed for mock + } + + public void WriteToRequestBody(Stream requestContent, byte[] content, + IDictionary contentHeaders) + { + // Not needed for mock + } + + public Task WriteToRequestBodyAsync(Stream requestContent, Stream contentStream, + IDictionary contentHeaders, IRequestContext requestContext) + { + return Task.CompletedTask; + } + + public Task WriteToRequestBodyAsync(Stream requestContent, byte[] content, + IDictionary contentHeaders, CancellationToken cancellationToken = default) + { + return Task.CompletedTask; + } +#else + public void WriteToRequestBody(HttpContent requestContent, Stream contentStream, + IDictionary contentHeaders, IRequestContext requestContext) + { + // Not needed for mock + } + + public void WriteToRequestBody(HttpContent requestContent, byte[] content, + IDictionary contentHeaders) + { + // Not needed for mock + } + + public Task WriteToRequestBodyAsync(HttpContent requestContent, Stream contentStream, + IDictionary contentHeaders, IRequestContext requestContext) + { + return Task.CompletedTask; + } + + public Task WriteToRequestBodyAsync(HttpContent requestContent, byte[] content, + IDictionary contentHeaders, CancellationToken cancellationToken = default) + { + return Task.CompletedTask; + } +#endif + + public IHttpRequestStreamHandle SetupHttpRequestStreamPublisher( + IDictionary contentHeaders, IHttpRequestStreamPublisher publisher) + { + throw new NotImplementedException(); + } + + public void Abort() + { + // Not needed for mock + } + +#if NETFRAMEWORK + public Task GetRequestContentAsync() + { + return Task.FromResult(new MemoryStream()); + } + + public Task GetRequestContentAsync(CancellationToken cancellationToken) + { + return Task.FromResult(new MemoryStream()); + } +#else + public Task GetRequestContentAsync() + { + return Task.FromResult(null); + } + + public Task GetRequestContentAsync(CancellationToken cancellationToken) + { + return Task.FromResult(null); + } +#endif + + public Stream SetupProgressListeners(Stream originalStream, long progressUpdateInterval, + object sender, EventHandler callback) + { + return originalStream; + } + + public void Dispose() + { + // Nothing to dispose + } + } + + /// + /// Stubbed web response data for testing different response scenarios + /// + private class StubWebResponseData : IWebResponseData + { + private readonly IHttpResponseBody _httpResponseBody; + + public StubWebResponseData(string jsonResponse, Dictionary headers = null, + HttpStatusCode statusCode = HttpStatusCode.OK) + { + StatusCode = statusCode; + IsSuccessStatusCode = (int)statusCode >= 200 && (int)statusCode < 300; + JsonResponse = jsonResponse; + Headers = headers ?? new Dictionary(StringComparer.OrdinalIgnoreCase); + ContentType = "application/json"; + ContentLength = jsonResponse?.Length ?? 0; + + _httpResponseBody = new HttpResponseBody(jsonResponse); + } + + public Dictionary Headers { get; set; } + public string JsonResponse { get; } + public long ContentLength { get; set; } + public string ContentType { get; set; } + public HttpStatusCode StatusCode { get; set; } + public bool IsSuccessStatusCode { get; set; } + + public IHttpResponseBody ResponseBody => _httpResponseBody; + + public string[] GetHeaderNames() + { + return Headers.Keys.ToArray(); + } + + public bool IsHeaderPresent(string headerName) + { + return Headers.ContainsKey(headerName); + } + + public string GetHeaderValue(string headerName) + { + return Headers.ContainsKey(headerName) ? Headers[headerName] : null; + } + } + + /// + /// HTTP response body implementation for stubbed responses + /// + private class HttpResponseBody : IHttpResponseBody + { + private readonly string _jsonResponse; + private Stream _stream; + + public HttpResponseBody(string jsonResponse) + { + _jsonResponse = jsonResponse ?? string.Empty; + } + + public void Dispose() + { + _stream?.Dispose(); + } + + public Stream OpenResponse() + { + _stream = new MemoryStream(Encoding.UTF8.GetBytes(_jsonResponse)); + return _stream; + } + + public Task OpenResponseAsync() + { + return Task.FromResult(OpenResponse()); + } + } + + #endregion + + #region ResponseFormat with HTTP Mocking Tests + + [Fact] + [Trait("UnitTest", "BedrockRuntime")] + public async Task ResponseFormat_Json_WithActualConverseResponse_ParsesCorrectly() + { + // Arrange - This is a real Converse API response with tool_use + var converseResponse = """ + { + "output": { + "message": { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "tooluse_12345", + "name": "generate_response", + "input": { + "name": "Alice Johnson", + "age": 28, + "city": "Seattle" + } + } + } + ] + } + }, + "stopReason": "tool_use", + "usage": { + "inputTokens": 125, + "outputTokens": 45, + "totalTokens": 170 + } + } + """; + + var chatClient = _fixture.BedrockRuntimeClient.AsIChatClient("anthropic.claude-3-sonnet-20240229-v1:0"); + var messages = new[] { new ChatMessage(ChatRole.User, "Generate a person") }; + + var schemaJson = """ + { + "type": "object", + "properties": { + "name": { "type": "string" }, + "age": { "type": "number" }, + "city": { "type": "string" } + }, + "required": ["name", "age"] + } + """; + var schemaElement = JsonDocument.Parse(schemaJson).RootElement; + + var request = new ConverseRequest(); + var options = new ChatOptions + { + ResponseFormat = ChatResponseFormat.ForJsonSchema(schemaElement, + schemaName: "PersonSchema", + schemaDescription: "A person with demographic information"), + RawRepresentationFactory = _ => request + }; + + // Inject the stubbed response + var webResponseData = new StubWebResponseData(converseResponse); + InjectMockedResponse(request, webResponseData); + + // Act + var response = await chatClient.GetResponseAsync(messages, options); + + // Assert + Assert.NotNull(response); + Assert.NotNull(response.Text); + + // Verify the JSON structure + var json = JsonDocument.Parse(response.Text); + Assert.Equal("Alice Johnson", json.RootElement.GetProperty("name").GetString()); + Assert.Equal(28, json.RootElement.GetProperty("age").GetInt32()); + Assert.Equal("Seattle", json.RootElement.GetProperty("city").GetString()); + + // Verify usage metadata + var usage = response.Usage; + Assert.NotNull(usage); + Assert.Equal(125, usage.InputTokenCount); + Assert.Equal(45, usage.OutputTokenCount); + Assert.Equal(170, usage.TotalTokenCount); + } + + [Fact] + [Trait("UnitTest", "BedrockRuntime")] + public async Task ResponseFormat_Json_WithNestedObjects_ParsesCorrectly() + { + // Arrange - Test with nested JSON structure + var converseResponse = """ + { + "output": { + "message": { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "tooluse_nested", + "name": "generate_response", + "input": { + "user": { + "name": "Bob Smith", + "contact": { + "email": "bob@example.com", + "phone": "555-0123" + } + }, + "metadata": { + "timestamp": "2024-01-15T10:30:00Z", + "version": 1 + } + } + } + } + ] + } + }, + "stopReason": "tool_use", + "usage": { + "inputTokens": 200, + "outputTokens": 80, + "totalTokens": 280 + } + } + """; + + var chatClient = _fixture.BedrockRuntimeClient.AsIChatClient("anthropic.claude-3-sonnet-20240229-v1:0"); + var messages = new[] { new ChatMessage(ChatRole.User, "Generate user data") }; + + var request = new ConverseRequest(); + var options = new ChatOptions + { + ResponseFormat = ChatResponseFormat.Json, + RawRepresentationFactory = _ => request + }; + + var webResponseData = new StubWebResponseData(converseResponse); + InjectMockedResponse(request, webResponseData); + + // Act + var response = await chatClient.GetResponseAsync(messages, options); + + // Assert + Assert.NotNull(response.Text); + var json = JsonDocument.Parse(response.Text); + + var user = json.RootElement.GetProperty("user"); + Assert.Equal("Bob Smith", user.GetProperty("name").GetString()); + + var contact = user.GetProperty("contact"); + Assert.Equal("bob@example.com", contact.GetProperty("email").GetString()); + Assert.Equal("555-0123", contact.GetProperty("phone").GetString()); + + var metadata = json.RootElement.GetProperty("metadata"); + Assert.Equal("2024-01-15T10:30:00Z", metadata.GetProperty("timestamp").GetString()); + Assert.Equal(1, metadata.GetProperty("version").GetInt32()); + } + + [Fact] + [Trait("UnitTest", "BedrockRuntime")] + public async Task ResponseFormat_Json_WithArrayData_ParsesCorrectly() + { + // Arrange - Test with arrays in JSON response + var converseResponse = """ + { + "output": { + "message": { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "tooluse_array", + "name": "generate_response", + "input": { + "items": ["apple", "banana", "orange"], + "prices": [1.99, 0.99, 2.49], + "inventory": { + "warehouse": "A", + "quantities": [100, 250, 75] + } + } + } + } + ] + } + }, + "stopReason": "tool_use", + "usage": { + "inputTokens": 50, + "outputTokens": 30, + "totalTokens": 80 + } + } + """; + + var chatClient = _fixture.BedrockRuntimeClient.AsIChatClient("anthropic.claude-3-sonnet-20240229-v1:0"); + var messages = new[] { new ChatMessage(ChatRole.User, "List items") }; + + var request = new ConverseRequest(); + var options = new ChatOptions + { + ResponseFormat = ChatResponseFormat.Json, + RawRepresentationFactory = _ => request + }; + + var webResponseData = new StubWebResponseData(converseResponse); + InjectMockedResponse(request, webResponseData); + + // Act + var response = await chatClient.GetResponseAsync(messages, options); + + // Assert + Assert.NotNull(response.Text); + var json = JsonDocument.Parse(response.Text); + + var items = json.RootElement.GetProperty("items"); + Assert.Equal(JsonValueKind.Array, items.ValueKind); + Assert.Equal(3, items.GetArrayLength()); + Assert.Equal("apple", items[0].GetString()); + Assert.Equal("banana", items[1].GetString()); + Assert.Equal("orange", items[2].GetString()); + + var prices = json.RootElement.GetProperty("prices"); + Assert.Equal(3, prices.GetArrayLength()); + Assert.Equal(1.99, prices[0].GetDouble(), precision: 2); + + var inventory = json.RootElement.GetProperty("inventory"); + var quantities = inventory.GetProperty("quantities"); + Assert.Equal(3, quantities.GetArrayLength()); + Assert.Equal(100, quantities[0].GetInt32()); + } + + [Fact] + [Trait("UnitTest", "BedrockRuntime")] + public async Task ResponseFormat_Json_WithMinimalSchema_ParsesCorrectly() + { + // Arrange - Test simple JSON response + var converseResponse = """ + { + "output": { + "message": { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "tooluse_simple", + "name": "generate_response", + "input": { + "message": "Hello, World!", + "status": "success" + } + } + } + ] + } + }, + "stopReason": "tool_use", + "usage": { + "inputTokens": 10, + "outputTokens": 5, + "totalTokens": 15 + } + } + """; + + var chatClient = _fixture.BedrockRuntimeClient.AsIChatClient("anthropic.claude-3-haiku-20240307-v1:0"); + var messages = new[] { new ChatMessage(ChatRole.User, "Say hello") }; + + var request = new ConverseRequest(); + var options = new ChatOptions + { + ResponseFormat = ChatResponseFormat.Json, + RawRepresentationFactory = _ => request + }; + + var webResponseData = new StubWebResponseData(converseResponse); + InjectMockedResponse(request, webResponseData); + + // Act + var response = await chatClient.GetResponseAsync(messages, options); + + // Assert + Assert.NotNull(response.Text); + var json = JsonDocument.Parse(response.Text); + Assert.Equal("Hello, World!", json.RootElement.GetProperty("message").GetString()); + Assert.Equal("success", json.RootElement.GetProperty("status").GetString()); + } + + [Fact] + [Trait("UnitTest", "BedrockRuntime")] + public async Task ResponseFormat_Json_WithComplexSchema_ValidatesStructure() + { + // Arrange - Test with detailed schema validation + var converseResponse = """ + { + "output": { + "message": { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "tooluse_complex", + "name": "generate_response", + "input": { + "id": "usr_123", + "username": "testuser", + "email": "test@example.com", + "profile": { + "firstName": "Test", + "lastName": "User", + "age": 25, + "preferences": { + "theme": "dark", + "notifications": true + } + }, + "roles": ["admin", "user"], + "active": true + } + } + } + ] + } + }, + "stopReason": "tool_use", + "usage": { + "inputTokens": 300, + "outputTokens": 150, + "totalTokens": 450 + } + } + """; + + var chatClient = _fixture.BedrockRuntimeClient.AsIChatClient("anthropic.claude-3-sonnet-20240229-v1:0"); + var messages = new[] { new ChatMessage(ChatRole.User, "Generate user profile") }; + + var schemaJson = """ + { + "type": "object", + "properties": { + "id": { "type": "string" }, + "username": { "type": "string" }, + "email": { "type": "string", "format": "email" }, + "profile": { + "type": "object", + "properties": { + "firstName": { "type": "string" }, + "lastName": { "type": "string" }, + "age": { "type": "number" }, + "preferences": { "type": "object" } + }, + "required": ["firstName", "lastName"] + }, + "roles": { + "type": "array", + "items": { "type": "string" } + }, + "active": { "type": "boolean" } + }, + "required": ["id", "username", "email"] + } + """; + var schemaElement = JsonDocument.Parse(schemaJson).RootElement; + + var request = new ConverseRequest(); + var options = new ChatOptions + { + ResponseFormat = ChatResponseFormat.ForJsonSchema(schemaElement, + schemaName: "UserProfile", + schemaDescription: "Complete user profile with preferences"), + RawRepresentationFactory = _ => request + }; + + var webResponseData = new StubWebResponseData(converseResponse); + InjectMockedResponse(request, webResponseData); + + // Act + var response = await chatClient.GetResponseAsync(messages, options); + + // Assert + Assert.NotNull(response.Text); + var json = JsonDocument.Parse(response.Text); + + // Verify required fields + Assert.Equal("usr_123", json.RootElement.GetProperty("id").GetString()); + Assert.Equal("testuser", json.RootElement.GetProperty("username").GetString()); + Assert.Equal("test@example.com", json.RootElement.GetProperty("email").GetString()); + + // Verify nested profile + var profile = json.RootElement.GetProperty("profile"); + Assert.Equal("Test", profile.GetProperty("firstName").GetString()); + Assert.Equal("User", profile.GetProperty("lastName").GetString()); + Assert.Equal(25, profile.GetProperty("age").GetInt32()); + + // Verify nested preferences + var preferences = profile.GetProperty("preferences"); + Assert.Equal("dark", preferences.GetProperty("theme").GetString()); + Assert.True(preferences.GetProperty("notifications").GetBoolean()); + + // Verify array + var roles = json.RootElement.GetProperty("roles"); + Assert.Equal(2, roles.GetArrayLength()); + Assert.Equal("admin", roles[0].GetString()); + Assert.Equal("user", roles[1].GetString()); + + // Verify boolean + Assert.True(json.RootElement.GetProperty("active").GetBoolean()); + } + + #endregion + + /// + /// Test fixture that registers the HTTP mocking pipeline customizer + /// + public class HttpMockFixture : IDisposable + { + private readonly MockPipelineCustomizer _customizer; + + public HttpMockFixture() + { + // Register the mock pipeline customizer globally + _customizer = new MockPipelineCustomizer(); + Runtime.Internal.RuntimePipelineCustomizerRegistry.Instance.Register(_customizer); + + // Create the Bedrock Runtime client - it will use the mocked pipeline + BedrockRuntimeClient = new AmazonBedrockRuntimeClient(); + } + + public IAmazonBedrockRuntime BedrockRuntimeClient { get; private set; } + + public void Dispose() + { + // Clean up + Runtime.Internal.RuntimePipelineCustomizerRegistry.Instance.Deregister(_customizer); + BedrockRuntimeClient?.Dispose(); + } + } +} \ No newline at end of file diff --git a/extensions/test/BedrockMEAITests/BedrockMEAITests.NetFramework.csproj b/extensions/test/BedrockMEAITests/BedrockMEAITests.NetFramework.csproj index dd9de35ce4a5..915b0769fe83 100644 --- a/extensions/test/BedrockMEAITests/BedrockMEAITests.NetFramework.csproj +++ b/extensions/test/BedrockMEAITests/BedrockMEAITests.NetFramework.csproj @@ -19,6 +19,7 @@ + diff --git a/generator/.DevConfigs/12b83a1f-1d6b-4e96-bd62-f0e0b7e4df6d.json b/generator/.DevConfigs/12b83a1f-1d6b-4e96-bd62-f0e0b7e4df6d.json new file mode 100644 index 000000000000..297c2483daf0 --- /dev/null +++ b/generator/.DevConfigs/12b83a1f-1d6b-4e96-bd62-f0e0b7e4df6d.json @@ -0,0 +1,11 @@ +{ + "extensions": [ + { + "extensionName": "Extensions.Bedrock.MEAI", + "type": "minor", + "changeLogMessages": [ + "Add support for ChatOptions.ResponseFormat to enable structured JSON responses using JSON Schema." + ] + } + ] +}