diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 4a7c81672..d4c61a017 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -398,7 +398,7 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An # Handle source if "source" in document: - result["source"] = {"bytes": document["source"]["bytes"]} + result["source"] = document["source"] # Handle optional fields if "citations" in document and document["citations"] is not None: diff --git a/src/strands/types/media.py b/src/strands/types/media.py index 69cd60cf3..a2b2f2979 100644 --- a/src/strands/types/media.py +++ b/src/strands/types/media.py @@ -5,7 +5,7 @@ - Bedrock docs: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_Types_Amazon_Bedrock_Runtime.html """ -from typing import Literal, Optional +from typing import Any, List, Literal, Optional from typing_extensions import TypedDict @@ -15,14 +15,18 @@ """Supported document formats.""" -class DocumentSource(TypedDict): +class DocumentSource(TypedDict, total=False): """Contains the content of a document. Attributes: bytes: The binary content of the document. + text: The text content of the document source. + content: The structured content of the document source. """ bytes: bytes + text: str + content: List[Any] class DocumentContent(TypedDict, total=False): diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 2809e8a72..ea1ad574e 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -19,6 +19,7 @@ DEFAULT_BEDROCK_REGION, DEFAULT_READ_TIMEOUT, ) +from strands.types.content import ContentBlock from strands.types.exceptions import ModelThrottledException from strands.types.tools import ToolSpec @@ -2070,3 +2071,51 @@ async def test_stream_backward_compatibility_system_prompt(bedrock_client, model "system": [{"text": system_prompt}], } bedrock_client.converse_stream.assert_called_once_with(**expected_request) + + +def test_format_request_document_with_text_source(model): + """Test that _format_request_message_content correctly handles a document with a 'text' source.""" + document_text = "This is the document content." + content_block: ContentBlock = { + "document": { + "name": "test_doc", + "source": {"text": document_text}, + "format": "txt", + } + } + + formatted_content = model._format_request_message_content(content_block) + + assert formatted_content["document"]["source"] == {"text": document_text} + + +def test_format_request_document_with_bytes_source(model): + """Test that _format_request_message_content correctly handles a 'bytes' source.""" + + content_block: ContentBlock = { + "document": { + "name": "test_doc", + "source": {"bytes": b"some byte data"}, + "format": "txt", + } + } + + formatted_content = model._format_request_message_content(content_block) + + assert formatted_content["document"]["source"] == {"bytes": b"some byte data"} + + +def test_format_request_document_with_content_source(model): + """Test that _format_request_message_content correctly handles a 'content' source.""" + doc_content = [{"text": "structured content"}] + content_block: ContentBlock = { + "document": { + "name": "test_doc", + "source": {"content": doc_content}, + "format": "txt", + } + } + + formatted_content = model._format_request_message_content(content_block) + + assert formatted_content["document"]["source"] == {"content": doc_content}