Skip to content

Commit 75cc8af

Browse files
committed
task group
1 parent db934ac commit 75cc8af

File tree

4 files changed

+127
-19
lines changed

4 files changed

+127
-19
lines changed

src/strands/experimental/bidi/_async/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22

33
from typing import Awaitable, Callable
44

5+
from ._task_group import _TaskGroup
56
from ._task_pool import _TaskPool
67

7-
__all__ = ["_TaskPool"]
8+
__all__ = ["_TaskGroup", "_TaskPool"]
89

910

1011
async def stop_all(*funcs: Callable[..., Awaitable[None]]) -> None:
@@ -28,6 +29,6 @@ async def stop_all(*funcs: Callable[..., Awaitable[None]]) -> None:
2829
if exceptions:
2930
exceptions.append(RuntimeError("failed stop sequence"))
3031
for i in range(1, len(exceptions)):
31-
exceptions[i].__cause__ = exceptions[i - 1]
32+
exceptions[i].__context__ = exceptions[i - 1]
3233

3334
raise exceptions[-1]
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
"""Manage a group of async tasks.
2+
3+
This is intended to mimic the behaviors of asyncio.TaskGroup released in Python 3.11.
4+
5+
- Docs: https://docs.python.org/3/library/asyncio-task.html#task-groups
6+
"""
7+
8+
import asyncio
9+
from typing import Any, Coroutine
10+
11+
12+
class _TaskGroup:
13+
"""Implementation of asyncio.TaskGroup for use in Python 3.10.
14+
15+
Attributes:
16+
_tasks: List of tasks in group.
17+
"""
18+
19+
_tasks: list[asyncio.Task]
20+
21+
def create_task(self, coro: Coroutine[Any, Any, Any]) -> asyncio.Task:
22+
"""Create an async task and add to group.
23+
24+
Returns:
25+
The created task.
26+
"""
27+
task = asyncio.create_task(coro)
28+
self._tasks.append(task)
29+
return task
30+
31+
async def __aenter__(self) -> "_TaskGroup":
32+
"""Setup self managed task group context."""
33+
self._tasks = []
34+
return self
35+
36+
async def __aexit__(self, *_: Any) -> None:
37+
"""Execute tasks in group.
38+
39+
The following execution rules are enforced:
40+
- The context stops executing all tasks if at least one task raises an Exception or the context is cancelled.
41+
- The context re-raises Exceptions to the caller.
42+
- The context re-raises CancelledErrors to the caller only if the context itself was cancelled.
43+
"""
44+
try:
45+
await asyncio.gather(*self._tasks)
46+
47+
except (Exception, asyncio.CancelledError) as error:
48+
for task in self._tasks:
49+
task.cancel()
50+
51+
await asyncio.gather(*self._tasks, return_exceptions=True)
52+
53+
if not isinstance(error, asyncio.CancelledError):
54+
raise
55+
56+
context_task = asyncio.current_task()
57+
if context_task and context_task.cancelling() > 0: # context itself was cancelled
58+
raise
59+
60+
finally:
61+
self._tasks = []

src/strands/experimental/bidi/agent/agent.py

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from ....types.tools import AgentTool
3131
from ...hooks.events import BidiAgentInitializedEvent, BidiMessageAddedEvent
3232
from ...tools import ToolProvider
33-
from .._async import stop_all
33+
from .._async import _TaskGroup, stop_all
3434
from ..models.model import BidiModel
3535
from ..models.nova_sonic import BidiNovaSonicModel
3636
from ..types.agent import BidiAgentInput
@@ -390,22 +390,9 @@ async def run_outputs(inputs_task: asyncio.Task) -> None:
390390
for start in [*input_starts, *output_starts]:
391391
await start(self)
392392

393-
inputs_task = asyncio.create_task(run_inputs())
394-
outputs_task = asyncio.create_task(run_outputs(inputs_task))
395-
396-
try:
397-
await asyncio.gather(inputs_task, outputs_task)
398-
except (Exception, asyncio.CancelledError) as error:
399-
inputs_task.cancel()
400-
outputs_task.cancel()
401-
await asyncio.gather(inputs_task, outputs_task, return_exceptions=True)
402-
403-
if not isinstance(error, asyncio.CancelledError):
404-
raise
405-
406-
run_task = asyncio.current_task()
407-
if run_task and run_task.cancelling() > 0: # externally cancelled
408-
raise
393+
async with _TaskGroup() as task_group:
394+
inputs_task = task_group.create_task(run_inputs())
395+
task_group.create_task(run_outputs(inputs_task))
409396

410397
finally:
411398
input_stops = [input_.stop for input_ in inputs if isinstance(input_, BidiInput)]
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import asyncio
2+
import unittest.mock
3+
4+
import pytest
5+
6+
from strands.experimental.bidi._async._task_group import _TaskGroup
7+
8+
9+
@pytest.mark.asyncio
10+
async def test_task_group__aexit__():
11+
coro = unittest.mock.AsyncMock()
12+
13+
async with _TaskGroup() as task_group:
14+
task_group.create_task(coro())
15+
16+
coro.assert_called_once()
17+
18+
19+
@pytest.mark.asyncio
20+
async def test_task_group__aexit__exception():
21+
wait_event = asyncio.Event()
22+
async def wait():
23+
await wait_event.wait()
24+
25+
async def fail():
26+
raise ValueError("test error")
27+
28+
with pytest.raises(ValueError, match="test error"):
29+
async with _TaskGroup() as task_group:
30+
wait_task = task_group.create_task(wait())
31+
fail_task = task_group.create_task(fail())
32+
33+
assert wait_task.cancelled()
34+
assert not fail_task.cancelled()
35+
36+
37+
@pytest.mark.asyncio
38+
async def test_task_group__aexit__cancelled():
39+
wait_event = asyncio.Event()
40+
async def wait():
41+
await wait_event.wait()
42+
43+
tasks = []
44+
45+
run_event = asyncio.Event()
46+
async def run():
47+
async with _TaskGroup() as task_group:
48+
tasks.append(task_group.create_task(wait()))
49+
run_event.set()
50+
51+
run_task = asyncio.create_task(run())
52+
await run_event.wait()
53+
run_task.cancel()
54+
55+
with pytest.raises(asyncio.CancelledError):
56+
await run_task
57+
58+
wait_task = tasks[0]
59+
assert wait_task.cancelled()

0 commit comments

Comments
 (0)