Skip to content

Avoid NaNs in Spherical and Spins Transforms#205

Merged
kazewong merged 11 commits intojim-devfrom
fix-spin-transform-nan
May 9, 2025
Merged

Avoid NaNs in Spherical and Spins Transforms#205
kazewong merged 11 commits intojim-devfrom
fix-spin-transform-nan

Conversation

@SSL32081
Copy link
Collaborator

@SSL32081 SSL32081 commented Apr 15, 2025

This PR attempts to address part of issue #204, specifically, the nan occurs at particular parameter values.

The whole concept is to redefine the arccos and arctan2 functions at the problematic points so that the gradient has a natural extension at those points.

Explicitly, they are:

theta = jnp.where(
    (jnp.abs(x) < EPS) & (jnp.abs(y) < EPS),
    jnp.arccos(jnp.sign(z)),
    jnp.arccos(z / jnp.sqrt(x ** 2 + y ** 2 + z ** 2)),
)

phi = jnp.where(
    (jnp.abs(x) < EPS) & (jnp.abs(y) < EPS),
    default_value * jnp.ones_like(x),
    jnp.atan2(y, x),
)

Summary by CodeRabbit

  • New Features
    • Improved numerical stability for angle calculations, reducing the risk of undefined values in edge cases.
  • Refactor
    • Unified and streamlined angle computations across the application for consistency.
  • Style
    • Enhanced code readability with minor formatting and spacing adjustments.

@coderabbitai
Copy link

coderabbitai bot commented Apr 15, 2025

Important

Review skipped

Auto reviews are disabled on base/target branches other than the default branch.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Walkthrough

The changes introduce new numerically stable utility functions for converting Cartesian coordinates to spherical angles, addressing edge cases where gradients could become undefined. These functions are added to a utility module and then integrated into existing transformation and conversion routines, replacing direct calls to standard trigonometric functions. The updates also include minor formatting improvements for code clarity and consistency, particularly around exponentiation operators and angle normalization. No public API signatures are changed, and the core logic remains intact.

Changes

File(s) Change Summary
src/jimgw/utils.py Added constant EPS and three new functions: safe_arctan2, safe_polar_angle, and carte_to_spherical_angles for numerically stable spherical angle computations.
src/jimgw/single_event/utils.py Replaced direct trigonometric computations with the new utility functions for angle calculations; applied consistent formatting and spacing in mathematical expressions.
src/jimgw/single_event/transforms.py Updated imports to include carte_to_spherical_angles; refactored angle computations in inverse transforms to use new utility; improved code formatting and angle normalization.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant Transform
    participant Utils

    User->>Transform: Request inverse transform (e.g., SphereSpinToCartesianSpinTransform)
    Transform->>Utils: Call carte_to_spherical_angles(x, y, z)
    Utils-->>Transform: Return (theta, phi) with stable gradients
    Transform-->>User: Return spherical angles (theta, phi)
Loading

Poem

In the warren where numbers spin and twirl,
We bunnies found gradients in a swirl.
With stable angles now, no NaNs in sight—
Our math hops gently, calculations light!
From Cartesian to spheres, with trigonometric cheer,
These fluffy improvements make the code more clear.
🥕✨


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR.
  • @coderabbitai generate sequence diagram to generate a sequence diagram of the changes in this PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

@SSL32081 SSL32081 changed the base branch from main to jim-dev April 15, 2025 18:40
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Inline review comments failed to post. This is likely due to GitHub's limits when posting large numbers of comments. If you are seeing this consistently it is likely a permissions issue. Please check "Moderation" -> "Code review limits" under your organization settings.

Actionable comments posted: 3

🔭 Outside diff range comments (1)
src/jimgw/single_event/utils.py (1)

280-320: 🛠️ Refactor suggestion

Clamping cos_beta to avoid NaNs.
sin_beta = jnp.sqrt(1 - cos_beta**2) can produce NaNs if cos_beta**2 slightly exceeds 1 due to floating-point rounding. Consider clamping cos_beta into [-1, 1] right before.

