Skip to content

Commit ada9942

Browse files
authored
Merge pull request #131 from stealthrocket/any-all-race
Concurrency primitives: any, all, race
2 parents e3e4b2a + a157ce1 commit ada9942

File tree

4 files changed

+532
-40
lines changed

4 files changed

+532
-40
lines changed

src/dispatch/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from __future__ import annotations
44

55
import dispatch.integrations
6-
from dispatch.coroutine import call, gather
6+
from dispatch.coroutine import all, any, call, gather, race
77
from dispatch.function import DEFAULT_API_URL, Client
88
from dispatch.id import DispatchID
99
from dispatch.proto import Call, Error, Input, Output
@@ -20,4 +20,7 @@
2020
"Status",
2121
"call",
2222
"gather",
23+
"all",
24+
"any",
25+
"race",
2326
]

src/dispatch/coroutine.py

Lines changed: 53 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,60 @@ def call(call: Call) -> Any:
1717
@coroutine
1818
@durable
1919
def gather(*awaitables: Awaitable[Any]) -> list[Any]: # type: ignore[misc]
20-
"""Concurrently run a set of coroutines and block until all
21-
results are available. If any coroutine fails with an uncaught
22-
exception, it will be re-raised when awaiting a result here."""
23-
return (yield Gather(awaitables))
20+
"""Alias for all."""
21+
return all(*awaitables)
22+
23+
24+
@coroutine
25+
@durable
26+
def all(*awaitables: Awaitable[Any]) -> list[Any]: # type: ignore[misc]
27+
"""Concurrently run a set of coroutines, blocking until all coroutines
28+
return or any coroutine raises an error. If any coroutine fails with an
29+
uncaught exception, the exception will be re-raised here."""
30+
return (yield AllDirective(awaitables))
31+
32+
33+
@coroutine
34+
@durable
35+
def any(*awaitables: Awaitable[Any]) -> list[Any]: # type: ignore[misc]
36+
"""Concurrently run a set of coroutines, blocking until any coroutine
37+
returns or all coroutines raises an error. If all coroutines fail with
38+
uncaught exceptions, the exception(s) will be re-raised here."""
39+
return (yield AnyDirective(awaitables))
40+
41+
42+
@coroutine
43+
@durable
44+
def race(*awaitables: Awaitable[Any]) -> list[Any]: # type: ignore[misc]
45+
"""Concurrently run a set of coroutines, blocking until any coroutine
46+
returns or raises an error. If any coroutine fails with an uncaught
47+
exception, the exception will be re-raised here."""
48+
return (yield RaceDirective(awaitables))
2449

2550

2651
@dataclass(slots=True)
27-
class Gather:
52+
class AllDirective:
2853
awaitables: tuple[Awaitable[Any], ...]
54+
55+
56+
@dataclass(slots=True)
57+
class AnyDirective:
58+
awaitables: tuple[Awaitable[Any], ...]
59+
60+
61+
@dataclass(slots=True)
62+
class RaceDirective:
63+
awaitables: tuple[Awaitable[Any], ...]
64+
65+
66+
class AnyException(RuntimeError):
67+
"""Error indicating that all coroutines passed to any() failed
68+
with an exception."""
69+
70+
__slots__ = ("exceptions",)
71+
72+
def __init__(self, exceptions: list[Exception]):
73+
self.exceptions = exceptions
74+
75+
def __str__(self):
76+
return f"{len(self.exceptions)} coroutine(s) failed with an exception"

src/dispatch/scheduler.py

Lines changed: 159 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import logging
22
import pickle
33
import sys
4-
from dataclasses import dataclass
5-
from typing import Any, Callable, Protocol, TypeAlias
4+
from dataclasses import dataclass, field
5+
from typing import Any, Awaitable, Callable, Protocol, TypeAlias
66

7-
from dispatch.coroutine import Gather
7+
from dispatch.coroutine import AllDirective, AnyDirective, AnyException, RaceDirective
88
from dispatch.error import IncompatibleStateError
99
from dispatch.experimental.durable.function import DurableCoroutine, DurableGenerator
1010
from dispatch.proto import Call, Error, Input, Output
@@ -73,17 +73,18 @@ def error(self) -> Exception | None:
7373
return self.first_error
7474

