Skip to content

Commit 6543097

Browse files
authored
bidi - tests - lint (#1307)
1 parent 45dd597 commit 6543097

File tree

9 files changed

+100
-103
lines changed

9 files changed

+100
-103
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
"""Bidirectional streaming agent tests."""
1+
"""Bidirectional streaming agent tests."""
Lines changed: 69 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,23 @@
11
"""Unit tests for BidiAgent."""
22

3-
import unittest.mock
43
import asyncio
5-
import pytest
4+
import unittest.mock
65
from uuid import uuid4
76

7+
import pytest
8+
89
from strands.experimental.bidi.agent.agent import BidiAgent
910
from strands.experimental.bidi.models.nova_sonic import BidiNovaSonicModel
1011
from strands.experimental.bidi.types.events import (
11-
BidiTextInputEvent,
1212
BidiAudioInputEvent,
1313
BidiAudioStreamEvent,
14-
BidiTranscriptStreamEvent,
15-
BidiConnectionStartEvent,
1614
BidiConnectionCloseEvent,
15+
BidiConnectionStartEvent,
16+
BidiTextInputEvent,
17+
BidiTranscriptStreamEvent,
1718
)
1819

20+
1921
class MockBidiModel:
2022
"""Mock bidirectional model for testing."""
2123

@@ -46,42 +48,44 @@ async def receive(self):
4648
"""Async generator yielding mock events."""
4749
if not self._started:
4850
raise RuntimeError("model not started | call start before sending/receiving")
49-
51+
5052
# Yield connection start event
5153
yield BidiConnectionStartEvent(connection_id=self._connection_id, model=self.model_id)
52-
54+
5355
# Yield any configured events
5456
for event in self._events_to_yield:
5557
yield event
56-
58+
5759
# Yield connection end event
5860
yield BidiConnectionCloseEvent(connection_id=self._connection_id, reason="complete")
5961

6062
def set_events(self, events):
6163
"""Helper to set events this mock model will yield."""
6264
self._events_to_yield = events
6365

66+
6467
@pytest.fixture
6568
def mock_model():
6669
"""Create a mock BidiModel instance."""
6770
return MockBidiModel()
6871

72+
6973
@pytest.fixture
7074
def mock_tool_registry():
7175
"""Mock tool registry with some basic tools."""
7276
registry = unittest.mock.Mock()
7377
registry.get_all_tool_specs.return_value = [
7478
{
7579
"name": "calculator",
76-
"description": "Perform calculations",
77-
"inputSchema": {"json": {"type": "object", "properties": {}}}
80+
"description": "Perform calculations",
81+
"inputSchema": {"json": {"type": "object", "properties": {}}},
7882
}
7983
]
8084
registry.get_all_tools_config.return_value = {"calculator": {}}
8185
return registry
8286

8387

84-
@pytest.fixture
88+
@pytest.fixture
8589
def mock_tool_caller():
8690
"""Mock tool caller for testing tool execution."""
8791
caller = unittest.mock.AsyncMock()
@@ -94,203 +98,194 @@ def agent(mock_model, mock_tool_registry, mock_tool_caller):
9498
"""Create a BidiAgent instance for testing."""
9599
with unittest.mock.patch("strands.experimental.bidi.agent.agent.ToolRegistry") as mock_registry_class:
96100
mock_registry_class.return_value = mock_tool_registry
97-
101+
98102
with unittest.mock.patch("strands.experimental.bidi.agent.agent._ToolCaller") as mock_caller_class:
99103
mock_caller_class.return_value = mock_tool_caller
100-
104+
101105
# Don't pass tools to avoid real tool loading
102106
agent = BidiAgent(model=mock_model)
103107
return agent
104108

109+
105110
def test_bidi_agent_init_with_various_configurations():
106111
"""Test agent initialization with various configurations."""
107112
# Test default initialization
108113
mock_model = MockBidiModel()
109114
agent = BidiAgent(model=mock_model)
110-
115+
111116
assert agent.model == mock_model
112117
assert agent.system_prompt is None
113118
assert not agent._started
114119
assert agent.model._connection_id is None
115-
120+
116121
# Test with configuration
117122
system_prompt = "You are a helpful assistant."
118-
agent_with_config = BidiAgent(
119-
model=mock_model,
120-
system_prompt=system_prompt,
121-
agent_id="test_agent"
122-
)
123-
123+
agent_with_config = BidiAgent(model=mock_model, system_prompt=system_prompt, agent_id="test_agent")
124+
124125
assert agent_with_config.system_prompt == system_prompt
125126
assert agent_with_config.agent_id == "test_agent"
126-
127+
127128
# Test with string model ID
128129
model_id = "amazon.nova-sonic-v1:0"
129130
agent_with_string = BidiAgent(model=model_id)
130-
131+
131132
assert isinstance(agent_with_string.model, BidiNovaSonicModel)
132133
assert agent_with_string.model.model_id == model_id
133-
134+
134135
# Test model config access
135136
config = agent.model.config
136137
assert config["audio"]["input_rate"] == 16000
137138
assert config["audio"]["output_rate"] == 24000
138139
assert config["audio"]["channels"] == 1
139140

141+
140142
@pytest.mark.asyncio
141143
async def test_bidi_agent_start_stop_lifecycle(agent):
142144
"""Test agent start/stop lifecycle and state management."""
143145
# Initial state
144146
assert not agent._started
145147
assert agent.model._connection_id is None
146-
148+
147149
# Start agent
148150
await agent.start()
149151
assert agent._started
150152
assert agent.model._connection_id is not None
151153
connection_id = agent.model._connection_id
152-
154+
153155
# Double start should error
154156
with pytest.raises(RuntimeError, match="agent already started"):
155157
await agent.start()
156-
158+
157159
# Stop agent
158160
await agent.stop()
159161
assert not agent._started
160162
assert agent.model._connection_id is None
161-
163+
162164
# Multiple stops should be safe
163165
await agent.stop()
164166
await agent.stop()
165-
167+
166168
# Restart should work with new connection ID
167169
await agent.start()
168170
assert agent._started
169171
assert agent.model._connection_id != connection_id
170172

173+
171174
@pytest.mark.asyncio
172175
async def test_bidi_agent_send_with_input_types(agent):
173176
"""Test sending various input types through agent.send()."""
174177
await agent.start()
175-
178+
176179
# Test text input with TypedEvent
177180
text_input = BidiTextInputEvent(text="Hello", role="user")
178181
await agent.send(text_input)
179182
assert len(agent.messages) == 1
180183
assert agent.messages[0]["content"][0]["text"] == "Hello"
181-
184+
182185
# Test string input (shorthand)
183186
await agent.send("World")
184187
assert len(agent.messages) == 2
185188
assert agent.messages[1]["content"][0]["text"] == "World"
186-
189+
187190
# Test audio input (doesn't add to messages)
188191
audio_input = BidiAudioInputEvent(
189192
audio="dGVzdA==", # base64 "test"
190193
format="pcm",
191194
sample_rate=16000,
192-
channels=1
195+
channels=1,
193196
)
194197
await agent.send(audio_input)
195198
assert len(agent.messages) == 2 # Still 2, audio doesn't add
196-
199+
197200
# Test concurrent sends
198-
sends = [
199-
agent.send(BidiTextInputEvent(text=f"Message {i}", role="user"))
200-
for i in range(3)
201-
]
201+
sends = [agent.send(BidiTextInputEvent(text=f"Message {i}", role="user")) for i in range(3)]
202202
await asyncio.gather(*sends)
203203
assert len(agent.messages) == 5 # 2 + 3 new messages
204204

205+
205206
@pytest.mark.asyncio
206207
async def test_bidi_agent_receive_events_from_model(agent):
207208
"""Test receiving events from model."""
208209
# Configure mock model to yield events
209210
events = [
210-
BidiAudioStreamEvent(
211-
audio="dGVzdA==",
212-
format="pcm",
213-
sample_rate=24000,
214-
channels=1
215-
),
211+
BidiAudioStreamEvent(audio="dGVzdA==", format="pcm", sample_rate=24000, channels=1),
216212
BidiTranscriptStreamEvent(
217213
text="Hello world",
218214
role="assistant",
219215
is_final=True,
220216
delta={"text": "Hello world"},
221-
current_transcript="Hello world"
222-
)
217+
current_transcript="Hello world",
218+
),
223219
]
224220
agent.model.set_events(events)
225-
221+
226222
await agent.start()
227-
223+
228224
received_events = []
229225
async for event in agent.receive():
230226
received_events.append(event)
231227
if len(received_events) >= 4: # Stop after getting expected events
232228
break
233-
229+
234230
# Verify event types and order
235231
assert len(received_events) >= 3
236232
assert isinstance(received_events[0], BidiConnectionStartEvent)
237233
assert isinstance(received_events[1], BidiAudioStreamEvent)
238234
assert isinstance(received_events[2], BidiTranscriptStreamEvent)
239-
235+
240236
# Test empty events
241237
agent.model.set_events([])
242238
await agent.stop()
243239
await agent.start()
244-
240+
245241
empty_events = []
246242
async for event in agent.receive():
247243
empty_events.append(event)
248244
if len(empty_events) >= 2:
249245
break
250-
246+
251247
assert len(empty_events) >= 1
252248
assert isinstance(empty_events[0], BidiConnectionStartEvent)
253249

