diff --git a/CHANGELOG.md b/CHANGELOG.md index 9daaa7803..a086e69e8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,4 +23,6 @@ If your change does not need a CHANGELOG entry, add the "skip changelog" label t ([#522](https://github.com/aws-observability/aws-otel-python-instrumentation/pull/522)) - Support credentials provider name for BedrockAgentCore Identity ([#534](https://github.com/aws-observability/aws-otel-python-instrumentation/pull/534)) +- [PATCH] Add safety check for bedrock ConverseStream responses + ([#547](https://github.com/aws-observability/aws-otel-python-instrumentation/pull/547)) diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_botocore_patches.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_botocore_patches.py index a415f6148..22af87af5 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_botocore_patches.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/patches/_botocore_patches.py @@ -334,9 +334,10 @@ def patched_extract_tool_calls( tool_calls.append(tool_call) return tool_calls - # TODO: The following code is to patch a bedrock bug that was fixed in + # TODO: The following code is to patch bedrock bugs that were fixed in # opentelemetry-instrumentation-botocore==0.60b0 in: # https://github.com/open-telemetry/opentelemetry-python-contrib/pull/3875 + # https://github.com/open-telemetry/opentelemetry-python-contrib/pull/3990 # Remove this code once we've bumped opentelemetry-instrumentation-botocore dependency to 0.60b0 def patched_process_anthropic_claude_chunk(self, chunk): # pylint: disable=too-many-return-statements,too-many-branches @@ -412,12 +413,30 @@ def patched_process_anthropic_claude_chunk(self, chunk): self._stream_done_callback(self._response) return + def patched_from_converse(cls, response: dict[str, Any], capture_content: bool) -> bedrock_utils._Choice: + # be defensive about malformed responses, refer to #3958 for more context + output = response.get("output", {}) + orig_message = output.get("message", {}) + if role := orig_message.get("role"): + message = {"role": role} + else: + # amazon.titan does not serialize the role + message = {} + + if tool_calls := bedrock_utils.extract_tool_calls(orig_message, capture_content): + message["tool_calls"] = tool_calls + elif capture_content and (content := orig_message.get("content")): + message["content"] = content + + return cls(message, response["stopReason"], index=0) + bedrock_utils.ConverseStreamWrapper.__init__ = patched_init bedrock_utils.ConverseStreamWrapper._process_event = patched_process_event bedrock_utils.InvokeModelWithResponseStreamWrapper._process_anthropic_claude_chunk = ( patched_process_anthropic_claude_chunk ) bedrock_utils.extract_tool_calls = patched_extract_tool_calls + bedrock_utils._Choice.from_converse = classmethod(patched_from_converse) # END The OpenTelemetry Authors code diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/patches/test_instrumentation_patch.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/patches/test_instrumentation_patch.py index e27f881aa..5880e8491 100644 --- a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/patches/test_instrumentation_patch.py +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/patches/test_instrumentation_patch.py @@ -250,6 +250,7 @@ def _test_patched_botocore_instrumentation(self): self._test_patched_process_anthropic_claude_chunk({"location": "Seattle"}, {"location": "Seattle"}) self._test_patched_process_anthropic_claude_chunk(None, None) self._test_patched_process_anthropic_claude_chunk({}, {}) + self._test_patched_from_converse_with_malformed_response() # Bedrock Agent Runtime self.assertTrue("bedrock-agent-runtime" in _KNOWN_EXTENSIONS) @@ -645,6 +646,15 @@ def _test_patched_extract_tool_calls(self): result = bedrock_utils.extract_tool_calls(message_with_type_tool_use, True) self.assertEqual(len(result), 1) + def _test_patched_from_converse_with_malformed_response(self): + """Test patched from_converse handles malformed response missing output key""" + malformed_response = {"stopReason": "end_turn"} + choice = bedrock_utils._Choice.from_converse(malformed_response, capture_content=False) + + self.assertEqual(choice.finish_reason, "end_turn") + self.assertEqual(choice.message, {}) + self.assertEqual(choice.index, 0) + def _test_patched_process_anthropic_claude_chunk( self, input_value: Dict[str, str], expected_output: Dict[str, str] ):