- sin_beta = jnp.sqrt(1 - cos_beta ** 2)
+ cos_beta_clamped = jnp.clip(cos_beta, -1.0, 1.0)
+ sin_beta = jnp.sqrt(1.0 - cos_beta_clamped**2)
🧹 Nitpick comments (25)
test/integration/test_GW150914_D.py (1)

114-115: Commented out results processing

The lines for retrieving and printing samples are commented out, making this more of a smoke test for initialization rather than validating results. Consider adding a basic assertion to ensure the test is meaningful.

Add a basic assertion to check that sampling completed successfully rather than just commenting out the result processing:

-# jim.get_samples()
-# jim.print_summary()
+# Minimal validation that sampling worked
+samples = jim.get_samples()
+assert samples is not None, "Sampling did not produce any results"
test/unit/source_files/sky_locations/bilby_sky_locations.py (1)

36-39: Verify output directory exists

The script assumes the output directory exists, but doesn't check or create it if missing.

Consider adding a directory check/creation before saving:

+import os
+os.makedirs(f"{outdir}/sky_locations", exist_ok=True)
+
jnp.savez(
    f"{outdir}/sky_locations/test_{ifo_pair[0]}_{ifo_pair[1]}_{time}.npz",
    **input_dict,
)
example/GW150914_IMRPhenomD.py (1)

25-25: Remove unused import.

optax is imported but not used within this file.

- import optax
🧰 Tools
🪛 Ruff (0.8.2)

25-25: optax imported but unused

Remove unused import: optax

(F401)

src/jimgw/constants.py (1)

1-1: Acknowledge TODO comment on type stubs.

The comment indicates missing type stubs for astropy. Consider addressing them in a future improvement to enhance type checking.

test/integration/test_periodic_uniform.py (2)

6-6: Remove unused import.

The BoundToUnbound import is only used in commented-out code.

-from jimgw.transforms import PeriodicTransform, BoundToUnbound
+from jimgw.transforms import PeriodicTransform
🧰 Tools
🪛 Ruff (0.8.2)

6-6: jimgw.transforms.BoundToUnbound imported but unused

Remove unused import: jimgw.transforms.BoundToUnbound

(F401)


76-99: Useful visualization code.

The commented visualization code provides a good way to analyze the sampling results. Consider uncommenting it or moving it to a separate function that can be optionally called.

Consider moving the visualization code to a separate function that can be called optionally:

def visualize_results(samples, prior):
    import matplotlib.pyplot as plt
    
    plt.figure(figsize=(10, 4))
    plt.subplot(1, 2, 1)
    plt.plot(samples["test"])
    plt.ylim(0.0, 2.0 * jnp.pi)
    plt.xlim(0, 300)
    plt.title("Walker history")
    
    plt.subplot(1, 2, 2)
    plt.hist(samples["test"], label="Samples", density=True, bins=50)
    x = jnp.linspace(0.0, 2.0 * jnp.pi, 1000)
    y = (jnp.cos(x) + 1.0) / 2.0 / jnp.pi
    plt.plot(x, y, label="Likelihood")
    plt.ylim(0.0)
    plt.xlim(0.0, 2.0 * jnp.pi)
    plt.legend()
    plt.title("Distribution comparison")
    
    plt.tight_layout()
    if prior.n_dim == 1:
        plt.savefig("figures/results_uniform.jpg")
    else:
        plt.savefig("figures/results_periodic.jpg")
    plt.close()

# Call at the end if desired
# visualize_results(samples, prior)
src/jimgw/utils.py (1)

107-149: Reduce code duplication in carte_to_spherical_angles.

The function correctly computes spherical coordinates, but duplicates logic already present in safe_polar_angle and safe_arctan2.

Consider reusing the already implemented functions to reduce code duplication:

def carte_to_spherical_angles(
    x: Float[Array, " n"],
    y: Float[Array, " n"],
    z: Float[Array, " n"],
    default_value: float = 0.0,
) -> Float[Array, " n n"]:
    """
    A numerically stable method to compute the spherical angles upon taking gradient.

    For more details, see
    * `safe_polar_angle` for the polar angle and
    * `safe_arctan2` for the azimuthal angle.

    Parameters
    ==========
    x: array-like
        x-coordinate of the point
    y: array-like
        y-coordinate of the point
    z: array-like
        z-coordinate of the point
    default_value: float
        arctan2 value to return at (0, 0), default is 0.0

    Returns
    =======
    theta: array-like:
        The polar angle, in radians, within [0, π]
    phi: array-like:
        The signed azimuthal angle, in radians, within [-π, π]
    """
