Skip to content

Refactor RNG key handling#199

Merged
kazewong merged 2 commits intojim-devfrom
fix-seed-issue
Apr 3, 2025
Merged

Refactor RNG key handling#199
kazewong merged 2 commits intojim-devfrom
fix-seed-issue

Conversation

@thomasckng
Copy link
Collaborator

@thomasckng thomasckng commented Apr 2, 2025

Summary by CodeRabbit

  • Refactor
    • Enhanced randomness management for initialization and sampling.
    • Now, a custom random configuration can be provided; otherwise, a default is automatically generated.
    • Clear notifications are provided when a custom configuration is applied.

@coderabbitai
Copy link

coderabbitai bot commented Apr 2, 2025

Important

Review skipped

Auto reviews are disabled on base/target branches other than the default branch.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Walkthrough

The changes update the random number generation within the Jim class. In the __init__ method, the seed parameter is replaced by rng_key, which defaults to None. If no key is provided, a default key is generated using jax.random.PRNGKey(0). In the sample method, the key parameter now defaults to None. When a key is provided during sampling, it prints a message and splits the key to update the internal state. These modifications transition the randomness management from a seed-based approach to a key-based approach.

Changes

File Summary of Modifications
src/jimgw/jim.py - Updated __init__ method: replaced seed with rng_key defaulting to None; generates a default key via jax.random.PRNGKey(0) when not provided.
- Updated sample method: key now defaults to None, and if provided, prints a message and splits the key to update the sampler's rng_key.

Sequence Diagram(s)

sequenceDiagram
    participant U as User
    participant J as Jim Class
    participant R as jax.random

    U->>J: Initialize Jim (rng_key: None or provided)
    alt rng_key not provided
        J->>R: Generate default key using PRNGKey(0)
        J->>J: Set internal rng_key with default
    else rng_key provided
        J->>J: Set internal rng_key with provided value
    end

    U->>J: Call sample(key: optional)
    alt key provided
        J->>J: Print "rng_key is being overwritten"
        J->>R: Split key and update sampler's rng_key
    else key not provided
        J->>J: Use current rng_key during sampling
    end
Loading

Poem

I'm a little rabbit, hopping with cheer,
Seeing code transform, oh so clear!
Seeds are out, new keys in play,
Splitting and managing randomness the new way.
With a joyful hop, I celebrate each change today!
🐰💕


🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai plan to trigger planning for file edits and PR creation.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

@thomasckng thomasckng requested a review from kazewong April 2, 2025 08:36
@thomasckng
Copy link
Collaborator Author

@coderabbitai review

@coderabbitai
Copy link

coderabbitai bot commented Apr 2, 2025

✅ Actions performed

Review triggered.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🔭 Outside diff range comments (1)
src/jimgw/jim.py (1)

112-144: ⚠️ Potential issue

Bug: key used without checking if it's None

When key is None (the default), the code will fail at line 130 where it attempts to split the key without checking if it's None first.

if key is not None:
    print("Overwriting rng_key")
    key, self.sampler.rng_key = jax.random.split(key)
+else:
+    key = self.sampler.rng_key  # Use the sampler's key if none provided

while not jax.tree.reduce(
    jnp.logical_and,
    jax.tree.map(lambda x: jnp.isfinite(x), initial_position),
).all():
    # ...rest of the code

This ensures that key is never None when used in the following code.

🧹 Nitpick comments (4)
src/jimgw/jim.py (4)

1-7: Missing import for Optional type

The code now uses an optional rng_key parameter, but the Optional type hint from the typing module is missing.

import jax
import jax.numpy as jnp
from flowMC.nfmodel.rqSpline import MaskedCouplingRQSpline
from flowMC.proposal.MALA import MALA
from flowMC.Sampler import Sampler
from jaxtyping import Array, Float, PRNGKeyArray
+from typing import Optional

56-59: Good refactor to key-based RNG, but minor improvements needed

The transition from seed-based to key-based RNG is a good improvement. However, there are a few minor issues:

  1. The message says "default seed" but you're actually using a default key
  2. Users can't configure the default seed value (0)
rng_key = kwargs.get("rng_key", None)
if rng_key is None:
-    print("No rng_key provided. Using default seed.")
+    print("No rng_key provided. Using default key with seed=0.")
    rng_key = jax.random.PRNGKey(0)

Consider also allowing users to configure the default seed:

rng_key = kwargs.get("rng_key", None)
+default_seed = kwargs.get("default_seed", 0)
if rng_key is None:
-    print("No rng_key provided. Using default seed.")
+    print(f"No rng_key provided. Using default key with seed={default_seed}.")
-    rng_key = jax.random.PRNGKey(0)
+    rng_key = jax.random.PRNGKey(default_seed)

106-106: Add Optional type hint to key parameter

Since the key parameter can now be None, add the proper type annotation.

-def sample(self, key: PRNGKeyArray = None, initial_position: Array = jnp.array([])):
+def sample(self, key: Optional[PRNGKeyArray] = None, initial_position: Array = jnp.array([])):

112-115: Improved message for when RNG key is overwritten

The message "Overwriting rng_key" doesn't provide enough context about what's happening.

if key is not None:
-    print("Overwriting rng_key")
+    print("Provided key will override the existing sampler RNG key")
     key, self.sampler.rng_key = jax.random.split(key)
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between e25ac98 and dccd431.

📒 Files selected for processing (1)
  • src/jimgw/jim.py (2 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (4)
  • GitHub Check: pre-commit (3.11)
  • GitHub Check: pre-commit (3.10)
  • GitHub Check: build (3.11)
  • GitHub Check: build (3.10)

…d allow optional key parameter in sample method
@thomasckng
Copy link
Collaborator Author

@kazewong The keys in jim.sample() and jim.__init__() are a bit weirdly used now. The key in sample() is just for initial point generation, and the key in __init__() is for sampler sampling and MaskedCouplingRQSpline initialization.

See if you think the changes I made are fine. Please tell me if you have better ideas.

Copy link
Owner

@kazewong kazewong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Merging this for now. In the next revision, we want to move to more structure way to define jim, so keyword arguments and states should only be declared once

@kazewong kazewong merged commit 5522775 into jim-dev Apr 3, 2025
1 of 5 checks passed
@kazewong kazewong deleted the fix-seed-issue branch April 3, 2025 13:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants