1
1
from strands .session .session_repository import SessionRepository
2
2
from strands .types .exceptions import SessionException
3
+ from strands .types .session import SessionAgent , SessionMessage
3
4
4
5
5
6
class MockedSessionRepository (SessionRepository ):
@@ -11,21 +12,20 @@ def __init__(self):
11
12
self .agents = {}
12
13
self .messages = {}
13
14
14
- def create_session (self , session ):
15
+ def create_session (self , session ) -> None :
15
16
"""Create a session."""
16
17
session_id = session .session_id
17
18
if session_id in self .sessions :
18
19
raise SessionException (f"Session { session_id } already exists" )
19
20
self .sessions [session_id ] = session
20
21
self .agents [session_id ] = {}
21
22
self .messages [session_id ] = {}
22
- return session
23
23
24
- def read_session (self , session_id ):
24
+ def read_session (self , session_id ) -> SessionAgent :
25
25
"""Read a session."""
26
26
return self .sessions .get (session_id )
27
27
28
- def create_agent (self , session_id , session_agent ):
28
+ def create_agent (self , session_id , session_agent ) -> None :
29
29
"""Create an agent."""
30
30
agent_id = session_agent .agent_id
31
31
if session_id not in self .sessions :
@@ -36,13 +36,13 @@ def create_agent(self, session_id, session_agent):
36
36
self .messages .setdefault (session_id , {}).setdefault (agent_id , {})
37
37
return session_agent
38
38
39
- def read_agent (self , session_id , agent_id ):
39
+ def read_agent (self , session_id , agent_id ) -> SessionAgent :
40
40
"""Read an agent."""
41
41
if session_id not in self .sessions :
42
42
return None
43
43
return self .agents .get (session_id , {}).get (agent_id )
44
44
45
- def update_agent (self , session_id , session_agent ):
45
+ def update_agent (self , session_id , session_agent ) -> None :
46
46
"""Update an agent."""
47
47
agent_id = session_agent .agent_id
48
48
if session_id not in self .sessions :
@@ -51,7 +51,7 @@ def update_agent(self, session_id, session_agent):
51
51
raise SessionException (f"Agent { agent_id } does not exist in session { session_id } " )
52
52
self .agents [session_id ][agent_id ] = session_agent
53
53
54
- def create_message (self , session_id , agent_id , session_message ):
54
+ def create_message (self , session_id , agent_id , session_message ) -> None :
55
55
"""Create a message."""
56
56
message_id = session_message .message_id
57
57
if session_id not in self .sessions :
@@ -62,15 +62,15 @@ def create_message(self, session_id, agent_id, session_message):
62
62
raise SessionException (f"Message { message_id } already exists in agent { agent_id } in session { session_id } " )
63
63
self .messages .setdefault (session_id , {}).setdefault (agent_id , {})[message_id ] = session_message
64
64
65
- def read_message (self , session_id , agent_id , message_id ):
65
+ def read_message (self , session_id , agent_id , message_id ) -> SessionMessage :
66
66
"""Read a message."""
67
67
if session_id not in self .sessions :
68
68
return None
69
69
if agent_id not in self .agents .get (session_id , {}):
70
70
return None
71
71
return self .messages .get (session_id , {}).get (agent_id , {}).get (message_id )
72
72
73
- def update_message (self , session_id , agent_id , session_message ):
73
+ def update_message (self , session_id , agent_id , session_message ) -> None :
74
74
"""Update a message."""
75
75
76
76
message_id = session_message .message_id
@@ -82,7 +82,7 @@ def update_message(self, session_id, agent_id, session_message):
82
82
raise SessionException (f"Message { message_id } does not exist in session { session_id } " )
83
83
self .messages [session_id ][agent_id ][message_id ] = session_message
84
84
85
- def list_messages (self , session_id , agent_id , limit = None , offset = 0 ):
85
+ def list_messages (self , session_id , agent_id , limit = None , offset = 0 ) -> list [ SessionMessage ] :
86
86
"""List messages."""
87
87
if session_id not in self .sessions :
88
88
return []
0 commit comments