23
23
from serverlessworkflow .sdk .transition_data_condition import TransitionDataCondition
24
24
from serverlessworkflow .sdk .end_data_condition import EndDataCondition
25
25
26
- from transitions .extensions import HierarchicalMachine , GraphMachine
27
26
from transitions .extensions .nesting import NestedState
28
27
import warnings
29
28
@@ -260,7 +259,67 @@ def sleep_state_details(self):
260
259
261
260
def event_state_details (self ):
262
261
if isinstance (self .current_state , EventState ):
263
- self .state_to_machine_state (["event_state" , "state" ])
262
+ state = self .state_to_machine_state (["event_state" , "state" ])
263
+ if self .get_actions :
264
+ if on_events := self .current_state .onEvents :
265
+ state .initial = [] if len (on_events ) > 1 else on_events [0 ]
266
+ for i , oe in enumerate (on_events ):
267
+ state .add_substate (
268
+ oe_state := self .state_machine .state_cls (
269
+ oe_name := f"onEvent { i } "
270
+ )
271
+ )
272
+
273
+ # define initial state
274
+ if i == 0 and len (on_events ) > 1 :
275
+ state .initial = [oe_state .name ]
276
+ elif i == 0 and len (on_events ) == 1 :
277
+ state .initial = oe_state .name
278
+ else :
279
+ state .initial .append (oe_state .name )
280
+
281
+ event_names = []
282
+ for ie , event in enumerate (oe .eventRefs ):
283
+ oe_state .add_substate (
284
+ ns := self .state_machine .state_cls (event )
285
+ )
286
+ ns .tags = ["event" ]
287
+ self .get_action_event (state = ns , e_name = event )
288
+ event_names .append (event )
289
+
290
+ # define initial state
291
+ if ie == 0 and len (oe .eventRefs ) > 1 :
292
+ oe_state .initial = [event ]
293
+ elif ie == 0 and len (oe .eventRefs ) == 1 :
294
+ oe_state .initial = event
295
+ else :
296
+ oe_state .initial .append (event )
297
+
298
+ if self .current_state .exclusive :
299
+ oe_state .add_substate (
300
+ ns := self .state_machine .state_cls (
301
+ action_name := f"action { ie } "
302
+ )
303
+ )
304
+ self .state_machine .add_transition (
305
+ trigger = "" ,
306
+ source = f"{ self .current_state .name } .{ oe_name } .{ event } " ,
307
+ dest = f"{ self .current_state .name } .{ oe_name } .{ action_name } " ,
308
+ )
309
+ self .generate_actions_info (
310
+ machine_state = ns ,
311
+ state_name = f"{ self .current_state .name } .{ oe_name } .{ action_name } " ,
312
+ actions = oe .actions ,
313
+ action_mode = oe .actionMode ,
314
+ )
315
+ if not self .current_state .exclusive and oe .actions :
316
+ self .generate_actions_info (
317
+ machine_state = oe_state ,
318
+ state_name = f"{ self .current_state .name } .{ oe_name } " ,
319
+ actions = oe .actions ,
320
+ action_mode = oe .actionMode ,
321
+ initial_states = event_names ,
322
+ )
264
323
265
324
def foreach_state_details (self ):
266
325
if isinstance (self .current_state , ForEachState ):
@@ -353,6 +412,7 @@ def generate_actions_info(
353
412
state_name : str ,
354
413
actions : List [Dict [str , Action ]],
355
414
action_mode : str = "sequential" ,
415
+ initial_states : List [str ] = [],
356
416
):
357
417
if self .get_actions :
358
418
parallel_states = []
@@ -387,7 +447,11 @@ def generate_actions_info(
387
447
ns := self .state_machine .state_cls (name )
388
448
)
389
449
ns .tags = ["event" ]
390
- self .get_action_event (state = ns , e_name = name )
450
+ self .get_action_event (
451
+ state = ns ,
452
+ e_name = action .eventRef .triggerEventRef ,
453
+ er_name = action .eventRef .resultEventRef ,
454
+ )
391
455
if name :
392
456
if action_mode == "sequential" :
393
457
if i < len (actions ) - 1 :
@@ -439,19 +503,36 @@ def generate_actions_info(
439
503
)
440
504
ns .tags = ["event" ]
441
505
self .get_action_event (
442
- state = ns , e_name = next_name
506
+ state = ns ,
507
+ e_name = action .eventRef .triggerEventRef ,
508
+ er_name = action .eventRef .resultEventRef ,
443
509
)
444
510
self .state_machine .add_transition (
445
511
trigger = "" ,
446
512
source = f"{ state_name } .{ name } " ,
447
513
dest = f"{ state_name } .{ next_name } " ,
448
514
)
449
- if i == 0 :
515
+ if i == 0 and not initial_states :
450
516
machine_state .initial = name
517
+ elif i == 0 and initial_states :
518
+ for init_s in initial_states :
519
+ self .state_machine .add_transition (
520
+ trigger = "" ,
521
+ source = f"{ state_name } .{ init_s } " ,
522
+ dest = f"{ state_name } .{ name } " ,
523
+ )
451
524
elif action_mode == "parallel" :
452
525
parallel_states .append (name )
453
- if action_mode == "parallel" :
526
+ if action_mode == "parallel" and not initial_states :
454
527
machine_state .initial = parallel_states
528
+ elif action_mode == "parallel" and initial_states :
529
+ for init_s in initial_states :
530
+ for ps in parallel_states :
531
+ self .state_machine .add_transition (
532
+ trigger = "" ,
533
+ source = f"{ state_name } .{ init_s } " ,
534
+ dest = f"{ state_name } .{ ps } " ,
535
+ )
455
536
456
537
def get_action_function (self , state : NestedState , f_name : str ):
457
538
if self .workflow .functions :
@@ -461,13 +542,14 @@ def get_action_function(self, state: NestedState, f_name: str):
461
542
state .metadata = {"function" : current_function }
462
543
break
463
544
464
- def get_action_event (self , state : NestedState , e_name : str ):
545
+ def get_action_event (self , state : NestedState , e_name : str , er_name : str = "" ):
465
546
if self .workflow .events :
466
547
for event in self .workflow .events :
467
548
current_event = event .serialize ().__dict__
468
549
if current_event ["name" ] == e_name :
469
550
state .metadata = {"event" : current_event }
470
- break
551
+ if current_event ["name" ] == er_name :
552
+ state .metadata = {"result_event" : current_event }
471
553
472
554
def subflow_state_name (self , action : Action , subflow : Workflow ):
473
555
return (
0 commit comments