250+
254251
def test_bidi_agent_tool_integration(agent, mock_tool_registry):
255252
"""Test agent tool integration and properties."""
256253
# Test tool property access
257-
assert hasattr(agent, 'tool')
254+
assert hasattr(agent, "tool")
258255
assert agent.tool is not None
259256
assert agent.tool == agent._tool_caller
260-
257+
261258
# Test tool names property
262-
mock_tool_registry.get_all_tools_config.return_value = {
263-
"calculator": {},
264-
"weather": {}
265-
}
266-
259+
mock_tool_registry.get_all_tools_config.return_value = {"calculator": {}, "weather": {}}
260+
267261
tool_names = agent.tool_names
268262
assert isinstance(tool_names, list)
269263
assert len(tool_names) == 2
270264
assert "calculator" in tool_names
271265
assert "weather" in tool_names
272266

267+
273268
@pytest.mark.asyncio
274269
async def test_bidi_agent_send_receive_error_before_start(agent):
275270
"""Test error handling in various scenarios."""
276271
# Test send before start
277272
with pytest.raises(RuntimeError, match="call start before"):
278273
await agent.send(BidiTextInputEvent(text="Hello", role="user"))
279-
274+
280275
# Test receive before start
281276
with pytest.raises(RuntimeError, match="call start before"):
282-
async for event in agent.receive():
277+
async for _ in agent.receive():
283278
pass
284-
279+
285280
# Test send after stop
286281
await agent.start()
287282
await agent.stop()
288283
with pytest.raises(RuntimeError, match="call start before"):
289284
await agent.send(BidiTextInputEvent(text="Hello", role="user"))
290-
285+
291286
# Test receive after stop
292287
with pytest.raises(RuntimeError, match="call start before"):
293-
async for event in agent.receive():
288+
async for _ in agent.receive():
294289
pass
295290

