Conversation
…rovide default behavior
|
Important Review skippedAuto reviews are disabled on base/target branches other than the default branch. Please check the settings in the CodeRabbit UI or the You can disable this status message by setting the WalkthroughThe changes update the random number generation within the Changes
Sequence Diagram(s)sequenceDiagram
participant U as User
participant J as Jim Class
participant R as jax.random
U->>J: Initialize Jim (rng_key: None or provided)
alt rng_key not provided
J->>R: Generate default key using PRNGKey(0)
J->>J: Set internal rng_key with default
else rng_key provided
J->>J: Set internal rng_key with provided value
end
U->>J: Call sample(key: optional)
alt key provided
J->>J: Print "rng_key is being overwritten"
J->>R: Split key and update sampler's rng_key
else key not provided
J->>J: Use current rng_key during sampling
end
Poem
🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
|
@coderabbitai review |
✅ Actions performedReview triggered.
|
There was a problem hiding this comment.
Actionable comments posted: 0
🔭 Outside diff range comments (1)
src/jimgw/jim.py (1)
112-144:⚠️ Potential issueBug: key used without checking if it's None
When
keyis None (the default), the code will fail at line 130 where it attempts to split the key without checking if it's None first.if key is not None: print("Overwriting rng_key") key, self.sampler.rng_key = jax.random.split(key) +else: + key = self.sampler.rng_key # Use the sampler's key if none provided while not jax.tree.reduce( jnp.logical_and, jax.tree.map(lambda x: jnp.isfinite(x), initial_position), ).all(): # ...rest of the codeThis ensures that
keyis never None when used in the following code.
🧹 Nitpick comments (4)
src/jimgw/jim.py (4)
1-7: Missing import for Optional typeThe code now uses an optional
rng_keyparameter, but theOptionaltype hint from thetypingmodule is missing.import jax import jax.numpy as jnp from flowMC.nfmodel.rqSpline import MaskedCouplingRQSpline from flowMC.proposal.MALA import MALA from flowMC.Sampler import Sampler from jaxtyping import Array, Float, PRNGKeyArray +from typing import Optional
56-59: Good refactor to key-based RNG, but minor improvements neededThe transition from seed-based to key-based RNG is a good improvement. However, there are a few minor issues:
- The message says "default seed" but you're actually using a default key
- Users can't configure the default seed value (0)
rng_key = kwargs.get("rng_key", None) if rng_key is None: - print("No rng_key provided. Using default seed.") + print("No rng_key provided. Using default key with seed=0.") rng_key = jax.random.PRNGKey(0)Consider also allowing users to configure the default seed:
rng_key = kwargs.get("rng_key", None) +default_seed = kwargs.get("default_seed", 0) if rng_key is None: - print("No rng_key provided. Using default seed.") + print(f"No rng_key provided. Using default key with seed={default_seed}.") - rng_key = jax.random.PRNGKey(0) + rng_key = jax.random.PRNGKey(default_seed)
106-106: Add Optional type hint to key parameterSince the
keyparameter can now beNone, add the proper type annotation.-def sample(self, key: PRNGKeyArray = None, initial_position: Array = jnp.array([])): +def sample(self, key: Optional[PRNGKeyArray] = None, initial_position: Array = jnp.array([])):
112-115: Improved message for when RNG key is overwrittenThe message "Overwriting rng_key" doesn't provide enough context about what's happening.
if key is not None: - print("Overwriting rng_key") + print("Provided key will override the existing sampler RNG key") key, self.sampler.rng_key = jax.random.split(key)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
src/jimgw/jim.py(2 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (4)
- GitHub Check: pre-commit (3.11)
- GitHub Check: pre-commit (3.10)
- GitHub Check: build (3.11)
- GitHub Check: build (3.10)
…d allow optional key parameter in sample method
|
@kazewong The keys in See if you think the changes I made are fine. Please tell me if you have better ideas. |
kazewong
left a comment
There was a problem hiding this comment.
Merging this for now. In the next revision, we want to move to more structure way to define jim, so keyword arguments and states should only be declared once
Summary by CodeRabbit