Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# HEAD

- Save memory in pointing generation [#488](https://github.com/litebird/litebird_sim/pull/488)

-   **Breaking change**: Major reworking of the interfaces and handling of inputs across the framework [#479](https://github.com/litebird/litebird_sim/pull/479), in detail:

1. Rework the handling of spherical harmonics by integrating ducc0 as the primary engine for SHT operations, including interpolation.
Expand Down
16 changes: 15 additions & 1 deletion litebird_sim/observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from .input_sky import SkyGenerationParams
from .maps_and_harmonics import HealpixMap, SphericalHarmonics
from .mpi import MPI_COMM_GRID, _SerialMpiCommunicator
from .pointings import PointingProvider
from .pointings import PointingProvider, DEFAULT_INTERNAL_BUFFER_SIZE_FOR_POINTINGS_MB
from .scanning import RotQuaternion
from .units import Units

Expand Down Expand Up @@ -910,6 +910,7 @@ def prepare_pointings(
instrument: InstrumentInfo,
spin2ecliptic_quats: RotQuaternion,
hwp: HWP | None = None,
maximum_internal_buffer_mem_mb: float = DEFAULT_INTERNAL_BUFFER_SIZE_FOR_POINTINGS_MB,
) -> None:
"""Prepare quaternion-based pointing and HWP information for this observation.

Expand All @@ -936,6 +937,10 @@ def prepare_pointings(
Optional HWP model. If provided, it is stored and its Mueller matrix
applied to all detectors lacking one.

maximum_internal_buffer_mem_mb (float):
Maximum number of megabytes (MB) to allocate for internal buffers during
the computation of pointings. Set to -1 to remove any limit.

Raises:
AssertionError:
If `hwp` is not provided and one or more detectors do not have a
Expand All @@ -947,10 +952,19 @@ def prepare_pointings(
internal :class:`.PointingProvider`.
"""

assert (maximum_internal_buffer_mem_mb > 0) or (
maximum_internal_buffer_mem_mb == -1
), (
"Invalid value for maximum_internal_buffer_mem_mb ({val}), it must either be -1 or a positive number".format(
val=maximum_internal_buffer_mem_mb
)
)

bore2ecliptic_quats = spin2ecliptic_quats * instrument.bore2spin_quat
pointing_provider = PointingProvider(
bore2ecliptic_quats=bore2ecliptic_quats,
hwp=hwp,
maximum_internal_buffer_mem_mb=maximum_internal_buffer_mem_mb,
)

self.pointing_provider = pointing_provider
Expand Down
90 changes: 70 additions & 20 deletions litebird_sim/pointings.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
RotQuaternion,
)

DEFAULT_INTERNAL_BUFFER_SIZE_FOR_POINTINGS_MB = 256.0


class PointingProvider:
"""Provides detector pointing angles and HWP angles based on scanning geometry.
Expand Down Expand Up @@ -51,10 +53,42 @@ def __init__(
# Note that we require here *boresight*→Ecliptic instead of *spin*→Ecliptic
bore2ecliptic_quats: RotQuaternion,
hwp: HWP | None = None,
maximum_internal_buffer_mem_mb: float = DEFAULT_INTERNAL_BUFFER_SIZE_FOR_POINTINGS_MB,
):
self.bore2ecliptic_quats = bore2ecliptic_quats
self.maximum_internal_buffer_mem_mb = maximum_internal_buffer_mem_mb
self.hwp = hwp

def _optimal_block_lengths(self, total_nsamples: int) -> list[int]:
# Size of one quaternion, in bytes
quaternion_size_bytes = 4 * self.bore2ecliptic_quats.quats.itemsize

# Average number of quaternions in each block, to make sure that no more than
# a fixed number of MB is ever needed to store them
quaternions_per_block = int(
self.maximum_internal_buffer_mem_mb * 1024 * 1024 / quaternion_size_bytes
)

# How many blocks of quaternions will need to be processed
number_of_blocks = total_nsamples // quaternions_per_block

# If something was left out of the previous calculation, include an additional block
if total_nsamples % quaternions_per_block != 0:
number_of_blocks += 1

# Instead of making the first N−1 blocks of the same size and add any leftover to the
# last block, which might even have 1 sample, try to create the blocks so that they all
# have roughly the same number of elements
result = []
quaternions_left = total_nsamples
while quaternions_left > 0:
current_block_length = quaternions_left // (number_of_blocks - len(result))
result.append(current_block_length)
quaternions_left -= current_block_length

assert sum(result) == total_nsamples
return result

def has_hwp(self):
"""Return ``True`` if a HWP has been set.

Expand Down Expand Up @@ -122,27 +156,21 @@ def get_pointings(
one is a float and the other is an `astropy.time.Time`).
"""

full_quaternions = (self.bore2ecliptic_quats * detector_quat).slerp(
start_time=start_time,
sampling_rate_hz=sampling_rate_hz,
nsamples=nsamples,
)
if isinstance(start_time, astropy.time.Time):
assert isinstance(start_time_global, astropy.time.Time), (
"The start_time is a astropy.time.Time object, so start_time_global must also be an astropy.time.Time object."
)
start_time_s = (start_time - start_time_global).to("s").value
else:
assert isinstance(start_time_global, (int, float)), (
"The start_time is a float, so start_time_global must also be a float."
)
start_time_s = start_time - start_time_global

if self.hwp is not None:
if hwp_buffer is None:
hwp_buffer = np.empty(nsamples, dtype=pointings_dtype)

if isinstance(start_time, astropy.time.Time):
assert isinstance(start_time_global, astropy.time.Time), (
"The start_time is a astropy.time.Time object, so start_time_global must also be an astropy.time.Time object."
)
start_time_s = (start_time - start_time_global).to("s").value
else:
assert isinstance(start_time_global, (int, float)), (
"The start_time is a float, so start_time_global must also be a float."
)
start_time_s = start_time - start_time_global

self.hwp.get_hwp_angle(
output_buffer=hwp_buffer,
start_time_s=start_time_s,
Expand All @@ -154,10 +182,32 @@ def get_pointings(
if pointing_buffer is None:
pointing_buffer = np.empty(shape=(nsamples, 3), dtype=pointings_dtype)

all_compute_pointing_and_orientation(
result_matrix=pointing_buffer,
quat_matrix=full_quaternions,
)
block_lengths = self._optimal_block_lengths(total_nsamples=nsamples)

det_to_ecliptic_quats = self.bore2ecliptic_quats * detector_quat
cur_time = start_time
start_sample = 0
for cur_block_length in block_lengths:
cur_quaternions = det_to_ecliptic_quats.slerp(
start_time=cur_time,
sampling_rate_hz=sampling_rate_hz,
nsamples=cur_block_length,
)
all_compute_pointing_and_orientation(
result_matrix=pointing_buffer[
start_sample : (start_sample + cur_block_length), :
],
quat_matrix=cur_quaternions,
)

if isinstance(cur_time, astropy.time.Time):
cur_time += astropy.time.TimeDelta(
cur_block_length / sampling_rate_hz, format="sec"
)
else:
cur_time += cur_block_length / sampling_rate_hz

start_sample += cur_block_length

return pointing_buffer, hwp_buffer

Expand Down
51 changes: 51 additions & 0 deletions test/test_scanning.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,3 +538,54 @@ def test_time_dependent_quaternions_operations():
expected[0, :] = qconst1.quats[0, :]
lbs.quat_right_multiply(expected[0, :], *qconst2.quats[0, :])
np.testing.assert_allclose(actual=result.quats, desired=expected)


def test_chunked_pointing_generation():
quat_array = lbs.RotQuaternion(
quats=np.array(
[
# This is not really a “rotating” quaternion: we repeat
# the same quaternion (90° rotation around x) thrice
# just for testing
[1.0, 0.0, 0.0, 1.0],
[1.0, 0.0, 0.0, 1.0],
[1.0, 0.0, 0.0, 1.0],
]
/ np.sqrt(2)
),
start_time=0.0,
sampling_rate_hz=0.25, # Four seconds per quaternion
)

# Make room for 5 quaternions at most
quaternion_size_in_bytes = 32
pp = lbs.PointingProvider(
bore2ecliptic_quats=quat_array,
maximum_internal_buffer_mem_mb=(quaternion_size_in_bytes * 5) / (1024 * 1024),
)

num_of_samples = 12
block_lengths = pp._optimal_block_lengths(total_nsamples=num_of_samples)
assert len(block_lengths) == 3
assert sum(block_lengths) == num_of_samples

pointing_buf, hwp_buf = pp.get_pointings(
detector_quat=lbs.RotQuaternion(quats=np.array([[0.0, 0.0, 0.0, 1.0]])),
start_time=0.0,
start_time_global=0.0,
sampling_rate_hz=1.0,
nsamples=num_of_samples,
)
assert pointing_buf.shape == (num_of_samples, 3)

# We expect the +z axis of the detector to be rotated by 90° around the x axis,
# so that it should point towards −y. This implies that ϑ = π/2 and φ = −π/2

# ϑ
np.testing.assert_allclose(pointing_buf[:, 0], np.pi / 2)

# φ
np.testing.assert_allclose(pointing_buf[:, 1], -np.pi / 2)

# ψ
np.testing.assert_allclose(pointing_buf[:, 2], 0, atol=1e-15)
Loading