diff --git a/python/solid_dmft/dft_managers/mpi_helpers.py b/python/solid_dmft/dft_managers/mpi_helpers.py index a1376eb0..fa461747 100644 --- a/python/solid_dmft/dft_managers/mpi_helpers.py +++ b/python/solid_dmft/dft_managers/mpi_helpers.py @@ -59,6 +59,9 @@ def create_hostfile(number_cores, cluster_name): return None hostnames = mpi.world.gather(socket.gethostname(), root=0) + if cluster_name == 'slurm': + slurm_hostnames = [hostname.split('.')[0] for hostname in hostnames] # TODO: please find a better solution + hostnames = slurm_hostnames if mpi.is_master_node(): # create hostfile based on first number_cores ranks hosts = defaultdict(int) @@ -68,6 +71,7 @@ def create_hostfile(number_cores, cluster_name): mask_hostfile = {'openmpi': '{} slots={}', # OpenMPI format 'openmpi-intra': '{} slots={}', # OpenMPI format 'mpich': '{}:{}', # MPICH format + 'slurm': '{}', # SLURM format }[cluster_name] hostfile = 'dft.hostfile' @@ -148,6 +152,14 @@ def get_mpi_arguments(mpi_profile, mpi_exe, number_cores, dft_exe, hostfile): return [mpi_exe, '-launcher', 'ssh', '-hostfile', hostfile, '-np', str(number_cores), '-envlist', 'PATH'] + shlex.split(dft_exe) + if mpi_profile == 'slurm': + return [ + mpi_exe, '-n', str(number_cores), '--export=PATH', + '-N', os.getenv("SLURM_JOB_NUM_NODES"), '-A', os.getenv("SLURM_JOB_ACCOUNT"), + '-p', os.getenv("SLURM_JOB_PARTITION"), '-t', '05:00', #TODO: decide way to get time limit + '-w', f"./{hostfile}", + ] + shlex.split(dft_exe) + return None diff --git a/python/solid_dmft/io_tools/verify_input_params.py b/python/solid_dmft/io_tools/verify_input_params.py index 5f0b2ea3..3e646faa 100644 --- a/python/solid_dmft/io_tools/verify_input_params.py +++ b/python/solid_dmft/io_tools/verify_input_params.py @@ -72,7 +72,7 @@ def _verify_input_params_dft(params: FullConfig) -> None: if params['dft']['dft_code'] not in ('vasp', 'qe', None): raise ValueError(f'Invalid "dft.dft_code" = {params["dft"]["dft_code"]}.') - if params['dft']['mpi_env'] not in ('default', 'openmpi', 'openmpi-intra', 'mpich'): + if params['dft']['mpi_env'] not in ('default', 'openmpi', 'openmpi-intra', 'mpich', 'slurm'): raise ValueError(f'Invalid "dft.mpi_env" = {params["dft"]["mpi_env"]}.') if params['dft']['projector_type'] not in ('w90', 'plo'):