@@ -516,9 +516,60 @@ def calc_ml_prediction(self, input_state=None):
516516 # Index list of ensemble members
517517 list_member_index = list (ml_ne )
518518
519- # Run prediction in parallel using p_map
520- en_pred = p_map (self .sim .run_fwd_sim , list_state ,
521- list_member_index , num_cpus = no_tot_run , disable = self .disable_tqdm )
519+ if no_tot_run == 1 : # if not in parallel we use regular loop
520+ en_pred = [self .sim .run_fwd_sim (state , member_index ) for state , member_index in
521+ tqdm (zip (list_state , list_member_index ), total = len (list_state ))]
522+ elif self .sim .input_dict .get ('hpc' , False ): # Run prediction in parallel on hpc
523+ batch_size = no_tot_run # If more than 500 ensemble members, we limit the runs to batches of 500
524+ # Split the ensemble into batches of 500
525+ if batch_size >= 1000 :
526+ self .logger .info (f'Cannot run batch size of { no_tot_run } . Set to 1000' )
527+ batch_size = 1000
528+ en_pred = []
529+ batch_en = [np .arange (start , start + batch_size ) for start in
530+ np .arange (0 , self .ne - batch_size , batch_size )]
531+ if len (batch_en ): # if self.ne is less than batch_size
532+ batch_en .append (np .arange (batch_en [- 1 ][- 1 ]+ 1 , self .ne ))
533+ else :
534+ batch_en .append (np .arange (0 , self .ne ))
535+ for n_e in batch_en :
536+ _ = [self .sim .run_fwd_sim (state , member_index , nosim = True ) for state , member_index in
537+ zip ([list_state [curr_n ] for curr_n in n_e ], [list_member_index [curr_n ] for curr_n in n_e ])]
538+ # Run call_sim on the hpc
539+ if self .sim .options ['mpiarray' ]:
540+ job_id = self .sim .SLURM_ARRAY_HPC_run (
541+ n_e ,
542+ venv = os .path .join (os .path .dirname (sys .executable ), 'activate' ),
543+ filename = self .sim .file ,
544+ ** self .sim .options
545+ )
546+ else :
547+ job_id = self .sim .SLURM_HPC_run (
548+ n_e ,
549+ venv = os .path .join (os .path .dirname (sys .executable ),'activate' ),
550+ filename = self .sim .file ,
551+ ** self .sim .options
552+ )
553+
554+ # Wait for the simulations to finish
555+ if job_id :
556+ sim_status = self .sim .wait_for_jobs (job_id )
557+ else :
558+ print ("Job submission failed. Exiting." )
559+ sim_status = [False ]* len (n_e )
560+ # Extract the results. Need a local counter to check the results in the correct order
561+ for c_member , member_i in enumerate ([list_member_index [curr_n ] for curr_n in n_e ]):
562+ if sim_status [c_member ]:
563+ self .sim .extract_data (member_i )
564+ en_pred .append (deepcopy (self .sim .pred_data ))
565+ if self .sim .saveinfo is not None : # Try to save information
566+ store_ensemble_sim_information (self .sim .saveinfo , member_i )
567+ else :
568+ en_pred .append (False )
569+ self .sim .remove_folder (member_i )
570+ else : # Run prediction in parallel using p_map
571+ en_pred = p_map (self .sim .run_fwd_sim , list_state ,
572+ list_member_index , num_cpus = no_tot_run , disable = self .disable_tqdm )
522573
523574 # List successful runs and crashes
524575 list_crash = [indx for indx , el in enumerate (en_pred ) if el is False ]
0 commit comments