Skip to content

Commit 3d59b7f

Browse files
author
Motta Kin
committed
Cast source to specific type; update tests to use _map_messages
1 parent 8f0cc33 commit 3d59b7f

File tree

2 files changed

+78
-24
lines changed

2 files changed

+78
-24
lines changed

pydantic_ai_slim/pydantic_ai/models/bedrock.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import anyio.to_thread
1313
from botocore.exceptions import ClientError
14+
from mypy_boto3_bedrock_runtime.type_defs import DocumentSourceTypeDef
1415
from typing_extensions import ParamSpec, assert_never
1516

1617
from pydantic_ai import (
@@ -647,15 +648,15 @@ async def _map_user_prompt(part: UserPromptPart, document_count: Iterator[int])
647648
if item.kind == 'image-url':
648649
format = item.media_type.split('/')[1]
649650
assert format in ('jpeg', 'png', 'gif', 'webp'), f'Unsupported image format: {format}'
650-
image: ImageBlockTypeDef = {'format': format, 'source': cast(Any, source)}
651+
image: ImageBlockTypeDef = {'format': format, 'source': cast(DocumentSourceTypeDef, source)}
651652
content.append({'image': image})
652653

653654
elif item.kind == 'document-url':
654655
name = f'Document {next(document_count)}'
655656
document: DocumentBlockTypeDef = {
656657
'name': name,
657658
'format': item.format,
658-
'source': cast(Any, source),
659+
'source': cast(DocumentSourceTypeDef, source),
659660
}
660661
content.append({'document': document})
661662

@@ -672,7 +673,7 @@ async def _map_user_prompt(part: UserPromptPart, document_count: Iterator[int])
672673
'wmv',
673674
'three_gp',
674675
), f'Unsupported video format: {format}'
675-
video: VideoBlockTypeDef = {'format': format, 'source': cast(Any, source)}
676+
video: VideoBlockTypeDef = {'format': format, 'source': cast(DocumentSourceTypeDef, source)}
676677
content.append({'video': video})
677678
elif isinstance(item, AudioUrl): # pragma: no cover
678679
raise NotImplementedError('Audio is not supported yet.')

tests/models/test_bedrock.py

Lines changed: 74 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -729,42 +729,95 @@ async def test_text_document_url_input(allow_model_requests: None, bedrock_provi
729729
""")
730730

731731

732-
@pytest.mark.vcr()
733-
async def test_s3_image_url_input(allow_model_requests: None, bedrock_provider: BedrockProvider):
732+
async def test_s3_image_url_input(bedrock_provider: BedrockProvider):
734733
"""Test that s3:// image URLs are passed directly to Bedrock API without downloading."""
735-
m = BedrockConverseModel('us.amazon.nova-pro-v1:0', provider=bedrock_provider)
736-
agent = Agent(m, system_prompt='You are a helpful chatbot.')
734+
model = BedrockConverseModel('us.amazon.nova-pro-v1:0', provider=bedrock_provider)
737735
image_url = ImageUrl(url='s3://my-bucket/images/test-image.jpg', media_type='image/jpeg')
738736

739-
result = await agent.run(['What is in this image?', image_url])
740-
assert result.output == snapshot(
741-
'The image shows a scenic landscape with mountains in the background and a clear blue sky above.'
737+
req = [
738+
ModelRequest(parts=[UserPromptPart(content=['What is in this image?', image_url])]),
739+
]
740+
741+
_, bedrock_messages = await model._map_messages(req, ModelRequestParameters()) # type: ignore[reportPrivateUsage]
742+
743+
assert bedrock_messages == snapshot(
744+
[
745+
{
746+
'role': 'user',
747+
'content': [
748+
{'text': 'What is in this image?'},
749+
{
750+
'image': {
751+
'format': 'jpeg',
752+
'source': {'s3Location': {'uri': 's3://my-bucket/images/test-image.jpg'}},
753+
}
754+
},
755+
],
756+
}
757+
]
742758
)
743759

744760

745-
@pytest.mark.vcr()
746-
async def test_s3_video_url_input(allow_model_requests: None, bedrock_provider: BedrockProvider):
761+
async def test_s3_video_url_input(bedrock_provider: BedrockProvider):
747762
"""Test that s3:// video URLs are passed directly to Bedrock API."""
748-
m = BedrockConverseModel('us.amazon.nova-pro-v1:0', provider=bedrock_provider)
749-
agent = Agent(m, system_prompt='You are a helpful chatbot.')
763+
model = BedrockConverseModel('us.amazon.nova-pro-v1:0', provider=bedrock_provider)
750764
video_url = VideoUrl(url='s3://my-bucket/videos/test-video.mp4', media_type='video/mp4')
751765

752-
result = await agent.run(['Describe this video', video_url])
753-
assert result.output == snapshot(
754-
'The video shows a time-lapse of a sunset over the ocean with waves gently rolling onto the shore.'
766+
# Create a ModelRequest with the S3 video URL
767+
req = [
768+
ModelRequest(parts=[UserPromptPart(content=['Describe this video', video_url])]),
769+
]
770+
771+
# Call the mapping function directly
772+
_, bedrock_messages = await model._map_messages(req, ModelRequestParameters()) # type: ignore[reportPrivateUsage]
773+
774+
assert bedrock_messages == snapshot(
775+
[
776+
{
777+
'role': 'user',
778+
'content': [
779+
{'text': 'Describe this video'},
780+
{
781+
'video': {
782+
'format': 'mp4',
783+
'source': {'s3Location': {'uri': 's3://my-bucket/videos/test-video.mp4'}},
784+
}
785+
},
786+
],
787+
}
788+
]
755789
)
756790

757791

758-
@pytest.mark.vcr()
759-
async def test_s3_document_url_input(allow_model_requests: None, bedrock_provider: BedrockProvider):
792+
async def test_s3_document_url_input(bedrock_provider: BedrockProvider):
760793
"""Test that s3:// document URLs are passed directly to Bedrock API."""
761-
m = BedrockConverseModel('anthropic.claude-v2', provider=bedrock_provider)
762-
agent = Agent(m, system_prompt='You are a helpful chatbot.')
794+
model = BedrockConverseModel('anthropic.claude-v2', provider=bedrock_provider)
763795
document_url = DocumentUrl(url='s3://my-bucket/documents/test-doc.pdf', media_type='application/pdf')
764796

765-
result = await agent.run(['What is the main content on this document?', document_url])
766-
assert result.output == snapshot(
767-
'Based on the provided document, the main content discusses best practices for cloud storage and data management.'
797+
# Create a ModelRequest with the S3 document URL
798+
req = [
799+
ModelRequest(parts=[UserPromptPart(content=['What is the main content on this document?', document_url])]),
800+
]
801+
802+
# Call the mapping function directly
803+
_, bedrock_messages = await model._map_messages(req, ModelRequestParameters()) # type: ignore[reportPrivateUsage]
804+
805+
assert bedrock_messages == snapshot(
806+
[
807+
{
808+
'role': 'user',
809+
'content': [
810+
{'text': 'What is the main content on this document?'},
811+
{
812+
'document': {
813+
'format': 'pdf',
814+
'name': 'Document 1',
815+
'source': {'s3Location': {'uri': 's3://my-bucket/documents/test-doc.pdf'}},
816+
}
817+
},
818+
],
819+
}
820+
]
768821
)
769822

770823

0 commit comments

Comments
 (0)