1
1
import logging
2
2
import pickle
3
3
import sys
4
- from dataclasses import dataclass
5
- from typing import Any , Callable , Protocol , TypeAlias
4
+ from dataclasses import dataclass , field
5
+ from typing import Any , Awaitable , Callable , Protocol , TypeAlias
6
6
7
- from dispatch .coroutine import All
7
+ from dispatch .coroutine import AllDirective , AnyDirective , AnyException
8
8
from dispatch .error import IncompatibleStateError
9
9
from dispatch .experimental .durable .function import DurableCoroutine , DurableGenerator
10
10
from dispatch .proto import Call , Error , Input , Output
@@ -73,6 +73,7 @@ def error(self) -> Exception | None:
73
73
return self .first_error
74
74
75
75
def value (self ) -> Any :
76
+ assert self .first_error is None
76
77
assert self .result is not None
77
78
return self .result .value
78
79
@@ -81,9 +82,9 @@ def value(self) -> Any:
81
82
class AllFuture :
82
83
"""A future result of a dispatch.coroutine.all() operation."""
83
84
84
- order : list [CoroutineID ]
85
- waiting : set [CoroutineID ]
86
- results : dict [CoroutineID , CoroutineResult ]
85
+ order : list [CoroutineID ] = field ( default_factory = list )
86
+ waiting : set [CoroutineID ] = field ( default_factory = set )
87
+ results : dict [CoroutineID , CoroutineResult ] = field ( default_factory = dict )
87
88
first_error : Exception | None = None
88
89
89
90
def add_result (self , result : CallResult | CoroutineResult ):
@@ -94,13 +95,15 @@ def add_result(self, result: CallResult | CoroutineResult):
94
95
except KeyError :
95
96
return
96
97
97
- if result .error is not None and self .first_error is None :
98
- self .first_error = result .error
98
+ if result .error is not None :
99
+ if self .first_error is None :
100
+ self .first_error = result .error
101
+ return
99
102
100
103
self .results [result .coroutine_id ] = result
101
104
102
105
def add_error (self , error : Exception ):
103
- if self .first_error is not None :
106
+ if self .first_error is None :
104
107
self .first_error = error
105
108
106
109
def ready (self ) -> bool :
@@ -113,9 +116,68 @@ def error(self) -> Exception | None:
113
116
def value (self ) -> list [Any ]:
114
117
assert self .ready ()
115
118
assert len (self .waiting ) == 0
119
+ assert self .first_error is None
116
120
return [self .results [id ].value for id in self .order ]
117
121
118
122
123
+ @dataclass (slots = True )
124
+ class AnyFuture :
125
+ """A future result of a dispatch.coroutine.any() operation."""
126
+
127
+ order : list [CoroutineID ] = field (default_factory = list )
128
+ waiting : set [CoroutineID ] = field (default_factory = set )
129
+ first_result : CoroutineResult | None = None
130
+ errors : dict [CoroutineID , Exception ] = field (default_factory = dict )
131
+ generic_error : Exception | None = None
132
+
133
+ def add_result (self , result : CallResult | CoroutineResult ):
134
+ assert isinstance (result , CoroutineResult )
135
+
136
+ try :
137
+ self .waiting .remove (result .coroutine_id )
138
+ except KeyError :
139
+ return
140
+
141
+ if result .error is None :
142
+ if self .first_result is None :
143
+ self .first_result = result
144
+ return
145
+
146
+ self .errors [result .coroutine_id ] = result .error
147
+
148
+ def add_error (self , error : Exception ):
149
+ if self .generic_error is None :
150
+ self .generic_error = error
151
+
152
+ def ready (self ) -> bool :
153
+ return (
154
+ self .generic_error is not None
155
+ or self .first_result is not None
156
+ or len (self .waiting ) == 0
157
+ )
158
+
159
+ def error (self ) -> Exception | None :
160
+ assert self .ready ()
161
+ if self .generic_error is not None :
162
+ return self .generic_error
163
+ if self .first_result is not None or len (self .errors ) == 0 :
164
+ return None
165
+ match len (self .errors ):
166
+ case 0 :
167
+ return None
168
+ case 1 :
169
+ return self .errors [self .order [0 ]]
170
+ case _:
171
+ return AnyException ([self .errors [id ] for id in self .order ])
172
+
173
+ def value (self ) -> Any :
174
+ assert self .ready ()
175
+ if len (self .order ) == 0 :
176
+ return None
177
+ assert self .first_result is not None
178
+ return self .first_result .value
179
+
180
+
119
181
@dataclass (slots = True )
120
182
class Coroutine :
121
183
"""An in-flight coroutine."""
@@ -386,28 +448,25 @@ def _run(self, input: Input) -> Output:
386
448
state .prev_callers .append (coroutine )
387
449
state .outstanding_calls += 1
388
450
389
- case All ():
390
- children = []
391
- for awaitable in coroutine_yield .awaitables :
392
- g = awaitable .__await__ ()
393
- if not isinstance (g , DurableGenerator ):
394
- raise ValueError (
395
- "gather awaitable is not a @dispatch.function"
396
- )
397
- child_id = state .next_coroutine_id
398
- state .next_coroutine_id += 1
399
- child = Coroutine (
400
- id = child_id , parent_id = coroutine .id , coroutine = g
401
- )
402
- logger .debug ("enqueuing %s for %s" , child , coroutine )
403
- children .append (child )
404
-
405
- # Prepend children to get a depth-first traversal of coroutines.
406
- state .ready = children + state .ready
451
+ case AllDirective ():
452
+ children = spawn_children (
453
+ state , coroutine , coroutine_yield .awaitables
454
+ )
407
455
408
456
child_ids = [child .id for child in children ]
409
457
coroutine .result = AllFuture (
410
- order = child_ids , waiting = set (child_ids ), results = {}
458
+ order = child_ids , waiting = set (child_ids )
459
+ )
460
+ state .suspended [coroutine .id ] = coroutine
461
+
462
+ case AnyDirective ():
463
+ children = spawn_children (
464
+ state , coroutine , coroutine_yield .awaitables
465
+ )
466
+
467
+ child_ids = [child .id for child in children ]
468
+ coroutine .result = AnyFuture (
469
+ order = child_ids , waiting = set (child_ids )
411
470
)
412
471
state .suspended [coroutine .id ] = coroutine
413
472
@@ -444,6 +503,26 @@ def _run(self, input: Input) -> Output:
444
503
)
445
504
446
505
506
+ def spawn_children (
507
+ state : State , coroutine : Coroutine , awaitables : tuple [Awaitable [Any ], ...]
508
+ ) -> list [Coroutine ]:
509
+ children = []
510
+ for awaitable in awaitables :
511
+ g = awaitable .__await__ ()
512
+ if not isinstance (g , DurableGenerator ):
513
+ raise TypeError ("awaitable is not a @dispatch.function" )
514
+ child_id = state .next_coroutine_id
515
+ state .next_coroutine_id += 1
516
+ child = Coroutine (id = child_id , parent_id = coroutine .id , coroutine = g )
517
+ logger .debug ("enqueuing %s for %s" , child , coroutine )
518
+ children .append (child )
519
+
520
+ # Prepend children to get a depth-first traversal of coroutines.
521
+ state .ready = children + state .ready
522
+
523
+ return children
524
+
525
+
447
526
def correlation_id (coroutine_id : CoroutineID , call_id : CallID ) -> CorrelationID :
448
527
return coroutine_id << 32 | call_id
449
528
0 commit comments