diff --git a/CHANGELOG.md b/CHANGELOG.md index a2df2ba1..d4dc40a1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,7 @@ # HEAD +- Save memory in pointing generation [#488](https://github.com/litebird/litebird_sim/pull/488) + -   **Breaking change**: Major reworking of the interfaces and handling of inputs across the framework [#479](https://github.com/litebird/litebird_sim/pull/479), in detail: 1. Rework the handling of spherical harmonics by integrating ducc0 as the primary engine for SHT operations, including interpolation. diff --git a/litebird_sim/observations.py b/litebird_sim/observations.py index e026c3ba..aa0b4a2d 100644 --- a/litebird_sim/observations.py +++ b/litebird_sim/observations.py @@ -14,7 +14,7 @@ from .input_sky import SkyGenerationParams from .maps_and_harmonics import HealpixMap, SphericalHarmonics from .mpi import MPI_COMM_GRID, _SerialMpiCommunicator -from .pointings import PointingProvider +from .pointings import PointingProvider, DEFAULT_INTERNAL_BUFFER_SIZE_FOR_POINTINGS_MB from .scanning import RotQuaternion from .units import Units @@ -910,6 +910,7 @@ def prepare_pointings( instrument: InstrumentInfo, spin2ecliptic_quats: RotQuaternion, hwp: HWP | None = None, + maximum_internal_buffer_mem_mb: float = DEFAULT_INTERNAL_BUFFER_SIZE_FOR_POINTINGS_MB, ) -> None: """Prepare quaternion-based pointing and HWP information for this observation. @@ -936,6 +937,10 @@ def prepare_pointings( Optional HWP model. If provided, it is stored and its Mueller matrix applied to all detectors lacking one. + maximum_internal_buffer_mem_mb (float): + Maximum number of megabytes (MB) to allocate for internal buffers during + the computation of pointings. Set to -1 to remove any limit. + Raises: AssertionError: If `hwp` is not provided and one or more detectors do not have a @@ -947,10 +952,19 @@ def prepare_pointings( internal :class:`.PointingProvider`. """ + assert (maximum_internal_buffer_mem_mb > 0) or ( + maximum_internal_buffer_mem_mb == -1 + ), ( + "Invalid value for maximum_internal_buffer_mem_mb ({val}), it must either be -1 or a positive number".format( + val=maximum_internal_buffer_mem_mb + ) + ) + bore2ecliptic_quats = spin2ecliptic_quats * instrument.bore2spin_quat pointing_provider = PointingProvider( bore2ecliptic_quats=bore2ecliptic_quats, hwp=hwp, + maximum_internal_buffer_mem_mb=maximum_internal_buffer_mem_mb, ) self.pointing_provider = pointing_provider diff --git a/litebird_sim/pointings.py b/litebird_sim/pointings.py index 0fee9f39..c074903e 100644 --- a/litebird_sim/pointings.py +++ b/litebird_sim/pointings.py @@ -8,6 +8,8 @@ RotQuaternion, ) +DEFAULT_INTERNAL_BUFFER_SIZE_FOR_POINTINGS_MB = 256.0 + class PointingProvider: """Provides detector pointing angles and HWP angles based on scanning geometry. @@ -51,10 +53,42 @@ def __init__( # Note that we require here *boresight*→Ecliptic instead of *spin*→Ecliptic bore2ecliptic_quats: RotQuaternion, hwp: HWP | None = None, + maximum_internal_buffer_mem_mb: float = DEFAULT_INTERNAL_BUFFER_SIZE_FOR_POINTINGS_MB, ): self.bore2ecliptic_quats = bore2ecliptic_quats + self.maximum_internal_buffer_mem_mb = maximum_internal_buffer_mem_mb self.hwp = hwp + def _optimal_block_lengths(self, total_nsamples: int) -> list[int]: + # Size of one quaternion, in bytes + quaternion_size_bytes = 4 * self.bore2ecliptic_quats.quats.itemsize + + # Average number of quaternions in each block, to make sure that no more than + # a fixed number of MB is ever needed to store them + quaternions_per_block = int( + self.maximum_internal_buffer_mem_mb * 1024 * 1024 / quaternion_size_bytes + ) + + # How many blocks of quaternions will need to be processed + number_of_blocks = total_nsamples // quaternions_per_block + + # If something was left out of the previous calculation, include an additional block + if total_nsamples % quaternions_per_block != 0: + number_of_blocks += 1 + + # Instead of making the first N−1 blocks of the same size and add any leftover to the + # last block, which might even have 1 sample, try to create the blocks so that they all + # have roughly the same number of elements + result = [] + quaternions_left = total_nsamples + while quaternions_left > 0: + current_block_length = quaternions_left // (number_of_blocks - len(result)) + result.append(current_block_length) + quaternions_left -= current_block_length + + assert sum(result) == total_nsamples + return result + def has_hwp(self): """Return ``True`` if a HWP has been set. @@ -122,27 +156,21 @@ def get_pointings( one is a float and the other is an `astropy.time.Time`). """ - full_quaternions = (self.bore2ecliptic_quats * detector_quat).slerp( - start_time=start_time, - sampling_rate_hz=sampling_rate_hz, - nsamples=nsamples, - ) + if isinstance(start_time, astropy.time.Time): + assert isinstance(start_time_global, astropy.time.Time), ( + "The start_time is a astropy.time.Time object, so start_time_global must also be an astropy.time.Time object." + ) + start_time_s = (start_time - start_time_global).to("s").value + else: + assert isinstance(start_time_global, (int, float)), ( + "The start_time is a float, so start_time_global must also be a float." + ) + start_time_s = start_time - start_time_global if self.hwp is not None: if hwp_buffer is None: hwp_buffer = np.empty(nsamples, dtype=pointings_dtype) - if isinstance(start_time, astropy.time.Time): - assert isinstance(start_time_global, astropy.time.Time), ( - "The start_time is a astropy.time.Time object, so start_time_global must also be an astropy.time.Time object." - ) - start_time_s = (start_time - start_time_global).to("s").value - else: - assert isinstance(start_time_global, (int, float)), ( - "The start_time is a float, so start_time_global must also be a float." - ) - start_time_s = start_time - start_time_global - self.hwp.get_hwp_angle( output_buffer=hwp_buffer, start_time_s=start_time_s, @@ -154,10 +182,32 @@ def get_pointings( if pointing_buffer is None: pointing_buffer = np.empty(shape=(nsamples, 3), dtype=pointings_dtype) - all_compute_pointing_and_orientation( - result_matrix=pointing_buffer, - quat_matrix=full_quaternions, - ) + block_lengths = self._optimal_block_lengths(total_nsamples=nsamples) + + det_to_ecliptic_quats = self.bore2ecliptic_quats * detector_quat + cur_time = start_time + start_sample = 0 + for cur_block_length in block_lengths: + cur_quaternions = det_to_ecliptic_quats.slerp( + start_time=cur_time, + sampling_rate_hz=sampling_rate_hz, + nsamples=cur_block_length, + ) + all_compute_pointing_and_orientation( + result_matrix=pointing_buffer[ + start_sample : (start_sample + cur_block_length), : + ], + quat_matrix=cur_quaternions, + ) + + if isinstance(cur_time, astropy.time.Time): + cur_time += astropy.time.TimeDelta( + cur_block_length / sampling_rate_hz, format="sec" + ) + else: + cur_time += cur_block_length / sampling_rate_hz + + start_sample += cur_block_length return pointing_buffer, hwp_buffer diff --git a/test/test_scanning.py b/test/test_scanning.py index 515eb35c..87cca16a 100644 --- a/test/test_scanning.py +++ b/test/test_scanning.py @@ -538,3 +538,54 @@ def test_time_dependent_quaternions_operations(): expected[0, :] = qconst1.quats[0, :] lbs.quat_right_multiply(expected[0, :], *qconst2.quats[0, :]) np.testing.assert_allclose(actual=result.quats, desired=expected) + + +def test_chunked_pointing_generation(): + quat_array = lbs.RotQuaternion( + quats=np.array( + [ + # This is not really a “rotating” quaternion: we repeat + # the same quaternion (90° rotation around x) thrice + # just for testing + [1.0, 0.0, 0.0, 1.0], + [1.0, 0.0, 0.0, 1.0], + [1.0, 0.0, 0.0, 1.0], + ] + / np.sqrt(2) + ), + start_time=0.0, + sampling_rate_hz=0.25, # Four seconds per quaternion + ) + + # Make room for 5 quaternions at most + quaternion_size_in_bytes = 32 + pp = lbs.PointingProvider( + bore2ecliptic_quats=quat_array, + maximum_internal_buffer_mem_mb=(quaternion_size_in_bytes * 5) / (1024 * 1024), + ) + + num_of_samples = 12 + block_lengths = pp._optimal_block_lengths(total_nsamples=num_of_samples) + assert len(block_lengths) == 3 + assert sum(block_lengths) == num_of_samples + + pointing_buf, hwp_buf = pp.get_pointings( + detector_quat=lbs.RotQuaternion(quats=np.array([[0.0, 0.0, 0.0, 1.0]])), + start_time=0.0, + start_time_global=0.0, + sampling_rate_hz=1.0, + nsamples=num_of_samples, + ) + assert pointing_buf.shape == (num_of_samples, 3) + + # We expect the +z axis of the detector to be rotated by 90° around the x axis, + # so that it should point towards −y. This implies that ϑ = π/2 and φ = −π/2 + + # ϑ + np.testing.assert_allclose(pointing_buf[:, 0], np.pi / 2) + + # φ + np.testing.assert_allclose(pointing_buf[:, 1], -np.pi / 2) + + # ψ + np.testing.assert_allclose(pointing_buf[:, 2], 0, atol=1e-15)