1
1
import unittest
2
2
from typing import Any , Callable
3
3
4
- from dispatch .coroutine import AnyException , any , call , gather
4
+ from dispatch .coroutine import AnyException , any , call , gather , race
5
5
from dispatch .experimental .durable import durable
6
6
from dispatch .proto import Call , CallResult , Error , Input , Output
7
7
from dispatch .proto import _any_unpickle as any_unpickle
8
- from dispatch .scheduler import AllFuture , AnyFuture , CoroutineResult , OneShotScheduler
8
+ from dispatch .scheduler import (
9
+ AllFuture ,
10
+ AnyFuture ,
11
+ CoroutineResult ,
12
+ OneShotScheduler ,
13
+ RaceFuture ,
14
+ )
9
15
from dispatch .sdk .v1 import call_pb2 as call_pb
10
16
from dispatch .sdk .v1 import exit_pb2 as exit_pb
11
17
from dispatch .sdk .v1 import poll_pb2 as poll_pb
@@ -21,6 +27,11 @@ async def call_any(*functions):
21
27
return await any (* [call_one (function ) for function in functions ])
22
28
23
29
30
+ @durable
31
+ async def call_race (* functions ):
32
+ return await race (* [call_one (function ) for function in functions ])
33
+
34
+
24
35
@durable
25
36
async def call_concurrently (* functions ):
26
37
return await gather (* [call_one (function ) for function in functions ])
@@ -201,6 +212,37 @@ async def main():
201
212
output , AnyException , "4 coroutine(s) failed with an exception"
202
213
)
203
214
215
+ def test_resume_after_race_result (self ):
216
+ @durable
217
+ async def main ():
218
+ return await call_race ("a" , "b" , "c" , "d" )
219
+
220
+ output = self .start (main )
221
+ calls = self .assert_poll_call_functions (output , ["a" , "b" , "c" , "d" ])
222
+
223
+ output = self .resume (
224
+ main ,
225
+ output ,
226
+ [CallResult .from_value (23 , correlation_id = calls [1 ].correlation_id )],
227
+ )
228
+ self .assert_exit_result_value (output , 23 )
229
+
230
+ def test_resume_after_race_error (self ):
231
+ @durable
232
+ async def main ():
233
+ return await call_race ("a" , "b" , "c" , "d" )
234
+
235
+ output = self .start (main )
236
+ calls = self .assert_poll_call_functions (output , ["a" , "b" , "c" , "d" ])
237
+
238
+ error = Error .from_exception (RuntimeError ("oops" ))
239
+ output = self .resume (
240
+ main ,
241
+ output ,
242
+ [CallResult .from_error (error , correlation_id = calls [2 ].correlation_id )],
243
+ )
244
+ self .assert_exit_result_error (output , RuntimeError , "oops" )
245
+
204
246
def test_dag (self ):
205
247
@durable
206
248
async def main ():
@@ -600,3 +642,83 @@ def test_two_result_errors(self):
600
642
601
643
with self .assertRaises (AssertionError ):
602
644
future .value ()
645
+
646
+
647
+ class TestRaceFuture (unittest .TestCase ):
648
+ def test_empty (self ):
649
+ future = RaceFuture ()
650
+
651
+ self .assertTrue (future .ready ())
652
+ self .assertIsNone (future .value ())
653
+ self .assertIsNone (future .error ())
654
+
655
+ def test_one_result_value (self ):
656
+ future = RaceFuture (waiting = {10 })
657
+
658
+ self .assertFalse (future .ready ())
659
+ future .add_result (CoroutineResult (coroutine_id = 10 , value = "foobar" ))
660
+
661
+ self .assertTrue (future .ready ())
662
+ self .assertIsNone (future .error ())
663
+ self .assertEqual (future .value (), "foobar" )
664
+
665
+ def test_one_result_error (self ):
666
+ future = RaceFuture (waiting = {10 })
667
+
668
+ self .assertFalse (future .ready ())
669
+ error = RuntimeError ("oops" )
670
+ future .add_result (CoroutineResult (coroutine_id = 10 , error = error ))
671
+
672
+ self .assertTrue (future .ready ())
673
+ self .assertIs (future .error (), error )
674
+
675
+ with self .assertRaises (AssertionError ):
676
+ future .value ()
677
+
678
+ def test_one_generic_error (self ):
679
+ future = RaceFuture (waiting = {10 })
680
+
681
+ self .assertFalse (future .ready ())
682
+ error = RuntimeError ("oops" )
683
+ future .add_error (error )
684
+
685
+ self .assertTrue (future .ready ())
686
+ self .assertIs (future .error (), error )
687
+
688
+ with self .assertRaises (AssertionError ):
689
+ future .value ()
690
+
691
+ def test_two_result_values (self ):
692
+ future = RaceFuture (waiting = {10 , 20 })
693
+
694
+ self .assertFalse (future .ready ())
695
+
696
+ future .add_result (CoroutineResult (coroutine_id = 20 , value = "bar" ))
697
+ self .assertTrue (future .ready ())
698
+ self .assertIsNone (future .error ())
699
+ self .assertEqual (future .value (), "bar" )
700
+
701
+ future .add_result (CoroutineResult (coroutine_id = 10 , value = "foo" ))
702
+ self .assertTrue (future .ready ())
703
+ self .assertIsNone (future .error ())
704
+ self .assertEqual (future .value (), "bar" )
705
+
706
+ def test_two_result_errors (self ):
707
+ future = RaceFuture (waiting = {10 , 20 })
708
+
709
+ self .assertFalse (future .ready ())
710
+ error1 = RuntimeError ("oops" )
711
+ future .add_result (CoroutineResult (coroutine_id = 10 , error = error1 ))
712
+
713
+ self .assertTrue (future .ready ())
714
+ self .assertIs (future .error (), error1 )
715
+
716
+ error2 = RuntimeError ("oops2" )
717
+ future .add_result (CoroutineResult (coroutine_id = 20 , error = error2 ))
718
+ self .assertIs (future .error (), error1 )
719
+
720
+ future .add_error (error2 )
721
+ self .assertIs (future .error (), error1 )
722
+
723
+ with self .assertRaises (AssertionError ):
724
+ future .value ()
0 commit comments