Avoid NaNs in Spherical and Spins Transforms#205
Conversation
|
Important Review skippedAuto reviews are disabled on base/target branches other than the default branch. Please check the settings in the CodeRabbit UI or the You can disable this status message by setting the WalkthroughThe changes introduce new numerically stable utility functions for converting Cartesian coordinates to spherical angles, addressing edge cases where gradients could become undefined. These functions are added to a utility module and then integrated into existing transformation and conversion routines, replacing direct calls to standard trigonometric functions. The updates also include minor formatting improvements for code clarity and consistency, particularly around exponentiation operators and angle normalization. No public API signatures are changed, and the core logic remains intact. Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant Transform
participant Utils
User->>Transform: Request inverse transform (e.g., SphereSpinToCartesianSpinTransform)
Transform->>Utils: Call carte_to_spherical_angles(x, y, z)
Utils-->>Transform: Return (theta, phi) with stable gradients
Transform-->>User: Return spherical angles (theta, phi)
Poem
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Caution
Inline review comments failed to post. This is likely due to GitHub's limits when posting large numbers of comments. If you are seeing this consistently it is likely a permissions issue. Please check "Moderation" -> "Code review limits" under your organization settings.
Actionable comments posted: 3
🔭 Outside diff range comments (1)
src/jimgw/single_event/utils.py (1)
280-320: 🛠️ Refactor suggestionClamping cos_beta to avoid NaNs.
sin_beta = jnp.sqrt(1 - cos_beta**2)can produce NaNs ifcos_beta**2slightly exceeds 1 due to floating-point rounding. Consider clampingcos_betainto[-1, 1]right before.- sin_beta = jnp.sqrt(1 - cos_beta ** 2) + cos_beta_clamped = jnp.clip(cos_beta, -1.0, 1.0) + sin_beta = jnp.sqrt(1.0 - cos_beta_clamped**2)
🧹 Nitpick comments (25)
test/integration/test_GW150914_D.py (1)
114-115: Commented out results processingThe lines for retrieving and printing samples are commented out, making this more of a smoke test for initialization rather than validating results. Consider adding a basic assertion to ensure the test is meaningful.
Add a basic assertion to check that sampling completed successfully rather than just commenting out the result processing:
-# jim.get_samples() -# jim.print_summary() +# Minimal validation that sampling worked +samples = jim.get_samples() +assert samples is not None, "Sampling did not produce any results"test/unit/source_files/sky_locations/bilby_sky_locations.py (1)
36-39: Verify output directory existsThe script assumes the output directory exists, but doesn't check or create it if missing.
Consider adding a directory check/creation before saving:
+import os +os.makedirs(f"{outdir}/sky_locations", exist_ok=True) + jnp.savez( f"{outdir}/sky_locations/test_{ifo_pair[0]}_{ifo_pair[1]}_{time}.npz", **input_dict, )example/GW150914_IMRPhenomD.py (1)
25-25: Remove unused import.
optaxis imported but not used within this file.- import optax🧰 Tools
🪛 Ruff (0.8.2)
25-25:
optaximported but unusedRemove unused import:
optax(F401)
src/jimgw/constants.py (1)
1-1: Acknowledge TODO comment on type stubs.The comment indicates missing type stubs for
astropy. Consider addressing them in a future improvement to enhance type checking.test/integration/test_periodic_uniform.py (2)
6-6: Remove unused import.The
BoundToUnboundimport is only used in commented-out code.-from jimgw.transforms import PeriodicTransform, BoundToUnbound +from jimgw.transforms import PeriodicTransform🧰 Tools
🪛 Ruff (0.8.2)
6-6:
jimgw.transforms.BoundToUnboundimported but unusedRemove unused import:
jimgw.transforms.BoundToUnbound(F401)
76-99: Useful visualization code.The commented visualization code provides a good way to analyze the sampling results. Consider uncommenting it or moving it to a separate function that can be optionally called.
Consider moving the visualization code to a separate function that can be called optionally:
def visualize_results(samples, prior): import matplotlib.pyplot as plt plt.figure(figsize=(10, 4)) plt.subplot(1, 2, 1) plt.plot(samples["test"]) plt.ylim(0.0, 2.0 * jnp.pi) plt.xlim(0, 300) plt.title("Walker history") plt.subplot(1, 2, 2) plt.hist(samples["test"], label="Samples", density=True, bins=50) x = jnp.linspace(0.0, 2.0 * jnp.pi, 1000) y = (jnp.cos(x) + 1.0) / 2.0 / jnp.pi plt.plot(x, y, label="Likelihood") plt.ylim(0.0) plt.xlim(0.0, 2.0 * jnp.pi) plt.legend() plt.title("Distribution comparison") plt.tight_layout() if prior.n_dim == 1: plt.savefig("figures/results_uniform.jpg") else: plt.savefig("figures/results_periodic.jpg") plt.close() # Call at the end if desired # visualize_results(samples, prior)src/jimgw/utils.py (1)
107-149: Reduce code duplication in carte_to_spherical_angles.The function correctly computes spherical coordinates, but duplicates logic already present in
safe_polar_angleandsafe_arctan2.Consider reusing the already implemented functions to reduce code duplication:
def carte_to_spherical_angles( x: Float[Array, " n"], y: Float[Array, " n"], z: Float[Array, " n"], default_value: float = 0.0, ) -> Float[Array, " n n"]: """ A numerically stable method to compute the spherical angles upon taking gradient. For more details, see * `safe_polar_angle` for the polar angle and * `safe_arctan2` for the azimuthal angle. Parameters ========== x: array-like x-coordinate of the point y: array-like y-coordinate of the point z: array-like z-coordinate of the point default_value: float arctan2 value to return at (0, 0), default is 0.0 Returns ======= theta: array-like: The polar angle, in radians, within [0, π] phi: array-like: The signed azimuthal angle, in radians, within [-π, π] """ - align_condition = (jnp.absolute(x) < EPS) & (jnp.absolute(y) < EPS) - theta = jnp.where( - align_condition, - jnp.arccos(jnp.sign(z)), - jnp.arccos(z / jnp.sqrt(x ** 2 + y ** 2 + z ** 2)), - ) - phi = jnp.where( - align_condition, - default_value * jnp.ones_like(x), - jnp.atan2(y, x), - ) + theta = safe_polar_angle(x, y, z) + phi = safe_arctan2(y, x, default_value) return theta, phisrc/jimgw/jim.py (2)
33-55: Constructor: Ensure consistent naming and consider user clarity.
Introducing many parameters with default values (e.g.,n_chains,n_local_steps, etc.) is helpful for flexibility. However, be sure that the naming and defaults are sufficiently descriptive so users know which settings to adjust.
77-79: Conditional RNG key check: Confirm user intent.
Usingif rng_key is jax.random.PRNGKey(0):as a sentinel for a "default key" is workable, yet it could inadvertently skip custom keys with the same initialization seed. Consider clarifying the condition or documenting that this specifically checks for the default key usage.src/jimgw/prior.py (2)
279-322: NewGaussianPriorclass.
The parameterization withmuandsigmais straightforward. The usage ofStandardNormalDistributionplus a scale and offset transform is a robust approach. Verify that edge cases (e.g.,sigma <= 0) are handled.+assert sigma > 0, "Standard deviation must be positive in GaussianPrior"
403-430:RayleighPriorclass for 1D distribution.
UsingUniformPrior(0.0, 1.0, [param_base])plusRayleighTransformis a neat approach. Confirm thatsigmais positive to avoid invalid transformations.+assert sigma > 0, "Rayleigh scale parameter must be positive"src/jimgw/single_event/transforms.py (1)
407-408: Summing the square of the pattern terms for SNR weighting.
R_dets2 += p_mode_term ** 2 + c_mode_term ** 2is correct for aggregator logic. Ensure numerical stability ifp_mode_termorc_mode_termapproach large values.test/unit/test_sky_position_transform.ipynb (4)
118-121: Consider using JAX for all random sampling to ensure reproducible results.
Currently, lines 118-121 use NumPy for generating random values, while JAX is used elsewhere. Unifying the random number generation approach can improve consistency and reproducibility.
168-170: Enhance reliability by turning printed diffs into assertions.
Instead of printing delta_x differences, consider adding assertions to ensure that any discrepancies remain below a certain threshold. This will make tests more robust and automated.
207-208: Add a stricter assertion for GMST differences.
You are printing “Mean difference in GMST: ...” without failing the test if the difference grows. Consider implementing an assertion to ensure GMST differences stay below a desired tolerance.
244-246: Incorporate an assertion on RA and DEC differences.
Currently, the code merely prints the mean differences for RA and DEC. Consider enforcing an upper bound usingjnp.allcloseor a similar assertion to detect unexpected discrepancies over the test samples.test/unit/test_transform.py (4)
147-157: Verify boundary values in SineTransform tests.
You're testing with angle=0.3, which is comfortably within [-π/2, π/2]. Consider adding tests at boundary angles like ±π/2 to ensure correct behavior at the extremes.
176-186: Exercise edge cases in CosineTransform tests.
Currently, the angle used is 1.2. Consider adding boundary angle tests (e.g., 0 and π) to confirm the transform handles extreme inputs properly.
316-318: Rename test_powerlaw_transformn for clarity.
The method nametest_powerlaw_transformnis not very descriptive. Consider renaming it to something liketest_powerlaw_transform_with_alpha_not_minus_onefor improved readability.
349-367: Extend distance transform tests for more edge cases.
The coverage is good, but consider testing extremely large or small distances (beyond the typical range) to ensure robust numerical handling and reveal potential corner cases.src/jimgw/transforms.py (2)
325-330: Consider explicit domain checks in SineTransform.
Ifx[name_mapping[1][i]]is outside [-1, 1],jnp.arcsinwill produce NaNs. Adding pre-checks or error handling can improve debuggability.
503-515: Introduce domain validation for PowerLaw with alpha = -1.0.
If user-supplied values go beyond [0,1], the transform or inverse log calculation might fail. Consider validating input domain or providing an explicit error to avoid silent NaNs.src/jimgw/single_event/utils.py (3)
42-64: Validate domain or sign constraints for masses.
While this function correctly computes the total mass and mass ratio, consider checking thatm1andm2are both positive to avoid undefined behavior.
550-684: Consider dot-product clamping for iota computation.
Iniota = jnp.arccos(jnp.dot(N, LNh)), if floating-point rounding pushes the dot product outside [-1,1],arccosmay return NaNs. Clamping could guard against that.- iota = jnp.arccos(jnp.dot(N, LNh)) + dot_val = jnp.clip(jnp.dot(N, LNh), -1.0, 1.0) + iota = jnp.arccos(dot_val)
686-807: Broad inverse spin transformation is well-structured.
Similar numeric considerations apply: consider clamping where arcsin/arccos might exceed domain. Otherwise, logic is appropriately mirroring the forward transform.
🛑 Comments failed to post (3)
pyproject.toml (1)
16-16:
⚠️ Potential issueUpdated dependency versions
Updates to flowmc (0.4.2) and jax (≥0.5.0) may introduce compatibility issues. The pipeline failure indicates a missing module
flowMC.resource_strategy_bundle.Check that the flowmc version 0.4.2 has the required modules:
#!/bin/bash # Check for the existence of the resource_strategy_bundle module in flowmc pip show flowmc # Try importing the module to see the exact error python -c "from flowMC.resource_strategy_bundle import RQSpline_MALA_PT_Bundle; print('Module exists')" || echo "Module not found"Also applies to: 18-18
test/integration/test_periodic_uniform.py (1)
1-8:
⚠️ Potential issueClean imports, but fix the pipeline failure.
The pipeline shows a failure with the import of 'flowMC.resource_strategy_bundle' in the Jim class. This is causing the test to fail.
This test depends on a module that's missing or has changed. Please check the Jim class implementation and ensure all dependencies are properly installed or update the import path.
🧰 Tools
🪛 Ruff (0.8.2)
6-6:
jimgw.transforms.BoundToUnboundimported but unusedRemove unused import:
jimgw.transforms.BoundToUnbound(F401)
🪛 GitHub Actions: Run Tests
[error] 1-1: ModuleNotFoundError: No module named 'flowMC.resource_strategy_bundle' imported in src/jimgw/jim.py
src/jimgw/single_event/utils.py (1)
277-278: 🛠️ Refactor suggestion
Ensure real-valued output for eta_to_q.
Iftemp**2 < 1,(temp**2 - 1)becomes negative, leading to a potential complex number. Consider enforcing domain constraints or using a safe clamp.- return temp - (temp**2 - 1)**0.5 + safe_root = jnp.sqrt(jnp.maximum(temp**2 - 1, 0.0)) # clamp to avoid negative + return temp - safe_root📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.safe_root = jnp.sqrt(jnp.maximum(temp**2 - 1, 0.0)) # clamp to avoid negative return temp - safe_root
|
@coderabbitai review |
✅ Actions performedReview triggered.
|
|
@kazewong I have updated the docstring style and fixed the merge conflict, it is ready for review. |
This PR attempts to address part of issue #204, specifically, the
nanoccurs at particular parameter values.The whole concept is to redefine the
arccosandarctan2functions at the problematic points so that the gradient has a natural extension at those points.Explicitly, they are:
Summary by CodeRabbit