7575
def value(self) -> Any:
76+
assert self.first_error is None
7677
assert self.result is not None
7778
return self.result.value
7879

7980

8081
@dataclass(slots=True)
81-
class GatherFuture:
82-
"""A future result of a dispatch.coroutine.gather() operation."""
82+
class AllFuture:
83+
"""A future result of a dispatch.coroutine.all() operation."""
8384

84-
order: list[CoroutineID]
85-
waiting: set[CoroutineID]
86-
results: dict[CoroutineID, CoroutineResult]
85+
order: list[CoroutineID] = field(default_factory=list)
86+
waiting: set[CoroutineID] = field(default_factory=set)
87+
results: dict[CoroutineID, CoroutineResult] = field(default_factory=dict)
8788
first_error: Exception | None = None
8889

8990
def add_result(self, result: CallResult | CoroutineResult):
@@ -94,13 +95,15 @@ def add_result(self, result: CallResult | CoroutineResult):
9495
except KeyError:
9596
return
9697

97-
if result.error is not None and self.first_error is None:
98-
self.first_error = result.error
98+
if result.error is not None:
99+
if self.first_error is None:
100+
self.first_error = result.error
101+
return
99102

100103
self.results[result.coroutine_id] = result
101104

102105
def add_error(self, error: Exception):
103-
if self.first_error is not None:
106+
if self.first_error is None:
104107
self.first_error = error
105108

106109
def ready(self) -> bool:
@@ -113,9 +116,108 @@ def error(self) -> Exception | None:
113116
def value(self) -> list[Any]:
114117
assert self.ready()
115118
assert len(self.waiting) == 0
119+
assert self.first_error is None
116120
return [self.results[id].value for id in self.order]
117121

118122

123+
@dataclass(slots=True)
124+
class AnyFuture:
125+
"""A future result of a dispatch.coroutine.any() operation."""
126+
127+
order: list[CoroutineID] = field(default_factory=list)
128+
waiting: set[CoroutineID] = field(default_factory=set)
129+
first_result: CoroutineResult | None = None
130+
errors: dict[CoroutineID, Exception] = field(default_factory=dict)
131+
generic_error: Exception | None = None
132+
133+
def add_result(self, result: CallResult | CoroutineResult):
134+
assert isinstance(result, CoroutineResult)
135+
136+
try:
137+
self.waiting.remove(result.coroutine_id)
138+
except KeyError:
139+
return
140+
141+
if result.error is None:
142+
if self.first_result is None:
143+
self.first_result = result
144+
return
145+
146+
self.errors[result.coroutine_id] = result.error
147+
148+
def add_error(self, error: Exception):
149+
if self.generic_error is None:
150+
self.generic_error = error
151+
152+
def ready(self) -> bool:
153+
return (
154+
self.generic_error is not None
155+
or self.first_result is not None
156+
or len(self.waiting) == 0
157+
)
158+
159+
def error(self) -> Exception | None:
160+
assert self.ready()
161+
if self.generic_error is not None:
162+
return self.generic_error
163+
if self.first_result is not None or len(self.errors) == 0:
164+
return None
165+
match len(self.errors):
166+
case 0:
167+
return None
168+
case 1:
169+
return self.errors[self.order[0]]
170+
case _:
171+
return AnyException([self.errors[id] for id in self.order])
172+
173+
def value(self) -> Any:
174+
assert self.ready()
175+
if len(self.order) == 0:
176+
return None
177+
assert self.first_result is not None
178+
return self.first_result.value
179+
180+
181+
@dataclass(slots=True)
182+
class RaceFuture:
183+
"""A future result of a dispatch.coroutine.race() operation."""
184+
185+
waiting: set[CoroutineID] = field(default_factory=set)
186+
first_result: CoroutineResult | None = None
187+
first_error: Exception | None = None
188+
189+
def add_result(self, result: CallResult | CoroutineResult):
190+
assert isinstance(result, CoroutineResult)
191+
192+
if result.error is not None:
193+
if self.first_error is None:
194+
self.first_error = result.error
195+
else:
196+
if self.first_result is None:
197+
self.first_result = result
198+
199+
self.waiting.remove(result.coroutine_id)
200+
201+
def add_error(self, error: Exception):
202+
if self.first_error is None:
203+
self.first_error = error
204+
205+
def ready(self) -> bool:
206+
return (
207+
self.first_error is not None
208+
or self.first_result is not None
209+
or len(self.waiting) == 0
210+
)
211+
212+
def error(self) -> Exception | None:
213+
assert self.ready()
214+
return self.first_error
215+
216+
def value(self) -> Any:
217+
assert self.first_error is None
218+
return self.first_result.value if self.first_result else None
219+
220+
119221
@dataclass(slots=True)
120222
class Coroutine:
121223
"""An in-flight coroutine."""
@@ -386,30 +488,35 @@ def _run(self, input: Input) -> Output:
386488
state.prev_callers.append(coroutine)
387489
state.outstanding_calls += 1
388490

