Skip to content

Commit f682255

Browse files
committed
HPC for ML prediction
1 parent f24138c commit f682255

File tree

1 file changed

+54
-3
lines changed

1 file changed

+54
-3
lines changed

ensemble/ensemble.py

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -516,9 +516,60 @@ def calc_ml_prediction(self, input_state=None):
516516
# Index list of ensemble members
517517
list_member_index = list(ml_ne)
518518

519-
# Run prediction in parallel using p_map
520-
en_pred = p_map(self.sim.run_fwd_sim, list_state,
521-
list_member_index, num_cpus=no_tot_run, disable=self.disable_tqdm)
519+
if no_tot_run==1: # if not in parallel we use regular loop
520+
en_pred = [self.sim.run_fwd_sim(state, member_index) for state, member_index in
521+
tqdm(zip(list_state, list_member_index), total=len(list_state))]
522+
elif self.sim.input_dict.get('hpc', False): # Run prediction in parallel on hpc
523+
batch_size = no_tot_run # If more than 500 ensemble members, we limit the runs to batches of 500
524+
# Split the ensemble into batches of 500
525+
if batch_size >= 1000:
526+
self.logger.info(f'Cannot run batch size of {no_tot_run}. Set to 1000')
527+
batch_size = 1000
528+
en_pred = []
529+
batch_en = [np.arange(start, start + batch_size) for start in
530+
np.arange(0, self.ne - batch_size, batch_size)]
531+
if len(batch_en): # if self.ne is less than batch_size
532+
batch_en.append(np.arange(batch_en[-1][-1]+1, self.ne))
533+
else:
534+
batch_en.append(np.arange(0, self.ne))
535+
for n_e in batch_en:
536+
_ = [self.sim.run_fwd_sim(state, member_index, nosim=True) for state, member_index in
537+
zip([list_state[curr_n] for curr_n in n_e], [list_member_index[curr_n] for curr_n in n_e])]
538+
# Run call_sim on the hpc
539+
if self.sim.options['mpiarray']:
540+
job_id = self.sim.SLURM_ARRAY_HPC_run(
541+
n_e,
542+
venv=os.path.join(os.path.dirname(sys.executable), 'activate'),
543+
filename=self.sim.file,
544+
**self.sim.options
545+
)
546+
else:
547+
job_id=self.sim.SLURM_HPC_run(
548+
n_e,
549+
venv=os.path.join(os.path.dirname(sys.executable),'activate'),
550+
filename=self.sim.file,
551+
**self.sim.options
552+
)
553+
554+
# Wait for the simulations to finish
555+
if job_id:
556+
sim_status = self.sim.wait_for_jobs(job_id)
557+
else:
558+
print("Job submission failed. Exiting.")
559+
sim_status = [False]*len(n_e)
560+
# Extract the results. Need a local counter to check the results in the correct order
561+
for c_member, member_i in enumerate([list_member_index[curr_n] for curr_n in n_e]):
562+
if sim_status[c_member]:
563+
self.sim.extract_data(member_i)
564+
en_pred.append(deepcopy(self.sim.pred_data))
565+
if self.sim.saveinfo is not None: # Try to save information
566+
store_ensemble_sim_information(self.sim.saveinfo, member_i)
567+
else:
568+
en_pred.append(False)
569+
self.sim.remove_folder(member_i)
570+
else: # Run prediction in parallel using p_map
571+
en_pred = p_map(self.sim.run_fwd_sim, list_state,
572+
list_member_index, num_cpus=no_tot_run, disable=self.disable_tqdm)
522573

523574
# List successful runs and crashes
524575
list_crash = [indx for indx, el in enumerate(en_pred) if el is False]

0 commit comments

Comments
 (0)