Skip to content

Commit 0ad0932

Browse files
committed
Add any()
1 parent 49f605b commit 0ad0932

File tree

3 files changed

+333
-35
lines changed

3 files changed

+333
-35
lines changed

src/dispatch/coroutine.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,39 @@ def gather(*awaitables: Awaitable[Any]) -> list[Any]: # type: ignore[misc]
2424
@coroutine
2525
@durable
2626
def all(*awaitables: Awaitable[Any]) -> list[Any]: # type: ignore[misc]
27-
"""Concurrently run a set of coroutines and block until all
28-
results are available. If any coroutine fails with an uncaught
29-
exception, it will be re-raised when awaiting a result here."""
30-
return (yield All(awaitables))
27+
"""Concurrently run a set of coroutines, blocking until all coroutines
28+
return or any coroutine raises an error. If a coroutine fails with an
29+
uncaught exception, the exception will be re-raised here."""
30+
return (yield AllDirective(awaitables))
31+
32+
33+
@coroutine
34+
@durable
35+
def any(*awaitables: Awaitable[Any]) -> list[Any]: # type: ignore[misc]
36+
"""Concurrently run a set of coroutines, blocking until any coroutine
37+
returns or all coroutines raises an error. If all coroutines fail with
38+
uncaught exceptions, an AnyException will be re-raised here."""
39+
return (yield AnyDirective(awaitables))
3140

3241

3342
@dataclass(slots=True)
34-
class All:
43+
class AllDirective:
3544
awaitables: tuple[Awaitable[Any], ...]
45+
46+
47+
@dataclass(slots=True)
48+
class AnyDirective:
49+
awaitables: tuple[Awaitable[Any], ...]
50+
51+
52+
class AnyException(RuntimeError):
53+
"""Error indicating that all coroutines passed to any() failed
54+
with an exception."""
55+
56+
__slots__ = ("exceptions",)
57+
58+
def __init__(self, exceptions: list[Exception]):
59+
self.exceptions = exceptions
60+
61+
def __str__(self):
62+
return f"{len(self.exceptions)} coroutine(s) failed with an exception"

src/dispatch/scheduler.py

Lines changed: 107 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import logging
22
import pickle
33
import sys
4-
from dataclasses import dataclass
5-
from typing import Any, Callable, Protocol, TypeAlias
4+
from dataclasses import dataclass, field
5+
from typing import Any, Awaitable, Callable, Protocol, TypeAlias
66

7-
from dispatch.coroutine import All
7+
from dispatch.coroutine import AllDirective, AnyDirective, AnyException
88
from dispatch.error import IncompatibleStateError
99
from dispatch.experimental.durable.function import DurableCoroutine, DurableGenerator
1010
from dispatch.proto import Call, Error, Input, Output
@@ -73,6 +73,7 @@ def error(self) -> Exception | None:
7373
return self.first_error
7474

7575
def value(self) -> Any:
76+
assert self.first_error is None
7677
assert self.result is not None
7778
return self.result.value
7879

@@ -81,9 +82,9 @@ def value(self) -> Any:
8182
class AllFuture:
8283
"""A future result of a dispatch.coroutine.all() operation."""
8384

84-
order: list[CoroutineID]
85-
waiting: set[CoroutineID]
86-
results: dict[CoroutineID, CoroutineResult]
85+
order: list[CoroutineID] = field(default_factory=list)
86+
waiting: set[CoroutineID] = field(default_factory=set)
87+
results: dict[CoroutineID, CoroutineResult] = field(default_factory=dict)
8788
first_error: Exception | None = None
8889

8990
def add_result(self, result: CallResult | CoroutineResult):
@@ -94,13 +95,15 @@ def add_result(self, result: CallResult | CoroutineResult):
9495
except KeyError:
9596
return
9697

97-
if result.error is not None and self.first_error is None:
98-
self.first_error = result.error
98+
if result.error is not None:
99+
if self.first_error is None:
100+
self.first_error = result.error
101+
return
99102

100103
self.results[result.coroutine_id] = result
101104

102105
def add_error(self, error: Exception):
103-
if self.first_error is not None:
106+
if self.first_error is None:
104107
self.first_error = error
105108

106109
def ready(self) -> bool:
@@ -113,9 +116,68 @@ def error(self) -> Exception | None:
113116
def value(self) -> list[Any]:
114117
assert self.ready()
115118
assert len(self.waiting) == 0
119+
assert self.first_error is None
116120
return [self.results[id].value for id in self.order]
117121

118122