389-
case Gather():
390-
gather = coroutine_yield
391-
392-
children = []
393-
for awaitable in gather.awaitables:
394-
g = awaitable.__await__()
395-
if not isinstance(g, DurableGenerator):
396-
raise ValueError(
397-
"gather awaitable is not a @dispatch.function"
398-
)
399-
child_id = state.next_coroutine_id
400-
state.next_coroutine_id += 1
401-
child = Coroutine(
402-
id=child_id, parent_id=coroutine.id, coroutine=g
403-
)
404-
logger.debug("enqueuing %s for %s", child, coroutine)
405-
children.append(child)
491+
case AllDirective():
492+
children = spawn_children(
493+
state, coroutine, coroutine_yield.awaitables
494+
)
406495

407-
# Prepend children to get a depth-first traversal of coroutines.
408-
state.ready = children + state.ready
496+
child_ids = [child.id for child in children]
497+
coroutine.result = AllFuture(
498+
order=child_ids, waiting=set(child_ids)
499+
)
500+
state.suspended[coroutine.id] = coroutine
501+
502+
case AnyDirective():
503+
children = spawn_children(
504+
state, coroutine, coroutine_yield.awaitables
505+
)
409506

410507
child_ids = [child.id for child in children]
411-
coroutine.result = GatherFuture(
412-
order=child_ids, waiting=set(child_ids), results={}
508+
coroutine.result = AnyFuture(
509+
order=child_ids, waiting=set(child_ids)
510+
)
511+
state.suspended[coroutine.id] = coroutine
512+
513+
case RaceDirective():
514+
children = spawn_children(
515+
state, coroutine, coroutine_yield.awaitables
516+
)
517+
518+
coroutine.result = RaceFuture(
519+
waiting={child.id for child in children}
413520
)
414521
state.suspended[coroutine.id] = coroutine
415522

@@ -446,6 +553,26 @@ def _run(self, input: Input) -> Output:
446553
)
447554

448555

556+
def spawn_children(
557+
state: State, coroutine: Coroutine, awaitables: tuple[Awaitable[Any], ...]
558+
) -> list[Coroutine]:
559+
children = []
560+
for awaitable in awaitables:
561+
g = awaitable.__await__()
562+
if not isinstance(g, DurableGenerator):
563+
raise TypeError("awaitable is not a @dispatch.function")
564+
child_id = state.next_coroutine_id
565+
state.next_coroutine_id += 1
566+
child = Coroutine(id=child_id, parent_id=coroutine.id, coroutine=g)
567+
logger.debug("enqueuing %s for %s", child, coroutine)
568+
children.append(child)
569+
570+
# Prepend children to get a depth-first traversal of coroutines.
571+
state.ready = children + state.ready
572+
573+
return children
574+
575+
449576
def correlation_id(coroutine_id: CoroutineID, call_id: CallID) -> CorrelationID:
450577
return coroutine_id << 32 | call_id
451578

0 commit comments

Comments
 (0)