-    align_condition = (jnp.absolute(x) < EPS) & (jnp.absolute(y) < EPS)
-    theta = jnp.where(
-        align_condition,
-        jnp.arccos(jnp.sign(z)),
-        jnp.arccos(z / jnp.sqrt(x ** 2 + y ** 2 + z ** 2)),
-    )
-    phi = jnp.where(
-        align_condition,
-        default_value * jnp.ones_like(x),
-        jnp.atan2(y, x),
-    )
+    theta = safe_polar_angle(x, y, z)
+    phi = safe_arctan2(y, x, default_value)
    return theta, phi
src/jimgw/jim.py (2)

33-55: Constructor: Ensure consistent naming and consider user clarity.
Introducing many parameters with default values (e.g., n_chains, n_local_steps, etc.) is helpful for flexibility. However, be sure that the naming and defaults are sufficiently descriptive so users know which settings to adjust.


77-79: Conditional RNG key check: Confirm user intent.
Using if rng_key is jax.random.PRNGKey(0): as a sentinel for a "default key" is workable, yet it could inadvertently skip custom keys with the same initialization seed. Consider clarifying the condition or documenting that this specifically checks for the default key usage.

src/jimgw/prior.py (2)

279-322: New GaussianPrior class.
The parameterization with mu and sigma is straightforward. The usage of StandardNormalDistribution plus a scale and offset transform is a robust approach. Verify that edge cases (e.g., sigma <= 0) are handled.

+assert sigma > 0, "Standard deviation must be positive in GaussianPrior"

403-430: RayleighPrior class for 1D distribution.
Using UniformPrior(0.0, 1.0, [param_base]) plus RayleighTransform is a neat approach. Confirm that sigma is positive to avoid invalid transformations.

+assert sigma > 0, "Rayleigh scale parameter must be positive"
src/jimgw/single_event/transforms.py (1)

407-408: Summing the square of the pattern terms for SNR weighting.
R_dets2 += p_mode_term ** 2 + c_mode_term ** 2 is correct for aggregator logic. Ensure numerical stability if p_mode_term or c_mode_term approach large values.

test/unit/test_sky_position_transform.ipynb (4)

118-121: Consider using JAX for all random sampling to ensure reproducible results.
Currently, lines 118-121 use NumPy for generating random values, while JAX is used elsewhere. Unifying the random number generation approach can improve consistency and reproducibility.


168-170: Enhance reliability by turning printed diffs into assertions.
Instead of printing delta_x differences, consider adding assertions to ensure that any discrepancies remain below a certain threshold. This will make tests more robust and automated.


207-208: Add a stricter assertion for GMST differences.
You are printing “Mean difference in GMST: ...” without failing the test if the difference grows. Consider implementing an assertion to ensure GMST differences stay below a desired tolerance.


244-246: Incorporate an assertion on RA and DEC differences.
Currently, the code merely prints the mean differences for RA and DEC. Consider enforcing an upper bound using jnp.allclose or a similar assertion to detect unexpected discrepancies over the test samples.

test/unit/test_transform.py (4)

147-157: Verify boundary values in SineTransform tests.
You're testing with angle=0.3, which is comfortably within [-π/2, π/2]. Consider adding tests at boundary angles like ±π/2 to ensure correct behavior at the extremes.


176-186: Exercise edge cases in CosineTransform tests.
Currently, the angle used is 1.2. Consider adding boundary angle tests (e.g., 0 and π) to confirm the transform handles extreme inputs properly.


316-318: Rename test_powerlaw_transformn for clarity.
The method name test_powerlaw_transformn is not very descriptive. Consider renaming it to something like test_powerlaw_transform_with_alpha_not_minus_one for improved readability.


349-367: Extend distance transform tests for more edge cases.
The coverage is good, but consider testing extremely large or small distances (beyond the typical range) to ensure robust numerical handling and reveal potential corner cases.

src/jimgw/transforms.py (2)