296291

@@ -301,43 +296,44 @@ async def test_bidi_agent_start_receive_propagates_model_errors():
301296
mock_model = MockBidiModel()
302297
mock_model.start = unittest.mock.AsyncMock(side_effect=Exception("Connection failed"))
303298
error_agent = BidiAgent(model=mock_model)
304-
299+
305300
with pytest.raises(Exception, match="Connection failed"):
306301
await error_agent.start()
307-
302+
308303
# Test model receive error
309304
mock_model2 = MockBidiModel()
310305
agent2 = BidiAgent(model=mock_model2)
311306
await agent2.start()
312-
307+
313308
async def failing_receive():
314309
yield BidiConnectionStartEvent(connection_id="test", model="test-model")
315310
raise Exception("Receive failed")
316-
311+
317312
agent2.model.receive = failing_receive
318313
with pytest.raises(Exception, match="Receive failed"):
319-
async for event in agent2.receive():
314+
async for _ in agent2.receive():
320315
pass
321316

317+
322318
@pytest.mark.asyncio
323319
async def test_bidi_agent_state_consistency(agent):
324320
"""Test that agent state remains consistent across operations."""
325321
# Initial state
326322
assert not agent._started
327323
assert agent.model._connection_id is None
328-
324+
329325
# Start
330326
await agent.start()
331327
assert agent._started
332328
assert agent.model._connection_id is not None
333329
connection_id = agent.model._connection_id
334-
330+
335331
# Send operations shouldn't change connection state
336332
await agent.send(BidiTextInputEvent(text="Hello", role="user"))
337333
assert agent._started
338334
assert agent.model._connection_id == connection_id
339-
335+
340336
# Stop
341337
await agent.stop()
342338
assert not agent._started
343-
assert agent.model._connection_id is None
339+
assert agent.model._connection_id is None

0 commit comments

Comments
 (0)