1
1
import logging
2
2
import pickle
3
3
import sys
4
- from dataclasses import dataclass
5
- from typing import Any , Callable , Protocol , TypeAlias
4
+ from dataclasses import dataclass , field
5
+ from typing import Any , Awaitable , Callable , Protocol , TypeAlias
6
6
7
- from dispatch .coroutine import Gather
7
+ from dispatch .coroutine import AllDirective , AnyDirective , AnyException , RaceDirective
8
8
from dispatch .error import IncompatibleStateError
9
9
from dispatch .experimental .durable .function import DurableCoroutine , DurableGenerator
10
10
from dispatch .proto import Call , Error , Input , Output
@@ -73,17 +73,18 @@ def error(self) -> Exception | None:
73
73
return self .first_error
74
74
75
75
def value (self ) -> Any :
76
+ assert self .first_error is None
76
77
assert self .result is not None
77
78
return self .result .value
78
79
79
80
80
81
@dataclass (slots = True )
81
- class GatherFuture :
82
- """A future result of a dispatch.coroutine.gather () operation."""
82
+ class AllFuture :
83
+ """A future result of a dispatch.coroutine.all () operation."""
83
84
84
- order : list [CoroutineID ]
85
- waiting : set [CoroutineID ]
86
- results : dict [CoroutineID , CoroutineResult ]
85
+ order : list [CoroutineID ] = field ( default_factory = list )
86
+ waiting : set [CoroutineID ] = field ( default_factory = set )
87
+ results : dict [CoroutineID , CoroutineResult ] = field ( default_factory = dict )
87
88
first_error : Exception | None = None
88
89
89
90
def add_result (self , result : CallResult | CoroutineResult ):
@@ -94,13 +95,15 @@ def add_result(self, result: CallResult | CoroutineResult):
94
95
except KeyError :
95
96
return
96
97
97
- if result .error is not None and self .first_error is None :
98
- self .first_error = result .error
98
+ if result .error is not None :
99
+ if self .first_error is None :
100
+ self .first_error = result .error
101
+ return
99
102
100
103
self .results [result .coroutine_id ] = result
101
104
102
105
def add_error (self , error : Exception ):
103
- if self .first_error is not None :
106
+ if self .first_error is None :
104
107
self .first_error = error
105
108
106
109
def ready (self ) -> bool :
@@ -113,9 +116,108 @@ def error(self) -> Exception | None:
113
116
def value (self ) -> list [Any ]:
114
117
assert self .ready ()
115
118
assert len (self .waiting ) == 0
119
+ assert self .first_error is None
116
120
return [self .results [id ].value for id in self .order ]
117
121
118
122
123
+ @dataclass (slots = True )
124
+ class AnyFuture :
125
+ """A future result of a dispatch.coroutine.any() operation."""
126
+
127
+ order : list [CoroutineID ] = field (default_factory = list )
128
+ waiting : set [CoroutineID ] = field (default_factory = set )
129
+ first_result : CoroutineResult | None = None
130
+ errors : dict [CoroutineID , Exception ] = field (default_factory = dict )
131
+ generic_error : Exception | None = None
132
+
133
+ def add_result (self , result : CallResult | CoroutineResult ):
134
+ assert isinstance (result , CoroutineResult )
135
+
136
+ try :
137
+ self .waiting .remove (result .coroutine_id )
138
+ except KeyError :
139
+ return
140
+
141
+ if result .error is None :
142
+ if self .first_result is None :
143
+ self .first_result = result
144
+ return
145
+
146
+ self .errors [result .coroutine_id ] = result .error
147
+
148
+ def add_error (self , error : Exception ):
149
+ if self .generic_error is None :
150
+ self .generic_error = error
151
+
152
+ def ready (self ) -> bool :
153
+ return (
154
+ self .generic_error is not None
155
+ or self .first_result is not None
156
+ or len (self .waiting ) == 0
157
+ )
158
+
159
+ def error (self ) -> Exception | None :
160
+ assert self .ready ()
161
+ if self .generic_error is not None :
162
+ return self .generic_error
163
+ if self .first_result is not None or len (self .errors ) == 0 :
164
+ return None
165
+ match len (self .errors ):
166
+ case 0 :
167
+ return None
168
+ case 1 :
169
+ return self .errors [self .order [0 ]]
170
+ case _:
171
+ return AnyException ([self .errors [id ] for id in self .order ])
172
+
173
+ def value (self ) -> Any :
174
+ assert self .ready ()
175
+ if len (self .order ) == 0 :
176
+ return None
177
+ assert self .first_result is not None
178
+ return self .first_result .value
179
+
180
+
181
+ @dataclass (slots = True )
182
+ class RaceFuture :
183
+ """A future result of a dispatch.coroutine.race() operation."""
184
+
185
+ waiting : set [CoroutineID ] = field (default_factory = set )
186
+ first_result : CoroutineResult | None = None
187
+ first_error : Exception | None = None
188
+
189
+ def add_result (self , result : CallResult | CoroutineResult ):
190
+ assert isinstance (result , CoroutineResult )
191
+
192
+ if result .error is not None :
193
+ if self .first_error is None :
194
+ self .first_error = result .error
195
+ else :
196
+ if self .first_result is None :
197
+ self .first_result = result
198
+
199
+ self .waiting .remove (result .coroutine_id )
200
+
201
+ def add_error (self , error : Exception ):
202
+ if self .first_error is None :
203
+ self .first_error = error
204
+
205
+ def ready (self ) -> bool :
206
+ return (
207
+ self .first_error is not None
208
+ or self .first_result is not None
209
+ or len (self .waiting ) == 0
210
+ )
211
+
212
+ def error (self ) -> Exception | None :
213
+ assert self .ready ()
214
+ return self .first_error
215
+
216
+ def value (self ) -> Any :
217
+ assert self .first_error is None
218
+ return self .first_result .value if self .first_result else None
219
+
220
+
119
221
@dataclass (slots = True )
120
222
class Coroutine :
121
223
"""An in-flight coroutine."""
@@ -386,30 +488,35 @@ def _run(self, input: Input) -> Output:
386
488
state .prev_callers .append (coroutine )
387
489
state .outstanding_calls += 1
388
490
389
- case Gather ():
390
- gather = coroutine_yield
391
-
392
- children = []
393
- for awaitable in gather .awaitables :
394
- g = awaitable .__await__ ()
395
- if not isinstance (g , DurableGenerator ):
396
- raise ValueError (
397
- "gather awaitable is not a @dispatch.function"
398
- )
399
- child_id = state .next_coroutine_id
400
- state .next_coroutine_id += 1
401
- child = Coroutine (
402
- id = child_id , parent_id = coroutine .id , coroutine = g
403
- )
404
- logger .debug ("enqueuing %s for %s" , child , coroutine )
405
- children .append (child )
491
+ case AllDirective ():
492
+ children = spawn_children (
493
+ state , coroutine , coroutine_yield .awaitables
494
+ )
406
495
407
- # Prepend children to get a depth-first traversal of coroutines.
408
- state .ready = children + state .ready
496
+ child_ids = [child .id for child in children ]
497
+ coroutine .result = AllFuture (
498
+ order = child_ids , waiting = set (child_ids )
499
+ )
500
+ state .suspended [coroutine .id ] = coroutine
501
+
502
+ case AnyDirective ():
503
+ children = spawn_children (
504
+ state , coroutine , coroutine_yield .awaitables
505
+ )
409
506
410
507
child_ids = [child .id for child in children ]
411
- coroutine .result = GatherFuture (
412
- order = child_ids , waiting = set (child_ids ), results = {}
508
+ coroutine .result = AnyFuture (
509
+ order = child_ids , waiting = set (child_ids )
510
+ )
511
+ state .suspended [coroutine .id ] = coroutine
512
+
513
+ case RaceDirective ():
514
+ children = spawn_children (
515
+ state , coroutine , coroutine_yield .awaitables
516
+ )
517
+
518
+ coroutine .result = RaceFuture (
519
+ waiting = {child .id for child in children }
413
520
)
414
521
state .suspended [coroutine .id ] = coroutine
415
522
@@ -446,6 +553,26 @@ def _run(self, input: Input) -> Output:
446
553
)
447
554
448
555
556
+ def spawn_children (
557
+ state : State , coroutine : Coroutine , awaitables : tuple [Awaitable [Any ], ...]
558
+ ) -> list [Coroutine ]:
559
+ children = []
560
+ for awaitable in awaitables :
561
+ g = awaitable .__await__ ()
562
+ if not isinstance (g , DurableGenerator ):
563
+ raise TypeError ("awaitable is not a @dispatch.function" )
564
+ child_id = state .next_coroutine_id
565
+ state .next_coroutine_id += 1
566
+ child = Coroutine (id = child_id , parent_id = coroutine .id , coroutine = g )
567
+ logger .debug ("enqueuing %s for %s" , child , coroutine )
568
+ children .append (child )
569
+
570
+ # Prepend children to get a depth-first traversal of coroutines.
571
+ state .ready = children + state .ready
572
+
573
+ return children
574
+
575
+
449
576
def correlation_id (coroutine_id : CoroutineID , call_id : CallID ) -> CorrelationID :
450
577
return coroutine_id << 32 | call_id
451
578
0 commit comments