325-330: Consider explicit domain checks in SineTransform.
If x[name_mapping[1][i]] is outside [-1, 1], jnp.arcsin will produce NaNs. Adding pre-checks or error handling can improve debuggability.


503-515: Introduce domain validation for PowerLaw with alpha = -1.0.
If user-supplied values go beyond [0,1], the transform or inverse log calculation might fail. Consider validating input domain or providing an explicit error to avoid silent NaNs.

src/jimgw/single_event/utils.py (3)

42-64: Validate domain or sign constraints for masses.
While this function correctly computes the total mass and mass ratio, consider checking that m1 and m2 are both positive to avoid undefined behavior.


550-684: Consider dot-product clamping for iota computation.
In iota = jnp.arccos(jnp.dot(N, LNh)), if floating-point rounding pushes the dot product outside [-1,1], arccos may return NaNs. Clamping could guard against that.

- iota = jnp.arccos(jnp.dot(N, LNh))
+ dot_val = jnp.clip(jnp.dot(N, LNh), -1.0, 1.0)
+ iota = jnp.arccos(dot_val)

686-807: Broad inverse spin transformation is well-structured.
Similar numeric considerations apply: consider clamping where arcsin/arccos might exceed domain. Otherwise, logic is appropriately mirroring the forward transform.

🛑 Comments failed to post (3)
pyproject.toml (1)

16-16: ⚠️ Potential issue

Updated dependency versions

Updates to flowmc (0.4.2) and jax (≥0.5.0) may introduce compatibility issues. The pipeline failure indicates a missing module flowMC.resource_strategy_bundle.

Check that the flowmc version 0.4.2 has the required modules:

#!/bin/bash
# Check for the existence of the resource_strategy_bundle module in flowmc
pip show flowmc
# Try importing the module to see the exact error
python -c "from flowMC.resource_strategy_bundle import RQSpline_MALA_PT_Bundle; print('Module exists')" || echo "Module not found"

Also applies to: 18-18

test/integration/test_periodic_uniform.py (1)

1-8: ⚠️ Potential issue

Clean imports, but fix the pipeline failure.

The pipeline shows a failure with the import of 'flowMC.resource_strategy_bundle' in the Jim class. This is causing the test to fail.

This test depends on a module that's missing or has changed. Please check the Jim class implementation and ensure all dependencies are properly installed or update the import path.

🧰 Tools
🪛 Ruff (0.8.2)

6-6: jimgw.transforms.BoundToUnbound imported but unused

Remove unused import: jimgw.transforms.BoundToUnbound

(F401)

🪛 GitHub Actions: Run Tests

[error] 1-1: ModuleNotFoundError: No module named 'flowMC.resource_strategy_bundle' imported in src/jimgw/jim.py

src/jimgw/single_event/utils.py (1)

277-278: 🛠️ Refactor suggestion

Ensure real-valued output for eta_to_q.
If temp**2 < 1, (temp**2 - 1) becomes negative, leading to a potential complex number. Consider enforcing domain constraints or using a safe clamp.

- return temp - (temp**2 - 1)**0.5
+ safe_root = jnp.sqrt(jnp.maximum(temp**2 - 1, 0.0))  # clamp to avoid negative
+ return temp - safe_root
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

    safe_root = jnp.sqrt(jnp.maximum(temp**2 - 1, 0.0))  # clamp to avoid negative
    return temp - safe_root

@SSL32081
Copy link
Collaborator Author

@coderabbitai review

@coderabbitai
Copy link

coderabbitai bot commented Apr 15, 2025

✅ Actions performed

Review triggered.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

@SSL32081 SSL32081 self-assigned this Apr 15, 2025
@SSL32081
Copy link
Collaborator Author

SSL32081 commented May 2, 2025

@kazewong I have updated the docstring style and fixed the merge conflict, it is ready for review.

@SSL32081 SSL32081 requested a review from kazewong May 2, 2025 11:34
Copy link
Owner

@kazewong kazewong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approved

@kazewong kazewong merged commit 69ae276 into jim-dev May 9, 2025
1 of 7 checks passed
@kazewong kazewong deleted the fix-spin-transform-nan branch May 9, 2025 04:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants