11"""Unit tests for BidiAgent."""
22
3- import unittest .mock
43import asyncio
5- import pytest
4+ import unittest . mock
65from uuid import uuid4
76
7+ import pytest
8+
89from strands .experimental .bidi .agent .agent import BidiAgent
910from strands .experimental .bidi .models .nova_sonic import BidiNovaSonicModel
1011from strands .experimental .bidi .types .events import (
11- BidiTextInputEvent ,
1212 BidiAudioInputEvent ,
1313 BidiAudioStreamEvent ,
14- BidiTranscriptStreamEvent ,
15- BidiConnectionStartEvent ,
1614 BidiConnectionCloseEvent ,
15+ BidiConnectionStartEvent ,
16+ BidiTextInputEvent ,
17+ BidiTranscriptStreamEvent ,
1718)
1819
20+
1921class MockBidiModel :
2022 """Mock bidirectional model for testing."""
2123
@@ -46,42 +48,44 @@ async def receive(self):
4648 """Async generator yielding mock events."""
4749 if not self ._started :
4850 raise RuntimeError ("model not started | call start before sending/receiving" )
49-
51+
5052 # Yield connection start event
5153 yield BidiConnectionStartEvent (connection_id = self ._connection_id , model = self .model_id )
52-
54+
5355 # Yield any configured events
5456 for event in self ._events_to_yield :
5557 yield event
56-
58+
5759 # Yield connection end event
5860 yield BidiConnectionCloseEvent (connection_id = self ._connection_id , reason = "complete" )
5961
6062 def set_events (self , events ):
6163 """Helper to set events this mock model will yield."""
6264 self ._events_to_yield = events
6365
66+
6467@pytest .fixture
6568def mock_model ():
6669 """Create a mock BidiModel instance."""
6770 return MockBidiModel ()
6871
72+
6973@pytest .fixture
7074def mock_tool_registry ():
7175 """Mock tool registry with some basic tools."""
7276 registry = unittest .mock .Mock ()
7377 registry .get_all_tool_specs .return_value = [
7478 {
7579 "name" : "calculator" ,
76- "description" : "Perform calculations" ,
77- "inputSchema" : {"json" : {"type" : "object" , "properties" : {}}}
80+ "description" : "Perform calculations" ,
81+ "inputSchema" : {"json" : {"type" : "object" , "properties" : {}}},
7882 }
7983 ]
8084 registry .get_all_tools_config .return_value = {"calculator" : {}}
8185 return registry
8286
8387
84- @pytest .fixture
88+ @pytest .fixture
8589def mock_tool_caller ():
8690 """Mock tool caller for testing tool execution."""
8791 caller = unittest .mock .AsyncMock ()
@@ -94,203 +98,194 @@ def agent(mock_model, mock_tool_registry, mock_tool_caller):
9498 """Create a BidiAgent instance for testing."""
9599 with unittest .mock .patch ("strands.experimental.bidi.agent.agent.ToolRegistry" ) as mock_registry_class :
96100 mock_registry_class .return_value = mock_tool_registry
97-
101+
98102 with unittest .mock .patch ("strands.experimental.bidi.agent.agent._ToolCaller" ) as mock_caller_class :
99103 mock_caller_class .return_value = mock_tool_caller
100-
104+
101105 # Don't pass tools to avoid real tool loading
102106 agent = BidiAgent (model = mock_model )
103107 return agent
104108
109+
105110def test_bidi_agent_init_with_various_configurations ():
106111 """Test agent initialization with various configurations."""
107112 # Test default initialization
108113 mock_model = MockBidiModel ()
109114 agent = BidiAgent (model = mock_model )
110-
115+
111116 assert agent .model == mock_model
112117 assert agent .system_prompt is None
113118 assert not agent ._started
114119 assert agent .model ._connection_id is None
115-
120+
116121 # Test with configuration
117122 system_prompt = "You are a helpful assistant."
118- agent_with_config = BidiAgent (
119- model = mock_model ,
120- system_prompt = system_prompt ,
121- agent_id = "test_agent"
122- )
123-
123+ agent_with_config = BidiAgent (model = mock_model , system_prompt = system_prompt , agent_id = "test_agent" )
124+
124125 assert agent_with_config .system_prompt == system_prompt
125126 assert agent_with_config .agent_id == "test_agent"
126-
127+
127128 # Test with string model ID
128129 model_id = "amazon.nova-sonic-v1:0"
129130 agent_with_string = BidiAgent (model = model_id )
130-
131+
131132 assert isinstance (agent_with_string .model , BidiNovaSonicModel )
132133 assert agent_with_string .model .model_id == model_id
133-
134+
134135 # Test model config access
135136 config = agent .model .config
136137 assert config ["audio" ]["input_rate" ] == 16000
137138 assert config ["audio" ]["output_rate" ] == 24000
138139 assert config ["audio" ]["channels" ] == 1
139140
141+
140142@pytest .mark .asyncio
141143async def test_bidi_agent_start_stop_lifecycle (agent ):
142144 """Test agent start/stop lifecycle and state management."""
143145 # Initial state
144146 assert not agent ._started
145147 assert agent .model ._connection_id is None
146-
148+
147149 # Start agent
148150 await agent .start ()
149151 assert agent ._started
150152 assert agent .model ._connection_id is not None
151153 connection_id = agent .model ._connection_id
152-
154+
153155 # Double start should error
154156 with pytest .raises (RuntimeError , match = "agent already started" ):
155157 await agent .start ()
156-
158+
157159 # Stop agent
158160 await agent .stop ()
159161 assert not agent ._started
160162 assert agent .model ._connection_id is None
161-
163+
162164 # Multiple stops should be safe
163165 await agent .stop ()
164166 await agent .stop ()
165-
167+
166168 # Restart should work with new connection ID
167169 await agent .start ()
168170 assert agent ._started
169171 assert agent .model ._connection_id != connection_id
170172
173+
171174@pytest .mark .asyncio
172175async def test_bidi_agent_send_with_input_types (agent ):
173176 """Test sending various input types through agent.send()."""
174177 await agent .start ()
175-
178+
176179 # Test text input with TypedEvent
177180 text_input = BidiTextInputEvent (text = "Hello" , role = "user" )
178181 await agent .send (text_input )
179182 assert len (agent .messages ) == 1
180183 assert agent .messages [0 ]["content" ][0 ]["text" ] == "Hello"
181-
184+
182185 # Test string input (shorthand)
183186 await agent .send ("World" )
184187 assert len (agent .messages ) == 2
185188 assert agent .messages [1 ]["content" ][0 ]["text" ] == "World"
186-
189+
187190 # Test audio input (doesn't add to messages)
188191 audio_input = BidiAudioInputEvent (
189192 audio = "dGVzdA==" , # base64 "test"
190193 format = "pcm" ,
191194 sample_rate = 16000 ,
192- channels = 1
195+ channels = 1 ,
193196 )
194197 await agent .send (audio_input )
195198 assert len (agent .messages ) == 2 # Still 2, audio doesn't add
196-
199+
197200 # Test concurrent sends
198- sends = [
199- agent .send (BidiTextInputEvent (text = f"Message { i } " , role = "user" ))
200- for i in range (3 )
201- ]
201+ sends = [agent .send (BidiTextInputEvent (text = f"Message { i } " , role = "user" )) for i in range (3 )]
202202 await asyncio .gather (* sends )
203203 assert len (agent .messages ) == 5 # 2 + 3 new messages
204204
205+
205206@pytest .mark .asyncio
206207async def test_bidi_agent_receive_events_from_model (agent ):
207208 """Test receiving events from model."""
208209 # Configure mock model to yield events
209210 events = [
210- BidiAudioStreamEvent (
211- audio = "dGVzdA==" ,
212- format = "pcm" ,
213- sample_rate = 24000 ,
214- channels = 1
215- ),
211+ BidiAudioStreamEvent (audio = "dGVzdA==" , format = "pcm" , sample_rate = 24000 , channels = 1 ),
216212 BidiTranscriptStreamEvent (
217213 text = "Hello world" ,
218214 role = "assistant" ,
219215 is_final = True ,
220216 delta = {"text" : "Hello world" },
221- current_transcript = "Hello world"
222- )
217+ current_transcript = "Hello world" ,
218+ ),
223219 ]
224220 agent .model .set_events (events )
225-
221+
226222 await agent .start ()
227-
223+
228224 received_events = []
229225 async for event in agent .receive ():
230226 received_events .append (event )
231227 if len (received_events ) >= 4 : # Stop after getting expected events
232228 break
233-
229+
234230 # Verify event types and order
235231 assert len (received_events ) >= 3
236232 assert isinstance (received_events [0 ], BidiConnectionStartEvent )
237233 assert isinstance (received_events [1 ], BidiAudioStreamEvent )
238234 assert isinstance (received_events [2 ], BidiTranscriptStreamEvent )
239-
235+
240236 # Test empty events
241237 agent .model .set_events ([])
242238 await agent .stop ()
243239 await agent .start ()
244-
240+
245241 empty_events = []
246242 async for event in agent .receive ():
247243 empty_events .append (event )
248244 if len (empty_events ) >= 2 :
249245 break
250-
246+
251247 assert len (empty_events ) >= 1
252248 assert isinstance (empty_events [0 ], BidiConnectionStartEvent )
253249
250+
254251def test_bidi_agent_tool_integration (agent , mock_tool_registry ):
255252 """Test agent tool integration and properties."""
256253 # Test tool property access
257- assert hasattr (agent , ' tool' )
254+ assert hasattr (agent , " tool" )
258255 assert agent .tool is not None
259256 assert agent .tool == agent ._tool_caller
260-
257+
261258 # Test tool names property
262- mock_tool_registry .get_all_tools_config .return_value = {
263- "calculator" : {},
264- "weather" : {}
265- }
266-
259+ mock_tool_registry .get_all_tools_config .return_value = {"calculator" : {}, "weather" : {}}
260+
267261 tool_names = agent .tool_names
268262 assert isinstance (tool_names , list )
269263 assert len (tool_names ) == 2
270264 assert "calculator" in tool_names
271265 assert "weather" in tool_names
272266
267+
273268@pytest .mark .asyncio
274269async def test_bidi_agent_send_receive_error_before_start (agent ):
275270 """Test error handling in various scenarios."""
276271 # Test send before start
277272 with pytest .raises (RuntimeError , match = "call start before" ):
278273 await agent .send (BidiTextInputEvent (text = "Hello" , role = "user" ))
279-
274+
280275 # Test receive before start
281276 with pytest .raises (RuntimeError , match = "call start before" ):
282- async for event in agent .receive ():
277+ async for _ in agent .receive ():
283278 pass
284-
279+
285280 # Test send after stop
286281 await agent .start ()
287282 await agent .stop ()
288283 with pytest .raises (RuntimeError , match = "call start before" ):
289284 await agent .send (BidiTextInputEvent (text = "Hello" , role = "user" ))
290-
285+
291286 # Test receive after stop
292287 with pytest .raises (RuntimeError , match = "call start before" ):
293- async for event in agent .receive ():
288+ async for _ in agent .receive ():
294289 pass
295290
296291
@@ -301,43 +296,44 @@ async def test_bidi_agent_start_receive_propagates_model_errors():
301296 mock_model = MockBidiModel ()
302297 mock_model .start = unittest .mock .AsyncMock (side_effect = Exception ("Connection failed" ))
303298 error_agent = BidiAgent (model = mock_model )
304-
299+
305300 with pytest .raises (Exception , match = "Connection failed" ):
306301 await error_agent .start ()
307-
302+
308303 # Test model receive error
309304 mock_model2 = MockBidiModel ()
310305 agent2 = BidiAgent (model = mock_model2 )
311306 await agent2 .start ()
312-
307+
313308 async def failing_receive ():
314309 yield BidiConnectionStartEvent (connection_id = "test" , model = "test-model" )
315310 raise Exception ("Receive failed" )
316-
311+
317312 agent2 .model .receive = failing_receive
318313 with pytest .raises (Exception , match = "Receive failed" ):
319- async for event in agent2 .receive ():
314+ async for _ in agent2 .receive ():
320315 pass
321316
317+
322318@pytest .mark .asyncio
323319async def test_bidi_agent_state_consistency (agent ):
324320 """Test that agent state remains consistent across operations."""
325321 # Initial state
326322 assert not agent ._started
327323 assert agent .model ._connection_id is None
328-
324+
329325 # Start
330326 await agent .start ()
331327 assert agent ._started
332328 assert agent .model ._connection_id is not None
333329 connection_id = agent .model ._connection_id
334-
330+
335331 # Send operations shouldn't change connection state
336332 await agent .send (BidiTextInputEvent (text = "Hello" , role = "user" ))
337333 assert agent ._started
338334 assert agent .model ._connection_id == connection_id
339-
335+
340336 # Stop
341337 await agent .stop ()
342338 assert not agent ._started
343- assert agent .model ._connection_id is None
339+ assert agent .model ._connection_id is None
0 commit comments