11import sys
22from multiprocessing import cpu_count
33import time
4- from typing import Callable , List , Any , Tuple , Union
4+ from typing import Callable , List , Any , Optional , Tuple , Union
55from threading import Lock , Thread , Event
66from concurrent .futures import ThreadPoolExecutor
77import shutil
88
9- from .runner_arguments import RunnerArguments , AdditionalSwiftSourcesArguments
9+ from .git_command import GitException
10+
11+ from .runner_arguments import RunnerArguments , AdditionalSwiftSourcesArguments , UpdateArguments
1012
1113
1214class TaskTracker :
@@ -50,32 +52,38 @@ def done_task_counter(self) -> int:
5052class MonitoredFunction :
5153 def __init__ (
5254 self ,
53- fn : Callable ,
55+ fn : Callable [..., Union [ Exception ]] ,
5456 task_tracker : TaskTracker ,
5557 ):
56- self .fn = fn
58+ self ._fn = fn
5759 self ._task_tracker = task_tracker
5860
5961 def __call__ (self , * args : Union [RunnerArguments , AdditionalSwiftSourcesArguments ]):
6062 task_name = args [0 ].repo_name
6163 self ._task_tracker .mark_task_as_running (task_name )
6264 result = None
6365 try :
64- result = self .fn (* args )
66+ result = self ._fn (* args )
6567 except Exception as e :
6668 print (e )
6769 finally :
6870 self ._task_tracker .mark_task_as_done (task_name )
6971 return result
7072
7173
72- class ParallelRunner :
74+ class ParallelRunner () :
7375 def __init__ (
7476 self ,
75- fn : Callable ,
76- pool_args : List [ Union [RunnerArguments , AdditionalSwiftSourcesArguments ]],
77+ fn : Callable [..., None ] ,
78+ pool_args : Union [List [ UpdateArguments ], List [ AdditionalSwiftSourcesArguments ]],
7779 n_threads : int = 0 ,
7880 ):
81+ def run_safely (* args , ** kwargs ):
82+ try :
83+ fn (* args , ** kwargs )
84+ except GitException as e :
85+ return e
86+
7987 if n_threads == 0 :
8088 # Limit the number of threads as the performance regresses if the
8189 # number is too high.
@@ -84,7 +92,8 @@ def __init__(
8492 self ._monitor_polling_period = 0.1
8593 self ._terminal_width = shutil .get_terminal_size ().columns
8694 self ._pool_args = pool_args
87- self ._fn = fn
95+ self ._fn_name = fn .__name__
96+ self ._fn = run_safely
8897 self ._output_prefix = pool_args [0 ].output_prefix
8998 self ._nb_repos = len (pool_args )
9099 self ._stop_event = Event ()
@@ -93,8 +102,8 @@ def __init__(
93102 self ._task_tracker = TaskTracker ()
94103 self ._monitored_fn = MonitoredFunction (self ._fn , self ._task_tracker )
95104
96- def run (self ) -> List [Any ]:
97- print (f"Running ``{ self ._fn . __name__ } `` with up to { self ._n_threads } processes." )
105+ def run (self ) -> List [Union [ None , Exception ] ]:
106+ print (f"Running ``{ self ._fn_name } `` with up to { self ._n_threads } processes." )
98107 if self ._verbose :
99108 with ThreadPoolExecutor (max_workers = self ._n_threads ) as pool :
100109 results = list (pool .map (self ._fn , self ._pool_args , timeout = 1800 ))
@@ -129,13 +138,10 @@ def _monitor(self):
129138 sys .stdout .flush ()
130139
131140 @staticmethod
132- def check_results (results , op ) -> int :
133- """Function used to check the results of ParallelRunner.
134-
135- NOTE: This function was originally located in the shell module of
136- swift_build_support and should eventually be replaced with a better
137- parallel implementation.
138- """
141+ def check_results (
142+ results : Optional [List [Union [GitException , Exception , Any ]]], operation : str
143+ ) -> int :
144+ """Check the results of ParallelRunner and print the failures."""
139145
140146 fail_count = 0
141147 if results is None :
@@ -144,15 +150,10 @@ def check_results(results, op) -> int:
144150 if r is None :
145151 continue
146152 if fail_count == 0 :
147- print ("======%s FAILURES======" % op )
153+ print (f "======{ operation } FAILURES======" )
148154 fail_count += 1
149- if isinstance (r , str ):
155+ if isinstance (r , ( GitException , Exception ) ):
150156 print (r )
151157 continue
152- if not hasattr (r , "repo_path" ):
153- # TODO: create a proper Exception class with these attributes
154- continue
155- print ("%s failed (ret=%d): %s" % (r .repo_path , r .ret , r ))
156- if r .stderr :
157- print (r .stderr .decode ())
158+ print (r )
158159 return fail_count
0 commit comments