123+
@dataclass(slots=True)
124+
class AnyFuture:
125+
"""A future result of a dispatch.coroutine.any() operation."""
126+
127+
order: list[CoroutineID] = field(default_factory=list)
128+
waiting: set[CoroutineID] = field(default_factory=set)
129+
first_result: CoroutineResult | None = None
130+
errors: dict[CoroutineID, Exception] = field(default_factory=dict)
131+
generic_error: Exception | None = None
132+
133+
def add_result(self, result: CallResult | CoroutineResult):
134+
assert isinstance(result, CoroutineResult)
135+
136+
try:
137+
self.waiting.remove(result.coroutine_id)
138+
except KeyError:
139+
return
140+
141+
if result.error is None:
142+
if self.first_result is None:
143+
self.first_result = result
144+
return
145+
146+
self.errors[result.coroutine_id] = result.error
147+
148+
def add_error(self, error: Exception):
149+
if self.generic_error is None:
150+
self.generic_error = error
151+
152+
def ready(self) -> bool:
153+
return (
154+
self.generic_error is not None
155+
or self.first_result is not None
156+
or len(self.waiting) == 0
157+
)
158+
159+
def error(self) -> Exception | None:
160+
assert self.ready()
161+
if self.generic_error is not None:
162+
return self.generic_error
163+
if self.first_result is not None or len(self.errors) == 0:
164+
return None
165+
match len(self.errors):
166+
case 0:
167+
return None
168+
case 1:
169+
return self.errors[self.order[0]]
170+
case _:
171+
return AnyException([self.errors[id] for id in self.order])
172+
173+
def value(self) -> Any:
174+
assert self.ready()
175+
if len(self.order) == 0:
176+
return None
177+
assert self.first_result is not None
178+
return self.first_result.value
179+
180+
119181
@dataclass(slots=True)
120182
class Coroutine:
121183
"""An in-flight coroutine."""
@@ -386,28 +448,25 @@ def _run(self, input: Input) -> Output:
386448
state.prev_callers.append(coroutine)
387449
state.outstanding_calls += 1
388450

389-
case All():
390-
children = []
391-
for awaitable in coroutine_yield.awaitables:
392-
g = awaitable.__await__()
393-
if not isinstance(g, DurableGenerator):
394-
raise ValueError(
395-
"gather awaitable is not a @dispatch.function"
396-
)
397-
child_id = state.next_coroutine_id
398-
state.next_coroutine_id += 1
399-
child = Coroutine(
400-
id=child_id, parent_id=coroutine.id, coroutine=g
401-
)
402-
logger.debug("enqueuing %s for %s", child, coroutine)
403-
children.append(child)
404-
405-
# Prepend children to get a depth-first traversal of coroutines.
406-
state.ready = children + state.ready
451+
case AllDirective():
452+
children = spawn_children(
453+
state, coroutine, coroutine_yield.awaitables
454+
)
407455

408456
child_ids = [child.id for child in children]
409457
coroutine.result = AllFuture(
410-
order=child_ids, waiting=set(child_ids), results={}
458+
order=child_ids, waiting=set(child_ids)
459+
)
460+
state.suspended[coroutine.id] = coroutine
461+
462+
case AnyDirective():
463+
children = spawn_children(
464+
state, coroutine, coroutine_yield.awaitables
465+
)
466+
467+
child_ids = [child.id for child in children]
468+
coroutine.result = AnyFuture(
469+
order=child_ids, waiting=set(child_ids)
411470
)
412471
state.suspended[coroutine.id] = coroutine
413472

@@ -444,6 +503,26 @@ def _run(self, input: Input) -> Output:
444503
)
445504

446505

506+
def spawn_children(
507+
state: State, coroutine: Coroutine, awaitables: tuple[Awaitable[Any], ...]
508+
) -> list[Coroutine]:
509+
children = []
510+
for awaitable in awaitables:
511+
g = awaitable.__await__()
512+
if not isinstance(g, DurableGenerator):
513+
raise TypeError("awaitable is not a @dispatch.function")
514+
child_id = state.next_coroutine_id
515+
state.next_coroutine_id += 1
516+
child = Coroutine(id=child_id, parent_id=coroutine.id, coroutine=g)
517+
logger.debug("enqueuing %s for %s", child, coroutine)
518+
children.append(child)
519+
520+
# Prepend children to get a depth-first traversal of coroutines.
521+
state.ready = children + state.ready
522+
523+
return children
524+
525+
447526
def correlation_id(coroutine_id: CoroutineID, call_id: CallID) -> CorrelationID:
448527
return coroutine_id << 32 | call_id
449528

0 commit comments

Comments
 (0)