Skip to content

Commit 0917c59

Browse files
committed
fix Decimal errors (from DynamoDB)
1 parent 7c5f3b8 commit 0917c59

File tree

2 files changed

+77
-2
lines changed

2 files changed

+77
-2
lines changed

src/strands/session/dynamodb_session_manager.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""DynamoDB-based session manager for cloud storage."""
22

33
import logging
4+
from decimal import Decimal
45
from typing import Any, List, Optional
56

67
import boto3
@@ -17,6 +18,23 @@
1718
logger = logging.getLogger(__name__)
1819

1920

21+
def _convert_decimals_to_native_types(obj: Any) -> Any:
22+
"""Convert Decimal objects to native Python types recursively.
23+
24+
DynamoDB's TypeDeserializer returns Decimal objects for numeric values,
25+
but other AWS services expect native Python int/float types.
26+
"""
27+
if isinstance(obj, Decimal):
28+
# Convert to int if it's a whole number, otherwise float
29+
return int(obj) if obj % 1 == 0 else float(obj)
30+
elif isinstance(obj, dict):
31+
return {key: _convert_decimals_to_native_types(value) for key, value in obj.items()}
32+
elif isinstance(obj, list):
33+
return [_convert_decimals_to_native_types(item) for item in obj]
34+
else:
35+
return obj
36+
37+
2038
class DynamoDBSessionManager(RepositorySessionManager, SessionRepository):
2139
"""DynamoDB-based session manager for cloud storage.
2240
@@ -148,6 +166,7 @@ def read_session(self, session_id: str, **kwargs: Any) -> Optional[Session]:
148166
return None
149167

150168
data = self.deserializer.deserialize(response["Item"]["data"])
169+
data = _convert_decimals_to_native_types(data)
151170
return Session.from_dict(data)
152171
except ClientError as e:
153172
raise SessionException(f"DynamoDB error reading session: {e}") from e
@@ -205,6 +224,7 @@ def read_agent(self, session_id: str, agent_id: str, **kwargs: Any) -> Optional[
205224
return None
206225

207226
data = self.deserializer.deserialize(response["Item"]["data"])
227+
data = _convert_decimals_to_native_types(data)
208228
return SessionAgent.from_dict(data)
209229
except ClientError as e:
210230
raise SessionException(f"DynamoDB error reading agent: {e}") from e
@@ -263,6 +283,7 @@ def read_message(self, session_id: str, agent_id: str, message_id: int, **kwargs
263283
return None
264284

265285
data = self.deserializer.deserialize(response["Item"]["data"])
286+
data = _convert_decimals_to_native_types(data)
266287
return SessionMessage.from_dict(data)
267288
except ClientError as e:
268289
raise SessionException(f"DynamoDB error reading message: {e}") from e
@@ -322,6 +343,7 @@ def list_messages(
322343
messages = []
323344
for item in items:
324345
data = self.deserializer.deserialize(item["data"])
346+
data = _convert_decimals_to_native_types(data)
325347
messages.append(SessionMessage.from_dict(data))
326348

327349
return messages

tests/strands/session/test_dynamodb_session_manager.py

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
"""Tests for DynamoDBSessionManager."""
22

33
import time
4+
from decimal import Decimal
45

56
import boto3
67
import pytest
78
from botocore.config import Config as BotocoreConfig
89
from moto import mock_aws
910

1011
from strands.agent.conversation_manager.null_conversation_manager import NullConversationManager
11-
from strands.session.dynamodb_session_manager import DynamoDBSessionManager
12+
from strands.session.dynamodb_session_manager import DynamoDBSessionManager, \
13+
_convert_decimals_to_native_types
1214
from strands.types.content import ContentBlock
1315
from strands.types.exceptions import SessionException
1416
from strands.types.session import Session, SessionAgent, SessionMessage, SessionType
@@ -202,6 +204,7 @@ def test_read_agent(dynamodb_manager, sample_session, sample_agent):
202204

203205
assert result.agent_id == sample_agent.agent_id
204206
assert result.state == sample_agent.state
207+
assert isinstance(result.conversation_manager_state.get("removed_message_count"), int)
205208

206209

207210
def test_read_nonexistent_agent(dynamodb_manager, sample_session):
@@ -265,7 +268,7 @@ def test_read_message(dynamodb_manager, sample_session, sample_agent, sample_mes
265268
dynamodb_manager.create_message(sample_session.session_id, sample_agent.agent_id, sample_message)
266269

267270
result = dynamodb_manager.read_message(sample_session.session_id, sample_agent.agent_id, sample_message.message_id)
268-
271+
assert isinstance(result.message_id, int)
269272
assert result.message_id == sample_message.message_id
270273
assert result.message["role"] == sample_message.message["role"]
271274
assert result.message["content"] == sample_message.message["content"]
@@ -306,6 +309,8 @@ def test_list_messages_all(dynamodb_manager, sample_session, sample_agent):
306309
result = dynamodb_manager.list_messages(sample_session.session_id, sample_agent.agent_id)
307310

308311
assert len(result) == 5
312+
for msg in result:
313+
assert isinstance(msg.message_id, int)
309314

310315

311316
def test_list_messages_with_pagination(dynamodb_manager, sample_session, sample_agent):
@@ -398,3 +403,51 @@ def test__get_message_sk_invalid_message_id(message_id, dynamodb_manager):
398403
"""Test that message_id that is not an integer raises ValueError."""
399404
with pytest.raises(ValueError, match=r"message_id=<.*> \| message id must be an integer"):
400405
dynamodb_manager._get_message_sk("agent1", message_id)
406+
407+
408+
def test_convert_decimals_to_native_types():
409+
"""Test the Decimal conversion utility function."""
410+
# Test simple Decimal conversion
411+
assert _convert_decimals_to_native_types(Decimal('10')) == 10
412+
assert _convert_decimals_to_native_types(Decimal('10.5')) == 10.5
413+
assert _convert_decimals_to_native_types(Decimal('0')) == 0
414+
415+
# Test nested dictionary conversion
416+
data = {
417+
'limit': Decimal('10'),
418+
'max_length': Decimal('8000'),
419+
'temperature': Decimal('0.5'),
420+
'name': 'test',
421+
'enabled': True,
422+
'nested': {
423+
'count': Decimal('42'),
424+
'ratio': Decimal('3.14')
425+
}
426+
}
427+
428+
result = _convert_decimals_to_native_types(data)
429+
430+
assert result['limit'] == 10
431+
assert isinstance(result['limit'], int)
432+
assert result['max_length'] == 8000
433+
assert isinstance(result['max_length'], int)
434+
assert result['temperature'] == 0.5
435+
assert isinstance(result['temperature'], float)
436+
assert result['name'] == 'test'
437+
assert result['enabled'] is True
438+
assert result['nested']['count'] == 42
439+
assert isinstance(result['nested']['count'], int)
440+
assert result['nested']['ratio'] == 3.14
441+
assert isinstance(result['nested']['ratio'], float)
442+
443+
# Test list conversion
444+
list_data = [Decimal('1'), Decimal('2.5'), 'string', {'nested': Decimal('100')}]
445+
result = _convert_decimals_to_native_types(list_data)
446+
447+
assert result[0] == 1
448+
assert isinstance(result[0], int)
449+
assert result[1] == 2.5
450+
assert isinstance(result[1], float)
451+
assert result[2] == 'string'
452+
assert result[3]['nested'] == 100
453+
assert isinstance(result[3]['nested'], int)

0 commit comments

Comments
 (0)