Skip to content

Commit a157ce1

Browse files
committed
Add race()
1 parent 0ad0932 commit a157ce1

File tree

4 files changed

+194
-6
lines changed

4 files changed

+194
-6
lines changed

src/dispatch/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from __future__ import annotations
44

55
import dispatch.integrations
6-
from dispatch.coroutine import all, call, gather
6+
from dispatch.coroutine import all, any, call, gather, race
77
from dispatch.function import DEFAULT_API_URL, Client
88
from dispatch.id import DispatchID
99
from dispatch.proto import Call, Error, Input, Output
@@ -21,4 +21,6 @@
2121
"call",
2222
"gather",
2323
"all",
24+
"any",
25+
"race",
2426
]

src/dispatch/coroutine.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def gather(*awaitables: Awaitable[Any]) -> list[Any]: # type: ignore[misc]
2525
@durable
2626
def all(*awaitables: Awaitable[Any]) -> list[Any]: # type: ignore[misc]
2727
"""Concurrently run a set of coroutines, blocking until all coroutines
28-
return or any coroutine raises an error. If a coroutine fails with an
28+
return or any coroutine raises an error. If any coroutine fails with an
2929
uncaught exception, the exception will be re-raised here."""
3030
return (yield AllDirective(awaitables))
3131

@@ -35,10 +35,19 @@ def all(*awaitables: Awaitable[Any]) -> list[Any]: # type: ignore[misc]
3535
def any(*awaitables: Awaitable[Any]) -> list[Any]: # type: ignore[misc]
3636
"""Concurrently run a set of coroutines, blocking until any coroutine
3737
returns or all coroutines raises an error. If all coroutines fail with
38-
uncaught exceptions, an AnyException will be re-raised here."""
38+
uncaught exceptions, the exception(s) will be re-raised here."""
3939
return (yield AnyDirective(awaitables))
4040

4141

42+
@coroutine
43+
@durable
44+
def race(*awaitables: Awaitable[Any]) -> list[Any]: # type: ignore[misc]
45+
"""Concurrently run a set of coroutines, blocking until any coroutine
46+
returns or raises an error. If any coroutine fails with an uncaught
47+
exception, the exception will be re-raised here."""
48+
return (yield RaceDirective(awaitables))
49+
50+
4251
@dataclass(slots=True)
4352
class AllDirective:
4453
awaitables: tuple[Awaitable[Any], ...]
@@ -49,6 +58,11 @@ class AnyDirective:
4958
awaitables: tuple[Awaitable[Any], ...]
5059

5160

61+
@dataclass(slots=True)
62+
class RaceDirective:
63+
awaitables: tuple[Awaitable[Any], ...]
64+
65+
5266
class AnyException(RuntimeError):
5367
"""Error indicating that all coroutines passed to any() failed
5468
with an exception."""

src/dispatch/scheduler.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from dataclasses import dataclass, field
55
from typing import Any, Awaitable, Callable, Protocol, TypeAlias
66

7-
from dispatch.coroutine import AllDirective, AnyDirective, AnyException
7+
from dispatch.coroutine import AllDirective, AnyDirective, AnyException, RaceDirective
88
from dispatch.error import IncompatibleStateError
99
from dispatch.experimental.durable.function import DurableCoroutine, DurableGenerator
1010
from dispatch.proto import Call, Error, Input, Output
@@ -178,6 +178,46 @@ def value(self) -> Any:
178178
return self.first_result.value
179179

180180

181+
@dataclass(slots=True)
182+
class RaceFuture:
183+
"""A future result of a dispatch.coroutine.race() operation."""
184+
185+
waiting: set[CoroutineID] = field(default_factory=set)
186+
first_result: CoroutineResult | None = None
187+
first_error: Exception | None = None
188+
189+
def add_result(self, result: CallResult | CoroutineResult):
190+
assert isinstance(result, CoroutineResult)
191+
192+
if result.error is not None:
193+
if self.first_error is None:
194+
self.first_error = result.error
195+
else:
196+
if self.first_result is None:
197+
self.first_result = result
198+
199+
self.waiting.remove(result.coroutine_id)
200+
201+
def add_error(self, error: Exception):
202+
if self.first_error is None:
203+
self.first_error = error
204+
205+
def ready(self) -> bool:
206+
return (
207+
self.first_error is not None
208+
or self.first_result is not None
209+
or len(self.waiting) == 0
210+
)
211+
212+
def error(self) -> Exception | None:
213+
assert self.ready()
214+
return self.first_error
215+
216+
def value(self) -> Any:
217+
assert self.first_error is None
218+
return self.first_result.value if self.first_result else None
219+
220+
181221
@dataclass(slots=True)
182222
class Coroutine:
183223
"""An in-flight coroutine."""
@@ -470,6 +510,16 @@ def _run(self, input: Input) -> Output:
470510
)
471511
state.suspended[coroutine.id] = coroutine
472512

513+
case RaceDirective():
514+
children = spawn_children(
515+
state, coroutine, coroutine_yield.awaitables
516+
)
517+
518+
coroutine.result = RaceFuture(
519+
waiting={child.id for child in children}
520+
)
521+
state.suspended[coroutine.id] = coroutine
522+
473523
case _:
474524
raise RuntimeError(
475525
f"coroutine unexpectedly yielded '{coroutine_yield}'"

tests/dispatch/test_scheduler.py

Lines changed: 124 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
11
import unittest
22
from typing import Any, Callable
33

4-
from dispatch.coroutine import AnyException, any, call, gather
4+
from dispatch.coroutine import AnyException, any, call, gather, race
55
from dispatch.experimental.durable import durable
66
from dispatch.proto import Call, CallResult, Error, Input, Output
77
from dispatch.proto import _any_unpickle as any_unpickle
8-
from dispatch.scheduler import AllFuture, AnyFuture, CoroutineResult, OneShotScheduler
8+
from dispatch.scheduler import (
9+
AllFuture,
10+
AnyFuture,
11+
CoroutineResult,
12+
OneShotScheduler,
13+
RaceFuture,
14+
)
915
from dispatch.sdk.v1 import call_pb2 as call_pb
1016
from dispatch.sdk.v1 import exit_pb2 as exit_pb
1117
from dispatch.sdk.v1 import poll_pb2 as poll_pb
@@ -21,6 +27,11 @@ async def call_any(*functions):
2127
return await any(*[call_one(function) for function in functions])
2228

2329

30+
@durable
31+
async def call_race(*functions):
32+
return await race(*[call_one(function) for function in functions])
33+
34+
2435
@durable
2536
async def call_concurrently(*functions):
2637
return await gather(*[call_one(function) for function in functions])
@@ -201,6 +212,37 @@ async def main():
201212
output, AnyException, "4 coroutine(s) failed with an exception"
202213
)
203214

215+
def test_resume_after_race_result(self):
216+
@durable
217+
async def main():
218+
return await call_race("a", "b", "c", "d")
219+
220+
output = self.start(main)
221+
calls = self.assert_poll_call_functions(output, ["a", "b", "c", "d"])
222+
223+
output = self.resume(
224+
main,
225+
output,
226+
[CallResult.from_value(23, correlation_id=calls[1].correlation_id)],
227+
)
228+
self.assert_exit_result_value(output, 23)
229+
230+
def test_resume_after_race_error(self):
231+
@durable
232+
async def main():
233+
return await call_race("a", "b", "c", "d")
234+
235+
output = self.start(main)
236+
calls = self.assert_poll_call_functions(output, ["a", "b", "c", "d"])
237+
238+
error = Error.from_exception(RuntimeError("oops"))
239+
output = self.resume(
240+
main,
241+
output,
242+
[CallResult.from_error(error, correlation_id=calls[2].correlation_id)],
243+
)
244+
self.assert_exit_result_error(output, RuntimeError, "oops")
245+
204246
def test_dag(self):
205247
@durable
206248
async def main():
@@ -600,3 +642,83 @@ def test_two_result_errors(self):
600642

601643
with self.assertRaises(AssertionError):
602644
future.value()
645+
646+
647+
class TestRaceFuture(unittest.TestCase):
648+
def test_empty(self):
649+
future = RaceFuture()
650+
651+
self.assertTrue(future.ready())
652+
self.assertIsNone(future.value())
653+
self.assertIsNone(future.error())
654+
655+
def test_one_result_value(self):
656+
future = RaceFuture(waiting={10})
657+
658+
self.assertFalse(future.ready())
659+
future.add_result(CoroutineResult(coroutine_id=10, value="foobar"))
660+
661+
self.assertTrue(future.ready())
662+
self.assertIsNone(future.error())
663+
self.assertEqual(future.value(), "foobar")
664+
665+
def test_one_result_error(self):
666+
future = RaceFuture(waiting={10})
667+
668+
self.assertFalse(future.ready())
669+
error = RuntimeError("oops")
670+
future.add_result(CoroutineResult(coroutine_id=10, error=error))
671+
672+
self.assertTrue(future.ready())
673+
self.assertIs(future.error(), error)
674+
675+
with self.assertRaises(AssertionError):
676+
future.value()
677+
678+
def test_one_generic_error(self):
679+
future = RaceFuture(waiting={10})
680+
681+
self.assertFalse(future.ready())
682+
error = RuntimeError("oops")
683+
future.add_error(error)
684+
685+
self.assertTrue(future.ready())
686+
self.assertIs(future.error(), error)
687+
688+
with self.assertRaises(AssertionError):
689+
future.value()
690+
691+
def test_two_result_values(self):
692+
future = RaceFuture(waiting={10, 20})
693+
694+
self.assertFalse(future.ready())
695+
696+
future.add_result(CoroutineResult(coroutine_id=20, value="bar"))
697+
self.assertTrue(future.ready())
698+
self.assertIsNone(future.error())
699+
self.assertEqual(future.value(), "bar")
700+
701+
future.add_result(CoroutineResult(coroutine_id=10, value="foo"))
702+
self.assertTrue(future.ready())
703+
self.assertIsNone(future.error())
704+
self.assertEqual(future.value(), "bar")
705+
706+
def test_two_result_errors(self):
707+
future = RaceFuture(waiting={10, 20})
708+
709+
self.assertFalse(future.ready())
710+
error1 = RuntimeError("oops")
711+
future.add_result(CoroutineResult(coroutine_id=10, error=error1))
712+
713+
self.assertTrue(future.ready())
714+
self.assertIs(future.error(), error1)
715+
716+
error2 = RuntimeError("oops2")
717+
future.add_result(CoroutineResult(coroutine_id=20, error=error2))
718+
self.assertIs(future.error(), error1)
719+
720+
future.add_error(error2)
721+
self.assertIs(future.error(), error1)
722+
723+
with self.assertRaises(AssertionError):
724+
future.value()

0 commit comments

Comments
 (0)