diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index ef2c13646..43789df11 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -84,7 +84,7 @@ jobs: - name: Tests MPICH if: ${{ matrix.mpi == 'mpich' }} run: | - for proc in 1 5 9 ; do + for proc in 1 2 3; do echo "Running MPI test ($MPI) with $proc processes" PYTHONPATH=. mpiexec -n $proc python3 ./test/test_mpi.py PYTHONPATH=. mpiexec -n $proc python3 ./test/test_detector_blocks.py @@ -93,8 +93,8 @@ jobs: - name: Tests OpenMPI if: ${{ matrix.mpi == 'openmpi' }} run: | - for proc in 1 2 ; do + for proc in 1 2 3; do echo "Running MPI test ($MPI) with $proc processes" - PYTHONPATH=. mpiexec -n $proc python3 ./test/test_mpi.py - PYTHONPATH=. mpiexec -n $proc python3 ./test/test_detector_blocks.py + PYTHONPATH=. mpiexec --oversubscribe -n $proc python3 ./test/test_mpi.py + PYTHONPATH=. mpiexec --oversubscribe -n $proc python3 ./test/test_detector_blocks.py done diff --git a/litebird_sim/io.py b/litebird_sim/io.py index 4f0812f12..4b3ae26ea 100644 --- a/litebird_sim/io.py +++ b/litebird_sim/io.py @@ -471,7 +471,7 @@ def write_list_of_observations( observations = [observations] except IndexError: # Empty list - # We do not want to return here, as we still need to participate to + # We do not want to return here, as we still need to participate in # the call to _compute_global_start_index below observations = [] # type: List[Observation] diff --git a/litebird_sim/mapmaking/binner.py b/litebird_sim/mapmaking/binner.py index eee93726d..eb4af9d9d 100644 --- a/litebird_sim/mapmaking/binner.py +++ b/litebird_sim/mapmaking/binner.py @@ -7,24 +7,23 @@ # functions and variable defined here use the same letters and symbols of that # paper. We refer to it in code comments and docstrings as "KurkiSuonio2009". +import logging from dataclasses import dataclass +from typing import Union, List, Any, Optional, Callable +import healpy as hp import numpy as np import numpy.typing as npt +from ducc0.healpix import Healpix_Base from numba import njit -import healpy as hp -from typing import Union, List, Any, Optional, Callable -from litebird_sim.observations import Observation -from litebird_sim.coordinates import CoordinateSystem -from litebird_sim.pointings import get_hwp_angle -from litebird_sim.hwp import HWP from litebird_sim import mpi -from ducc0.healpix import Healpix_Base +from litebird_sim.coordinates import CoordinateSystem from litebird_sim.healpix import nside_to_npix - -import logging - +from litebird_sim.hwp import HWP +from litebird_sim.mpi import MPI_COMM_GRID +from litebird_sim.observations import Observation +from litebird_sim.pointings import get_hwp_angle from .common import ( _compute_pixel_indices, _normalize_observations_and_pointings, @@ -263,7 +262,9 @@ def _build_nobs_matrix( for i in range(len(obs_list) - 1) ] ): - nobs_matrix = obs_list[0].comm.allreduce(nobs_matrix, mpi.MPI.SUM) + nobs_matrix = mpi.MPI_COMM_GRID.COMM_OBS_GRID.allreduce( + nobs_matrix, mpi.MPI.SUM + ) else: raise NotImplementedError( "All observations must be distributed over the same MPI groups" @@ -282,7 +283,7 @@ def make_binned_map( detector_split: str = "full", time_split: str = "full", pointings_dtype=np.float64, -) -> BinnerResult: +) -> Optional[BinnerResult]: """Bin Map-maker Map a list of observations @@ -319,9 +320,14 @@ def make_binned_map( Returns: An instance of the class :class:`.MapMakerResult`. If the observations are - distributed over MPI Processes, all of them get a copy of the same object. + distributed over MPI Processes, all of them get a copy of the same object, + unless the current MPI process does not hold any TOD sample: in the latter + case, ``None`` is returned. """ + if not MPI_COMM_GRID.is_this_process_in_grid(): + return None + if not components: components = ["tod"] diff --git a/litebird_sim/mapmaking/common.py b/litebird_sim/mapmaking/common.py index 441d1ca25..800fded3f 100644 --- a/litebird_sim/mapmaking/common.py +++ b/litebird_sim/mapmaking/common.py @@ -1,11 +1,11 @@ from dataclasses import dataclass from typing import Union, List, Tuple, Callable + +import astropy.time import numpy as np import numpy.typing as npt -from numba import njit -import astropy.time - from ducc0.healpix import Healpix_Base +from numba import njit from litebird_sim.coordinates import CoordinateSystem, rotate_coordinates_e2g from litebird_sim.observations import Observation @@ -109,15 +109,14 @@ def get_map_making_weights( except AttributeError: weights = np.ones(observations.n_detectors) - if check and MPI_COMM_GRID.COMM_OBS_GRID != MPI_COMM_GRID.COMM_NULL: - if check: - # Check that there are no weird weights - assert np.all( - np.isfinite(weights) - ), f"Not all the detectors' weights are finite numbers: {weights}" - assert np.all( - weights > 0.0 - ), f"Not all the detectors' weights are positive: {weights}" + if check and MPI_COMM_GRID.is_this_process_in_grid(): + # Check that there are no weird weights + assert np.all( + np.isfinite(weights) + ), f"Not all the detectors' weights are finite numbers: {weights}" + assert np.all( + weights > 0.0 + ), f"Not all the detectors' weights are positive: {weights}" return weights diff --git a/litebird_sim/mapmaking/destriper.py b/litebird_sim/mapmaking/destriper.py index c5c93c5d0..e8d05716d 100644 --- a/litebird_sim/mapmaking/destriper.py +++ b/litebird_sim/mapmaking/destriper.py @@ -1,32 +1,22 @@ # -*- encoding: utf-8 -*- +import gc import logging import time - -# The implementation of the destriping algorithm provided here is based on the paper -# «Destriping CMB temperature and polarization maps» by Kurki-Suonio et al. 2009, -# A&A 506, 1511–1539 (2009), https://dx.doi.org/10.1051/0004-6361/200912361 -# -# It is important to have that paper at hand while reading this code, as many -# functions and variable defined here use the same letters and symbols of that -# paper. We refer to it in code comments and docstrings as "KurkiSuonio2009". - from dataclasses import dataclass -import gc from pathlib import Path +from typing import Callable, Union, List, Optional, Tuple, Any, Dict +import healpy as hp import numpy as np import numpy.typing as npt from ducc0.healpix import Healpix_Base from numba import njit, prange -import healpy as hp -from litebird_sim.mpi import MPI_ENABLED, MPI_COMM_WORLD, MPI_COMM_GRID -from typing import Callable, Union, List, Optional, Tuple, Any, Dict +from litebird_sim.coordinates import CoordinateSystem, coord_sys_to_healpix_string from litebird_sim.hwp import HWP +from litebird_sim.mpi import MPI_ENABLED, MPI_COMM_WORLD, MPI_COMM_GRID from litebird_sim.observations import Observation from litebird_sim.pointings import get_hwp_angle -from litebird_sim.coordinates import CoordinateSystem, coord_sys_to_healpix_string - from .common import ( _compute_pixel_indices, _normalize_observations_and_pointings, @@ -39,6 +29,14 @@ _build_mask_time_split, ) +# The implementation of the destriping algorithm provided here is based on the paper +# «Destriping CMB temperature and polarization maps» by Kurki-Suonio et al. 2009, +# A&A 506, 1511–1539 (2009), https://dx.doi.org/10.1051/0004-6361/200912361 +# +# It is important to have that paper at hand while reading this code, as many +# functions and variable defined here use the same letters and symbols of that +# paper. We refer to it in code comments and docstrings as "KurkiSuonio2009". + if MPI_ENABLED: import mpi4py.MPI @@ -501,7 +499,7 @@ def _build_nobs_matrix( ) # Now we must accumulate the result of every MPI process - if MPI_ENABLED and MPI_COMM_GRID.COMM_OBS_GRID != MPI_COMM_GRID.COMM_NULL: + if MPI_ENABLED and MPI_COMM_GRID.is_this_process_in_grid(): MPI_COMM_GRID.COMM_OBS_GRID.Allreduce( mpi4py.MPI.IN_PLACE, nobs_matrix, op=mpi4py.MPI.SUM ) @@ -750,7 +748,7 @@ def _compute_binned_map( baseline_lengths=cur_baseline_lengths, ) - if MPI_ENABLED: + if MPI_ENABLED and MPI_COMM_GRID.is_this_process_in_grid(): MPI_COMM_GRID.COMM_OBS_GRID.Allreduce( mpi4py.MPI.IN_PLACE, output_sky_map, op=mpi4py.MPI.SUM ) @@ -995,7 +993,7 @@ def _mpi_dot(a: List[npt.ArrayLike], b: List[npt.ArrayLike]) -> float: # we call “flatten” to make them 1D and produce *one* scalar out of # the dot product local_result = sum([np.dot(x1.flatten(), x2.flatten()) for (x1, x2) in zip(a, b)]) - if MPI_ENABLED: + if MPI_ENABLED and MPI_COMM_GRID.is_this_process_in_grid(): return MPI_COMM_GRID.COMM_OBS_GRID.allreduce(local_result, op=mpi4py.MPI.SUM) else: return local_result @@ -1575,6 +1573,9 @@ def my_gui_callback( """ elapsed_time_s = time.monotonic() + if not MPI_COMM_GRID.is_this_process_in_grid(): + return None + if not components: components = ["tod"] diff --git a/litebird_sim/mpi.py b/litebird_sim/mpi.py index 64c311813..1c0c5ec40 100644 --- a/litebird_sim/mpi.py +++ b/litebird_sim/mpi.py @@ -56,6 +56,15 @@ def _set_comm_obs_grid(self, comm_obs_grid): def _set_null_comm(self, comm_null): self._MPI_COMM_NULL = comm_null + def is_this_process_in_grid(self) -> bool: + """ + Return ``True`` if the current MPI process is in the MPI grid. + + If the function returns ``False``, then the current MPI process is not handling + any TOD sample. + """ + return self._MPI_COMM_OBS_GRID != self._MPI_COMM_NULL + #: Global variable equal either to `mpi4py.MPI.COMM_WORLD` or a object #: that defines the member variables `rank = 0` and `size = 1`. diff --git a/litebird_sim/observations.py b/litebird_sim/observations.py index 9c7290167..267664d28 100644 --- a/litebird_sim/observations.py +++ b/litebird_sim/observations.py @@ -1,18 +1,17 @@ # -*- encoding: utf-8 -*- -from dataclasses import dataclass -from typing import Union, List, Any, Optional import numbers +from collections import defaultdict +from dataclasses import dataclass +from typing import Dict, Union, List, Any, Optional import astropy.time import numpy as np import numpy.typing as npt -from collections import defaultdict - from .coordinates import DEFAULT_TIME_SCALE -from .distribute import distribute_evenly, distribute_detector_blocks from .detectors import DetectorInfo +from .distribute import distribute_evenly, distribute_detector_blocks from .mpi import MPI_COMM_GRID, _SerialMpiCommunicator @@ -39,6 +38,14 @@ class TodDescription: description: str +def _smart_list_to_numpy_array(lst: list) -> Union[list, npt.NDArray]: + # Anything in the form [None, None, …, None] is *not* converted to a NumPy array + if isinstance(lst, list) and all(x is None for x in lst): + return lst + + return np.array(lst) + + class Observation: """An observation made by one or multiple detectors over some time window @@ -241,40 +248,50 @@ def _get_local_start_time_start_and_n_samples(self): return self.start_time_global + start * delta, start, num - def _set_attributes_from_list_of_dict(self, list_of_dict, root): + def _set_attributes_from_list_of_dict( + self, list_of_dict: List[Dict[str, str]], root: int + ) -> None: + """ + Take a list of dictionaries describing each detector and propagate them + """ np.testing.assert_equal(len(list_of_dict), self.n_detectors_global) # Turn list of dict into dict of arrays if not self.comm or self.comm.rank == root: # Build a list of all the keys in the dictionaries contained within - # `list_of_dict` (which is a *list* of dictionaries) + # `list_of_dict` (which is a *list* of dictionaries). `keys` is a list of + # strings like `name`, `net_ukrts`, `fknee_mhz`, etc. keys = list(set().union(*list_of_dict) - set(dir(self))) # This will be the dictionary associating each key with the - # *array* of value for that dictionary - dict_of_array = {k: [] for k in keys} + # *array* of values for that dictionary + dict_of_array = {cur_key: [] for cur_key in keys} # This array associates either np.nan or None to each type; # the former indicates that the value is a NumPy array, while # None is used for everything else nan_or_none = {} - for k in keys: - for d in list_of_dict: - if k in d: + for cur_key in keys: + for cur_det_dict in list_of_dict: + if cur_key in cur_det_dict: try: - nan_or_none[k] = np.nan * d[k] + nan_or_none[cur_key] = np.nan * cur_det_dict[cur_key] except TypeError: - nan_or_none[k] = None + nan_or_none[cur_key] = None break # Finally, build `dict_of_array` - for d in list_of_dict: - for k in keys: - dict_of_array[k].append(d.get(k, nan_or_none[k])) + for cur_det_dict in list_of_dict: + for cur_key in keys: + dict_of_array[cur_key].append( + cur_det_dict.get(cur_key, nan_or_none[cur_key]) + ) - # Why should this code iterate over `keys`?!? - for k in keys: - dict_of_array = {k: np.array(dict_of_array[k]) for k in keys} + # So far, dict_of_array entries are plain lists. This converts them into NumPy arrays + dict_of_array = { + cur_key: _smart_list_to_numpy_array(dict_of_array[cur_key]) + for cur_key in keys + } else: keys = None dict_of_array = {} @@ -283,8 +300,8 @@ def _set_attributes_from_list_of_dict(self, list_of_dict, root): if self.comm and self.comm.size > 1: keys = self.comm.bcast(keys) - for k in keys: - self.setattr_det_global(k, dict_of_array.get(k), root) + for cur_key in keys: + self.setattr_det_global(cur_key, dict_of_array.get(cur_key), root) @property def n_samples_global(self): @@ -681,12 +698,11 @@ def setattr_det_global(self, name, info, root=0): setattr(self, name, info) return - if ( - MPI_COMM_GRID.COMM_OBS_GRID == MPI_COMM_GRID.COMM_NULL - ): # The process does not own any detector (and TOD) + if not MPI_COMM_GRID.is_this_process_in_grid(): + # The process does not own any detector (and TOD) null_det = DetectorInfo() attribute = getattr(null_det, name, None) - value = np.array([0]) if isinstance(attribute, numbers.Number) else [None] + value = np.array([0]) if isinstance(attribute, numbers.Number) else None setattr(self, name, value) return @@ -927,6 +943,10 @@ def get_pointings( return pointing_buffer, hwp_buffer + def no_mueller_hwp(self) -> bool: + "Return True if no detectors have defined a Mueller matrix for the HWP" + return (self.mueller_hwp is None) or all(m is None for m in self.mueller_hwp) + def _set_mpi_subcommunicators(self): """ This function splits the global MPI communicator into three kinds of diff --git a/litebird_sim/pointings_in_obs.py b/litebird_sim/pointings_in_obs.py index 2988413fb..ac3a0cf6d 100644 --- a/litebird_sim/pointings_in_obs.py +++ b/litebird_sim/pointings_in_obs.py @@ -41,16 +41,21 @@ def prepare_pointings( # If the hwp object is passed and is not initialised in the observations, it gets applied to all detectors if hwp is None: - assert all(m is None for m in cur_obs.mueller_hwp), ( + assert cur_obs.no_mueller_hwp(), ( "Some detectors have been initialized with a mueller_hwp," "but no HWP object has been passed to prepare_pointings." ) for cur_obs in obs_list: cur_obs.has_hwp = False else: - for idet in cur_obs.det_idx: - if cur_obs.mueller_hwp[idet] is None: - cur_obs.mueller_hwp[idet] = hwp.mueller + if cur_obs.det_idx is not None: + if cur_obs.no_mueller_hwp(): + cur_obs.mueller_hwp = [hwp.mueller for i in cur_obs.det_idx] + else: + for idet in list(cur_obs.det_idx): + if cur_obs.mueller_hwp[idet] is None: + cur_obs.mueller_hwp[idet] = hwp.mueller + for cur_obs in obs_list: cur_obs.has_hwp = True diff --git a/litebird_sim/scan_map.py b/litebird_sim/scan_map.py index d143eeee4..0f42a79e2 100644 --- a/litebird_sim/scan_map.py +++ b/litebird_sim/scan_map.py @@ -1,17 +1,18 @@ # -*- encoding: utf-8 -*- +import logging +from typing import Union, List, Dict, Optional + +import healpy as hp import numpy as np +from ducc0.healpix import Healpix_Base from numba import njit, prange -from ducc0.healpix import Healpix_Base -from typing import Union, List, Dict, Optional -from .observations import Observation -from .hwp import HWP, mueller_ideal_hwp -from .pointings import get_hwp_angle from .coordinates import rotate_coordinates_e2g, CoordinateSystem from .healpix import npix_to_nside -import logging -import healpy as hp +from .hwp import HWP, mueller_ideal_hwp +from .observations import Observation +from .pointings import get_hwp_angle @njit @@ -454,16 +455,18 @@ def scan_map_in_observations( 1 ] else: - assert all(m is None for m in cur_obs.mueller_hwp), ( + assert cur_obs.no_mueller_hwp(), ( "Detectors have been initialized with a mueller_hwp," - "but no HWP is either passed or initilized in the pointing" + "but no HWP is either passed or initialized in pointings" ) hwp_angle = None else: if isinstance(cur_ptg, np.ndarray): hwp_angle = get_hwp_angle(cur_obs, hwp) else: - logging.warning("HWP provided, but no precomputed pointings passed.") + logging.warning( + "To use an external HWP object, pass pre-calculated pointings" + ) scan_map( tod=getattr(cur_obs, component), diff --git a/litebird_sim/simulations.py b/litebird_sim/simulations.py index 07493f65c..10492f4e7 100644 --- a/litebird_sim/simulations.py +++ b/litebird_sim/simulations.py @@ -1112,7 +1112,11 @@ def describe_mpi_distribution(self) -> Optional[MpiDistributionDescr]: numba_num_of_threads_all = [] # type: list[int] for obs in self.observations: - cur_det_names = list(obs.name) + try: + cur_det_names = list(obs.name) + except TypeError: + # This observation has no `.name` field, so it is empty + cur_det_names = [""] shapes = [ tuple(getattr(obs, cur_tod.name).shape) for cur_tod in self.tod_list @@ -1789,36 +1793,36 @@ def make_binned_map_splits( ) if write_to_disk: filenames = [] - for ds in detector_splits: - for ts in time_splits: - result = make_binned_map( - nside=nside, - observations=self.observations, - output_coordinate_system=output_coordinate_system, - components=components, - detector_split=ds, - time_split=ts, - pointings_dtype=pointings_dtype, - ) - file = f"binned_map_DET{ds}_TIME{ts}.fits" - names = ["I", "Q", "U"] - result = list(result.__dict__.items()) - mapp = result.pop(0)[1] - inv_cov = result.pop(0)[1] - coords = result.pop(0)[1].name - del result - inv_cov = inv_cov.T[np.tril_indices(3)] - inv_cov[[2, 3]] = inv_cov[[3, 2]] - inv_cov = list(inv_cov) - if include_inv_covariance: - names.extend(["II", "IQ", "IU", "QQ", "QU", "UU"]) - for _ in range(6): - mapp = np.append(mapp, inv_cov.pop(0)[None, :], axis=0) - filenames.append( - self.write_healpix_map( - file, mapp, column_names=names, coord=coords + if MPI_COMM_GRID.is_this_process_in_grid(): + for ds in detector_splits: + for ts in time_splits: + result = make_binned_map( + nside=nside, + observations=self.observations, + output_coordinate_system=output_coordinate_system, + components=components, + detector_split=ds, + time_split=ts, + ) + file = f"binned_map_DET{ds}_TIME{ts}.fits" + names = ["I", "Q", "U"] + result = list(result.__dict__.items()) + mapp = result.pop(0)[1] + inv_cov = result.pop(0)[1] + coords = result.pop(0)[1].name + del result + inv_cov = inv_cov.T[np.tril_indices(3)] + inv_cov[[2, 3]] = inv_cov[[3, 2]] + inv_cov = list(inv_cov) + if include_inv_covariance: + names.extend(["II", "IQ", "IU", "QQ", "QU", "UU"]) + for _ in range(6): + mapp = np.append(mapp, inv_cov.pop(0)[None, :], axis=0) + filenames.append( + self.write_healpix_map( + file, mapp, column_names=names, coord=coords + ) ) - ) return filenames else: binned_maps = {} @@ -1957,7 +1961,7 @@ def make_destriped_map_splits( baselines = result.baselines recycled_convergence = result.converged - if append_to_report: + if append_to_report and MPI_COMM_GRID.is_this_process_in_grid(): self._build_and_append_destriped_report( "report_destriper_splits.md", ts, ds, result ) @@ -1966,12 +1970,14 @@ def make_destriped_map_splits( base_file = ( f"DET{ds}_TIME{ts}_baselines_mpi{MPI_COMM_WORLD.rank:04d}.fits" ) - save_destriper_results( - result, - output_folder=self.base_path, - custom_dest_file=dest_file, - custom_base_file=base_file, - ) + + if MPI_COMM_GRID.is_this_process_in_grid(): + save_destriper_results( + result, + output_folder=self.base_path, + custom_dest_file=dest_file, + custom_base_file=base_file, + ) filenames.append((dest_file, base_file)) del baselines return filenames @@ -2002,7 +2008,7 @@ def make_destriped_map_splits( baselines = destriped_maps[f"{ds}_{ts}"].baselines recycled_convergence = destriped_maps[f"{ds}_{ts}"].converged - if append_to_report: + if append_to_report and MPI_COMM_GRID.is_this_process_in_grid(): self._build_and_append_destriped_report( "report_destriper_splits.md", ts, @@ -2057,7 +2063,7 @@ def make_destriped_map( pointings_dtype=pointings_dtype, ) - if append_to_report: + if append_to_report and MPI_COMM_GRID.is_this_process_in_grid(): self._build_and_append_destriped_report( "report_destriper.md", detector_split, time_split, results ) diff --git a/test/test_mpi.py b/test/test_mpi.py index d7e3d4d69..880036050 100644 --- a/test/test_mpi.py +++ b/test/test_mpi.py @@ -1,13 +1,20 @@ # -*- encoding: utf-8 -*- # NOTE: all the following tests should be valid also in a serial execution +from io import StringIO from tempfile import TemporaryDirectory -import numpy as np import astropy.time as astrotime +import numpy as np + import litebird_sim as lbs +# Do *not* call pytest on these tests! Some of them accept a path as an argument, +# and PyTest will pass a different path for each MPI process. This is *wrong*, +# as some of these tests assume that all the MPI processes get the same path! + + def test_observation_time(): comm_world = lbs.MPI_COMM_WORLD ref_time = astrotime.Time("2020-02-20", format="iso") @@ -79,7 +86,8 @@ def test_construction_from_detectors(): comm_world = lbs.MPI_COMM_WORLD if comm_world.rank == 0: - print(f"MPI configuration: {lbs.MPI_CONFIGURATION}") + tmpbuf = StringIO() + print(f"MPI configuration: {lbs.MPI_CONFIGURATION}", file=tmpbuf) det1 = dict( name="pol01", @@ -503,6 +511,162 @@ def test_simulation_random(): assert state3["state"]["state"] != state4["state"]["state"] +def _test_empty_observations(tmp_path, use_hwp): + """ + Check that when there are empty observations the code still runs + """ + comm = lbs.MPI_COMM_WORLD + + if comm.size != 3: + # This test is meant to be run with 3 processes + return + + ### Mission params + telescope = "LFT" + channel = "L4-140" + detector_list = [ + "000_005_027_UA_140_B", + "000_005_027_UA_140_T", + ] + mission_time_days = 10 + detector_sampling_freq = 1 + + ### Simulation params + imo_version = "vPTEP" + base_path = "mwe" + nside = 16 + + imo = lbs.Imo() + + # Simulation initialization + sim = lbs.Simulation( + base_path=base_path, + start_time=0.0, + duration_s=mission_time_days * 24 * 60 * 60.0, + random_seed=5455, + mpi_comm=comm, + imo=imo, + ) + + # Instrument definition + sim.set_instrument( + lbs.InstrumentInfo.from_imo( + imo, + f"/releases/{imo_version}/satellite/{telescope}/instrument_info", + ) + ) + + # Detector list + dets = [] + for n_det in detector_list: + det = lbs.DetectorInfo.from_imo( + url=f"/releases/{imo_version}/satellite/{telescope}/{channel}/{n_det}/detector_info", + imo=imo, + ) + det.sampling_rate_hz = detector_sampling_freq + dets.append(det) + + # Scanning strategy + sim.set_scanning_strategy( + scanning_strategy=lbs.SpinningScanningStrategy.from_imo( + imo=imo, + url=f"/releases/{imo_version}/satellite/scanning_parameters", + ), + ) + + if use_hwp: + # Ideal Half-wave plate + sim.set_hwp( + lbs.IdealHWP( + sim.instrument.hwp_rpm * 2 * np.pi / 60, + ), + ) + + # Create observations. Note that we require 3 observations per detector but + # 2 detector blocks + sim.create_observations( + detectors=dets, + num_of_obs_per_detector=3, + n_blocks_det=2, + n_blocks_time=1, + split_list_over_processes=False, + ) + distr = sim.describe_mpi_distribution() + + # Be sure that printing the MPI distribution does not make the code crash + if lbs.MPI_COMM_WORLD.rank == 0: + tmpdest = StringIO() + print(distr, file=tmpdest) + + # Of the three processes, the last one should not be included in the MPI grid + if comm.rank == 0: + assert lbs.MPI_COMM_GRID.is_this_process_in_grid() + elif comm.rank == 1: + assert lbs.MPI_COMM_GRID.is_this_process_in_grid() + elif comm.rank == 2: + assert not lbs.MPI_COMM_GRID.is_this_process_in_grid() + + # Compute pointings + sim.prepare_pointings() + + # Channel info + ch_info = [] + n_ch_info = lbs.FreqChannelInfo.from_imo( + imo, + f"/releases/{imo_version}/satellite/{telescope}/{channel}/channel_info", + ) + ch_info.append(n_ch_info) + + # CMB map + Mbsparams = lbs.MbsParameters( + make_cmb=True, + make_fg=False, + seed_cmb=1, + gaussian_smooth=True, + bandpass_int=False, + nside=nside, + units="uK_CMB", + maps_in_ecliptic=False, + output_string="mbs_cmb_lens", + ) + + mbs_obj = lbs.Mbs( + simulation=sim, + parameters=Mbsparams, + channel_list=ch_info, + ) + + if comm.rank == 0: + input_maps = mbs_obj.run_all() + else: + input_maps = None + + input_maps = lbs.MPI_COMM_WORLD.bcast(input_maps, 0) + + # Scanning the sky + lbs.scan_map_in_observations( + sim.observations, maps=input_maps[0][channel], input_map_in_galactic=True + ) + + binned_maps = sim.make_binned_map(nside) + + destriped_maps = sim.make_destriped_map(nside) + + if comm.rank != 2: + assert binned_maps is not None + assert destriped_maps is not None + else: + assert binned_maps is None + assert destriped_maps is None + + sim.flush() + + +def test_empty_observations(tmp_path): + for use_hwp in [True, False]: + _test_empty_observations(tmp_path, use_hwp) + + if __name__ == "__main__": test_observation_time() test_construction_from_detectors() @@ -512,27 +676,18 @@ def test_simulation_random(): test_observation_tod_set_blocks() test_simulation_random() - # It's critical that all MPI processes use the same output directory - if lbs.MPI_ENABLED: - if lbs.MPI_COMM_WORLD.rank == 0: + for cur_test_fn in [test_write_hdf5_mpi, test_empty_observations]: + # It's critical that all MPI processes use the same output directory + if lbs.MPI_ENABLED: + if lbs.MPI_COMM_WORLD.rank == 0: + tmp_dir = TemporaryDirectory() + tmp_path = tmp_dir.name + lbs.MPI_COMM_WORLD.bcast(tmp_path, root=0) + else: + tmp_dir = None + tmp_path = lbs.MPI_COMM_WORLD.bcast(None, root=0) + else: tmp_dir = TemporaryDirectory() tmp_path = tmp_dir.name - lbs.MPI_COMM_WORLD.bcast(tmp_path, root=0) - else: - tmp_dir = None - tmp_path = lbs.MPI_COMM_WORLD.bcast(None, root=0) - else: - tmp_dir = TemporaryDirectory() - tmp_path = tmp_dir.name - - try: - test_write_hdf5_mpi(tmp_path) - finally: - # Now we can remove the temporary directory, but first make - # sure that there are no other MPI processes still waiting to - # finish - if lbs.MPI_ENABLED: - lbs.MPI_COMM_WORLD.barrier() - if tmp_dir: - tmp_dir.cleanup() + cur_test_fn(tmp_path)