Conversation
- Updated the usage of PRNGKeyArray to Key in various modules to align with the latest jaxtyping standards. - Modified the initialization of random keys in tests to use the new key generation method. - Ensured consistency in random key handling across resource strategies, strategies, and tests.
feat: implement early stopping strategy
|
Warning Rate limit exceeded
⌛ How to resolve this issue?After the wait time has elapsed, a review can be triggered using the We recommend that you space out your commits to avoid hitting the rate limit. 🚦 How do rate limits work?CodeRabbit enforces hourly rate limits for each developer per organization. Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout. Please see our FAQ for further information. ⛔ Files ignored due to path filters (1)
📒 Files selected for processing (1)
📝 WalkthroughWalkthroughMigrates JAX RNG typing from Changes
Sequence Diagram(s)sequenceDiagram
participant Sampler
participant Strategy as Training\nStrategy
participant State
participant Resources
Sampler->>Strategy: Call strategy(rng_key, resources, ...)
Strategy-->>Sampler: (rng_key, resources, position)
Sampler->>State: Read state.data["early_stopped"]
alt early_stopped == True
Sampler->>Sampler: Set skip_to_production = True
Sampler->>Strategy: Skip remaining training strategies until reset_steppers
else
Sampler->>Strategy: Continue calling next training strategy
end
alt on reset_steppers
Sampler->>State: Clear state.data["early_stopped"]
Sampler->>Sampler: Set skip_to_production = False (enter production)
end
sequenceDiagram
participant Bundle
participant Check as CheckEarlyStop\nStrategy
participant State
participant Buffer as Acceptance\nBuffer
Bundle->>Check: Instantiate(params: tolerance, patience, ...)
loop each training loop
Bundle->>Check: __call__(rng_key, resources, initial_position, data)
Check->>State: Read state resource (state_name)
State->>Buffer: Provide acceptance buffer data
Check->>Check: Compute mean & CoV, compare to previous
alt Stability & patience reached OR min_acceptance reached
Check->>State: state.data["early_stopped"] = True
Check-->>Bundle: Return updated resources with early_stopped
else
Check-->>Bundle: Return without change
end
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes Possibly related PRs
Suggested labels
🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
@coderabbitai review |
✅ Actions performedReview triggered.
|
There was a problem hiding this comment.
Actionable comments posted: 5
🧹 Nitpick comments (5)
src/flowMC/resource/model/common.py (1)
54-55: Unusedkeyparameter inDistribution.__call__.The
keyparameter is declared but unused — the method simply delegates toself.log_prob(x). This appears intentional for API consistency, allowing subclasses to optionally use stochastic evaluation.Consider adding a brief comment or using
del keyto explicitly acknowledge this is intentional, similar to the pattern suggested forrqSpline.py.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/flowMC/resource/model/common.py` around lines 54 - 55, The unused key parameter in Distribution.__call__ should be explicitly acknowledged: inside the __call__ method (Distribution.__call__), either add a short comment noting key is kept for API compatibility or add a no-op statement like "del key" before returning self.log_prob(x) so linters/readers know the unused parameter is intentional; update only the __call__ method accordingly to avoid changing behavior.src/flowMC/resource/model/nf_model/rqSpline.py (1)
453-458: Unusedkeyparameter inforwardmethod.The static analysis correctly identifies that the
keyparameter is declared but never used within theforwardmethod. This parameter appears to be included for API consistency with other methods or potential future use with stochastic forward passes.If the parameter is intentional for interface consistency, consider adding a brief comment or using
_ = keyto explicitly acknowledge it's unused. Otherwise, if it's not needed, consider removing it.🔧 Option to explicitly acknowledge unused parameter
def forward( self, x: Float[Array, " n_dim"], key: Optional[Key] = None, condition: Optional[Float[Array, " n_condition"]] = None, ) -> tuple[Float[Array, " n_dim"], Float]: + del key # Unused, included for API consistency log_det = 0.0🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/flowMC/resource/model/nf_model/rqSpline.py` around lines 453 - 458, The forward method declares an unused parameter key; either remove key from the forward signature if it's unnecessary, or preserve API consistency by explicitly acknowledging it inside the method (e.g., add a line like "_ = key" or a short comment "# key unused (kept for API compatibility)" near the top of rqSpline.forward) so static analysis no longer flags it; reference the forward method and the key parameter when making the change.src/flowMC/strategy/optimization.py (1)
95-122: Misplaced docstring will not be recognised by Python.The docstring for the
optimizemethod (lines 103–122) is placed after the bounds validation code (lines 95–101). In Python, docstrings must immediately follow the function signature to be accessible viahelp()or__doc__. This appears to be a pre-existing issue but worth addressing.♻️ Proposed fix to relocate the docstring
def optimize( self, rng_key: Key, objective: Callable, initial_position: Float[Array, " n_chain n_dim"], data: dict, ): + """Optimization kernel. This can be used independently of the __call__ method. + + Args: + rng_key: Key + Random key for the optimization. + objective: Callable + Objective function to optimize. + initial_position: Float[Array, " n_chain n_dim"] + Initial positions for the optimization. + data: dict + Data to pass to the objective function. + + Returns: + rng_key: Key + Updated random key. + optimized_positions: Float[Array, " n_chain n_dim"] + Optimized positions. + final_log_prob: Float[Array, " n_chain"] + Final log-probabilities of the optimized positions. + """ # Validate bounds shape against n_dim n_dim = initial_position.shape[-1] if not (self.bounds.shape[0] == 1 or self.bounds.shape[0] == n_dim): raise ValueError( f"bounds shape {self.bounds.shape} is incompatible with n_dim={n_dim}. " "Provide bounds of shape (1, 2) for broadcasting or (n_dim, 2) for per-dimension bounds." ) - """Optimization kernel. This can be used independently of the __call__ method. - - Args: - rng_key: Key - Random key for the optimization. - objective: Callable - Objective function to optimize. - initial_position: Float[Array, " n_chain n_dim"] - Initial positions for the optimization. - data: dict - Data to pass to the objective function. - - Returns: - rng_key: Key - Updated random key. - optimized_positions: Float[Array, " n_chain n_dim"] - Optimized positions. - final_log_prob: Float[Array, " n_chain"] - Final log-probabilities of the optimized positions. - """ grad_fn = jax.jit(jax.grad(objective))🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/flowMC/strategy/optimization.py` around lines 95 - 122, The docstring for the optimize method is placed after the bounds validation and will not be recognized; move the triple-quoted docstring so it immediately follows the def optimize(...) signature (before any code such as the bounds validation referencing initial_position or self.bounds), preserving the current content and indentation, and remove the stray duplicate docstring block after the validation so optimize.__doc__ is correct.tests/unit/test_nf.py (1)
30-30: Remove dead key-split outputs in these tests.Line 30 unpacks
rng_subkeybut never uses it.
Lines 63-65 split a key, then ignore it and build the model from a separate key.Proposed cleanup
- rng_key, rng_subkey = jax.random.split(jax.random.key(0), 2) + rng_key, _ = jax.random.split(jax.random.key(0), 2) @@ - rng_key, rng_subkey = jax.random.split(jax.random.key(0), 2) model = MaskedCouplingRQSpline( n_features, n_layers, hidden_layes, n_bins, jax.random.key(10) )Also applies to: 63-65
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unit/test_nf.py` at line 30, Remove unused split outputs by changing calls to jax.random.split that currently unpack into rng_key, rng_subkey (e.g., the statement using jax.random.split(jax.random.key(0), 2)) so only the used key is kept; specifically drop rng_subkey from the unpack and just assign the single used key (or call jax.random.key(0) directly) wherever rng_subkey is never referenced. Also fix the later occurrence around lines 63-65 where a key is split and the result ignored—stop creating the dead key and pass the actual key you use to build the model (reference jax.random.split and the variables rng_key/rng_subkey and the model construction call) so no unused key variables remain.src/flowMC/resource_strategy_bundle/RQSpline_GRW.py (1)
231-238: Improve callback transparency by returning the strategy tuple explicitly.The current code at lines 231–238 is safe—Lambda.call explicitly ignores callback return values and always returns
(rng_key, resources, initial_position). However, the callbacks returnNone(fromset_current_position), which depends on this implementation detail. Making callbacks return the strategy tuple explicitly improves code clarity and reduces dependency on Lambda's internal behaviour.Suggested improvement
- update_global_step = Lambda( - lambda rng_key, resources, initial_position, data: ( - global_stepper.set_current_position(local_stepper.current_position) - ) - ) - update_local_step = Lambda( - lambda rng_key, resources, initial_position, data: ( - local_stepper.set_current_position(global_stepper.current_position) - ) - ) + def sync_global_step(rng_key, resources, initial_position, data): + global_stepper.set_current_position(local_stepper.current_position) + return rng_key, resources, initial_position + + def sync_local_step(rng_key, resources, initial_position, data): + local_stepper.set_current_position(global_stepper.current_position) + return rng_key, resources, initial_position + + update_global_step = Lambda(sync_global_step) + update_local_step = Lambda(sync_local_step)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/flowMC/resource_strategy_bundle/RQSpline_GRW.py` around lines 231 - 238, The lambdas used for stepper callbacks (the Lambda instances assigned around update_local_step and the earlier global setter) currently call global_stepper.set_current_position(...) and local_stepper.set_current_position(...) which return None, relying on Lambda.__call__ to ignore callback return values; change each callback to perform the set_current_position call and then explicitly return the strategy tuple (rng_key, resources, initial_position) so the callbacks return the expected (rng_key, resources, initial_position) tuple instead of None (e.g., inside the Lambda bodies for the callbacks referencing global_stepper and local_stepper).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/flowMC/resource_strategy_bundle/RQSpline_HMC_PT.py`:
- Line 250: Add a "# noqa: F722" comment to the jaxtyping forward-annotation
lines that use multi-word shape strings (e.g., the parameter signature
"initial_position: Float[Array, \"n_chains n_dim\"]") to silence Ruff F722;
update the exact signature lines at the shown locations (and the analogous lines
at 282 and 322 and the same patterns in RQSpline_GRW_PT.py, RQSpline_HMC.py,
RQSpline_MALA.py, and RQSpline_MALA_PT.py) by appending "# noqa: F722" to those
annotation lines so the multi-word shape literals are ignored by the
forward-annotation checker.
In `@src/flowMC/resource/model/flowmatching/base.py`:
- Around line 275-285: The train function's docstring is out of sync: remove the
obsolete `model` parameter mention and update parameter and return descriptions
to match the current signature (e.g., `rng`, `data`, `num_epochs`, `batch_size`,
`verbose`, and the optimizer/state objects actually passed such as `optim` or
`state`), and ensure the Returns section lists the actual returned items (e.g.,
updated `rng`, `state`/`optim`, and `loss_values` as appropriate); update
parameter names and types to match those used in the `train` function in this
file (search for `def train(`) and keep descriptions concise and accurate.
- Around line 171-172: The return type annotation for FlowMatchingModel.sample
is missing the batch dimension: update the annotated return shape to include
num_samples (e.g., change Float[Array, " n_dim"] to include the batch axis such
as Float[Array, " num_samples n_dim"]) so the signature for sample(self,
rng_key: Key, num_samples: int, dt: Float = 1e-1) -> ... correctly reflects a
batched array; locate the method named sample in class FlowMatchingModel in
base.py and adjust its return type annotation accordingly.
In `@src/flowMC/resource/model/nf_model/base.py`:
- Line 67: The forward-annotation strings contain leading spaces which cause
F722 errors; remove the stray spaces inside all Array dimension string
annotations (e.g., change Float[Array, " n_dim"] to Float[Array, "n_dim"] and
Float[Array, " n_batch n_dim"] to Float[Array, "n_batch n_dim"]) across this
module (methods using the Float[Array, "..."] annotation such as the method with
signature containing x: Float[Array, " n_dim"]) and also in
src/flowMC/resource_strategy_bundle/RQSpline_GRW.py so that no dimension string
contains leading spaces.
In `@src/flowMC/strategy/check_early_stop.py`:
- Around line 72-74: Docstring for the function/class that defines the
"patience" parameter (see signature with patience: int = 3) is
inconsistent—docstring says "Defaults to 1" while the actual default is 3;
update either the docstring or the function signature so they match. Locate the
"patience" parameter in check_early_stop (or the function/class in
src/flowMC/strategy/check_early_stop.py that declares patience: int = 3) and
change the docstring default text to "Defaults to 3" (or change the signature
default to 1 if intended) so the declared default and the docstring are
consistent. Ensure the docstring line describing patience is updated exactly
where the parameter is documented.
---
Nitpick comments:
In `@src/flowMC/resource_strategy_bundle/RQSpline_GRW.py`:
- Around line 231-238: The lambdas used for stepper callbacks (the Lambda
instances assigned around update_local_step and the earlier global setter)
currently call global_stepper.set_current_position(...) and
local_stepper.set_current_position(...) which return None, relying on
Lambda.__call__ to ignore callback return values; change each callback to
perform the set_current_position call and then explicitly return the strategy
tuple (rng_key, resources, initial_position) so the callbacks return the
expected (rng_key, resources, initial_position) tuple instead of None (e.g.,
inside the Lambda bodies for the callbacks referencing global_stepper and
local_stepper).
In `@src/flowMC/resource/model/common.py`:
- Around line 54-55: The unused key parameter in Distribution.__call__ should be
explicitly acknowledged: inside the __call__ method (Distribution.__call__),
either add a short comment noting key is kept for API compatibility or add a
no-op statement like "del key" before returning self.log_prob(x) so
linters/readers know the unused parameter is intentional; update only the
__call__ method accordingly to avoid changing behavior.
In `@src/flowMC/resource/model/nf_model/rqSpline.py`:
- Around line 453-458: The forward method declares an unused parameter key;
either remove key from the forward signature if it's unnecessary, or preserve
API consistency by explicitly acknowledging it inside the method (e.g., add a
line like "_ = key" or a short comment "# key unused (kept for API
compatibility)" near the top of rqSpline.forward) so static analysis no longer
flags it; reference the forward method and the key parameter when making the
change.
In `@src/flowMC/strategy/optimization.py`:
- Around line 95-122: The docstring for the optimize method is placed after the
bounds validation and will not be recognized; move the triple-quoted docstring
so it immediately follows the def optimize(...) signature (before any code such
as the bounds validation referencing initial_position or self.bounds),
preserving the current content and indentation, and remove the stray duplicate
docstring block after the validation so optimize.__doc__ is correct.
In `@tests/unit/test_nf.py`:
- Line 30: Remove unused split outputs by changing calls to jax.random.split
that currently unpack into rng_key, rng_subkey (e.g., the statement using
jax.random.split(jax.random.key(0), 2)) so only the used key is kept;
specifically drop rng_subkey from the unpack and just assign the single used key
(or call jax.random.key(0) directly) wherever rng_subkey is never referenced.
Also fix the later occurrence around lines 63-65 where a key is split and the
result ignored—stop creating the dead key and pass the actual key you use to
build the model (reference jax.random.split and the variables rng_key/rng_subkey
and the model construction call) so no unused key variables remain.
ℹ️ Review info
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
⛔ Files ignored due to path filters (1)
uv.lockis excluded by!**/*.lock
📒 Files selected for processing (43)
.coverage.github/renovate.json.pre-commit-config.yamldocs/quickstart.mdpyproject.tomlsrc/flowMC/Sampler.pysrc/flowMC/resource/kernel/Gaussian_random_walk.pysrc/flowMC/resource/kernel/HMC.pysrc/flowMC/resource/kernel/MALA.pysrc/flowMC/resource/kernel/NF_proposal.pysrc/flowMC/resource/kernel/base.pysrc/flowMC/resource/model/common.pysrc/flowMC/resource/model/flowmatching/base.pysrc/flowMC/resource/model/nf_model/base.pysrc/flowMC/resource/model/nf_model/realNVP.pysrc/flowMC/resource/model/nf_model/rqSpline.pysrc/flowMC/resource_strategy_bundle/RQSpline_GRW.pysrc/flowMC/resource_strategy_bundle/RQSpline_GRW_PT.pysrc/flowMC/resource_strategy_bundle/RQSpline_HMC.pysrc/flowMC/resource_strategy_bundle/RQSpline_HMC_PT.pysrc/flowMC/resource_strategy_bundle/RQSpline_MALA.pysrc/flowMC/resource_strategy_bundle/RQSpline_MALA_PT.pysrc/flowMC/strategy/adapt_step_size.pysrc/flowMC/strategy/base.pysrc/flowMC/strategy/check_early_stop.pysrc/flowMC/strategy/lambda_function.pysrc/flowMC/strategy/optimization.pysrc/flowMC/strategy/parallel_tempering.pysrc/flowMC/strategy/sequential_monte_carlo.pysrc/flowMC/strategy/take_steps.pysrc/flowMC/strategy/train_model.pysrc/flowMC/strategy/update_state.pytests/integration/test_HMC.pytests/integration/test_MALA.pytests/integration/test_RWMCMC.pytests/integration/test_normalizingFlow.pytests/integration/test_quickstart.pytests/unit/test_bundle.pytests/unit/test_flowmatching.pytests/unit/test_kernels.pytests/unit/test_nf.pytests/unit/test_resources.pytests/unit/test_strategies.py
💤 Files with no reviewable changes (1)
- .github/renovate.json
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
src/flowMC/strategy/optimization.py (1)
95-122:⚠️ Potential issue | 🟡 MinorDocstring is misplaced after executable code.
The docstring for the
optimizemethod is placed after the bounds validation logic (lines 95-101). Docstrings must immediately follow the function definition to be recognised by Python's documentation tools.💡 Suggested fix
def optimize( self, rng_key: Key, objective: Callable, initial_position: Float[Array, "n_chain n_dim"], data: dict, ): + """Optimization kernel. This can be used independently of the __call__ method. + + Args: + rng_key: Key + Random key for the optimization. + objective: Callable + Objective function to optimize. + initial_position: Float[Array, "n_chain n_dim"] + Initial positions for the optimization. + data: dict + Data to pass to the objective function. + + Returns: + rng_key: Key + Updated random key. + optimized_positions: Float[Array, "n_chain n_dim"] + Optimized positions. + final_log_prob: Float[Array, "n_chain"] + Final log-probabilities of the optimized positions. + """ # Validate bounds shape against n_dim n_dim = initial_position.shape[-1] if not (self.bounds.shape[0] == 1 or self.bounds.shape[0] == n_dim): raise ValueError( f"bounds shape {self.bounds.shape} is incompatible with n_dim={n_dim}. " "Provide bounds of shape (1, 2) for broadcasting or (n_dim, 2) for per-dimension bounds." ) - - """Optimization kernel. This can be used independently of the __call__ method. - ... - """ grad_fn = jax.jit(jax.grad(objective))🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/flowMC/strategy/optimization.py` around lines 95 - 122, The optimize method's docstring is incorrectly placed after executable bounds validation; move the triple-quoted docstring so it immediately follows the def optimize(...) signature (before any code like the bounds check using self.bounds and initial_position), preserving the exact docstring content and indentation, so Python documentation tools will recognize it and the validation logic (n_dim, bounds shape check) remains unchanged after the docstring.
♻️ Duplicate comments (2)
src/flowMC/resource/model/flowmatching/base.py (2)
170-172:⚠️ Potential issue | 🟡 MinorReturn type annotation is missing the batch dimension.
The
samplemethod returnsnum_samplessamples, but the return type annotation indicates a single vectorFloat[Array, " n_dim"]instead ofFloat[Array, "num_samples n_dim"].💡 Suggested fix
def sample( self, rng_key: Key, num_samples: int, dt: Float = 1e-1 - ) -> Float[Array, " n_dim"]: + ) -> Float[Array, "num_samples n_dim"]:🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/flowMC/resource/model/flowmatching/base.py` around lines 170 - 172, The sample method's return type annotation on class/method sample is missing the batch dimension; update the return annotation from Float[Array, " n_dim"] to include the num_samples dimension (e.g. Float[Array, "num_samples n_dim"]) so it correctly reflects that sample(self, rng_key: Key, num_samples: int, dt: Float = 1e-1) returns a batch of num_samples vectors; apply this change to the sample signature in flowmatching.base.py and ensure any whitespace/typing string formatting matches surrounding annotations.
274-285:⚠️ Potential issue | 🟡 MinorDocstring is out of sync with the function signature.
The docstring still documents a
modelargument that does not exist in the signature, and the Returns section is incomplete (missingstatereturn value).💡 Suggested docstring correction
"""Train a normalizing flow model. Args: rng (Key): JAX PRNGKey. - model (eqx.Module): NF model to train. data (Array): Training data. + optim (optax.GradientTransformation): Optimiser transformation. + state (optax.OptState): Optimiser state. num_epochs (int): Number of epochs to train for. batch_size (int): Batch size. verbose (bool): Whether to print progress. Returns: rng (Key): Updated JAX PRNGKey. - model (eqx.Model): Updated NF model. + model (Self): Updated NF model. + state (optax.OptState): Updated optimiser state. loss_values (Array): Loss values. """🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/flowMC/resource/model/flowmatching/base.py` around lines 274 - 285, The docstring currently documents a removed `model` parameter and omits the `state` return; update the docstring for the function containing the shown Args so the Args list matches the actual signature (include `rng`, `state`, `data`, `num_epochs`, `batch_size`, `verbose` as present) and remove any reference to `model`, and update the Returns section to list both `rng` and `state` with their correct types (and `loss_values` if returned) and short descriptions; also adjust the top summary line to reflect the real behavior of the function and ensure parameter names `rng` and `state` match exactly the symbols in the implementation.
🧹 Nitpick comments (3)
src/flowMC/resource/model/nf_model/realNVP.py (1)
177-182: Unusedkeyparameter inforwardmethod.The
keyparameter is accepted but never used in the method body. RealNVP is a deterministic flow, so this is expected behaviour. If this parameter exists solely for interface compatibility with stochastic flow models, consider documenting this or prefixing with underscore (_key) to suppress the linter warning.♻️ Optional: Suppress linter warning
def forward( self, x: Float[Array, " n_dim"], - key: Optional[Key] = None, + _key: Optional[Key] = None, # Unused; kept for interface compatibility condition: Optional[Float[Array, " n_condition"]] = None, ) -> tuple[Float[Array, " n_dim"], Float]:🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/flowMC/resource/model/nf_model/realNVP.py` around lines 177 - 182, The forward method in RealNVP (def forward(...)) accepts a key parameter but never uses it; to fix, either rename the parameter to _key to suppress linter warnings or add a short docstring/comment in the RealNVP.forward signature body stating that key is accepted for interface compatibility with stochastic flows and is intentionally unused; update references to the parameter name accordingly to avoid unused-variable errors.src/flowMC/resource/model/common.py (1)
54-55: Consider documenting or removing the unusedkeyparameter.The
keyparameter inDistribution.__call__is currently unused. If this is intentional for interface consistency with subclasses that may need randomness, consider adding a brief docstring explaining this. Otherwise, it could be removed.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/flowMC/resource/model/common.py` around lines 54 - 55, The __call__ method on Distribution currently accepts an unused key parameter; either document its purpose for interface consistency or remove it: add a short docstring to Distribution.__call__ explaining that key is optional and unused by the base implementation but kept so subclasses that require RNG (e.g., sampling subclasses overriding __call__ or methods like log_prob) can accept it, or remove key from Distribution.__call__ signature and update any subclass overrides and callers to match (search for __call__ and log_prob in the class to locate affected code).src/flowMC/strategy/optimization.py (1)
42-54: Consider usingNoneas default and initialisingboundsinside the function.Using
jnp.array(...)as a default argument value triggers B008. While JAX arrays are immutable and this is safe in practice, the idiomatic pattern is to useNoneas the default and initialise inside the function body.💡 Suggested fix
def __init__( self, logpdf: Callable[[Float[Array, " n_dim"], dict], Float], n_steps: int = 100, learning_rate: float = 1e-2, noise_level: float = 10, - bounds: Float[Array, "n_dim 2"] = jnp.array([[-jnp.inf, jnp.inf]]), + bounds: Float[Array, "n_dim 2"] | None = None, ): self.logpdf = logpdf self.n_steps = n_steps self.learning_rate = learning_rate self.noise_level = noise_level - self.bounds = bounds + self.bounds = bounds if bounds is not None else jnp.array([[-jnp.inf, jnp.inf]])Note: The class attribute default and docstring would also need updating.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/flowMC/strategy/optimization.py` around lines 42 - 54, The __init__ method currently uses a JAX array literal as the default for the bounds parameter which triggers B008; change the signature to use bounds: Optional[...]=None and inside __init__ set self.bounds = jnp.array([[-jnp.inf, jnp.inf]]) if bounds is None else bounds (ensuring the type matches expected Float[Array, "n_dim 2"]), and update the constructor docstring/default class attribute comment accordingly; refer to the __init__ parameter name bounds and the instance attribute self.bounds when making the change.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/flowMC/resource/model/nf_model/base.py`:
- Around line 166-179: Update the docstring for the train function in
nf_model.base to match the actual signature: remove the nonexistent `model`
parameter from the Args section, ensure the listed arguments (e.g., rng, data,
num_epochs, batch_size, verbose) exactly match the function signature, and add
the missing `state` to the Returns section (include its type, e.g.,
optimizer/state object or eqx state) alongside rng, model, and loss_values;
reference the function by name (train) and the returned `state` symbol when
making these edits.
---
Outside diff comments:
In `@src/flowMC/strategy/optimization.py`:
- Around line 95-122: The optimize method's docstring is incorrectly placed
after executable bounds validation; move the triple-quoted docstring so it
immediately follows the def optimize(...) signature (before any code like the
bounds check using self.bounds and initial_position), preserving the exact
docstring content and indentation, so Python documentation tools will recognize
it and the validation logic (n_dim, bounds shape check) remains unchanged after
the docstring.
---
Duplicate comments:
In `@src/flowMC/resource/model/flowmatching/base.py`:
- Around line 170-172: The sample method's return type annotation on
class/method sample is missing the batch dimension; update the return annotation
from Float[Array, " n_dim"] to include the num_samples dimension (e.g.
Float[Array, "num_samples n_dim"]) so it correctly reflects that sample(self,
rng_key: Key, num_samples: int, dt: Float = 1e-1) returns a batch of num_samples
vectors; apply this change to the sample signature in flowmatching.base.py and
ensure any whitespace/typing string formatting matches surrounding annotations.
- Around line 274-285: The docstring currently documents a removed `model`
parameter and omits the `state` return; update the docstring for the function
containing the shown Args so the Args list matches the actual signature (include
`rng`, `state`, `data`, `num_epochs`, `batch_size`, `verbose` as present) and
remove any reference to `model`, and update the Returns section to list both
`rng` and `state` with their correct types (and `loss_values` if returned) and
short descriptions; also adjust the top summary line to reflect the real
behavior of the function and ensure parameter names `rng` and `state` match
exactly the symbols in the implementation.
---
Nitpick comments:
In `@src/flowMC/resource/model/common.py`:
- Around line 54-55: The __call__ method on Distribution currently accepts an
unused key parameter; either document its purpose for interface consistency or
remove it: add a short docstring to Distribution.__call__ explaining that key is
optional and unused by the base implementation but kept so subclasses that
require RNG (e.g., sampling subclasses overriding __call__ or methods like
log_prob) can accept it, or remove key from Distribution.__call__ signature and
update any subclass overrides and callers to match (search for __call__ and
log_prob in the class to locate affected code).
In `@src/flowMC/resource/model/nf_model/realNVP.py`:
- Around line 177-182: The forward method in RealNVP (def forward(...)) accepts
a key parameter but never uses it; to fix, either rename the parameter to _key
to suppress linter warnings or add a short docstring/comment in the
RealNVP.forward signature body stating that key is accepted for interface
compatibility with stochastic flows and is intentionally unused; update
references to the parameter name accordingly to avoid unused-variable errors.
In `@src/flowMC/strategy/optimization.py`:
- Around line 42-54: The __init__ method currently uses a JAX array literal as
the default for the bounds parameter which triggers B008; change the signature
to use bounds: Optional[...]=None and inside __init__ set self.bounds =
jnp.array([[-jnp.inf, jnp.inf]]) if bounds is None else bounds (ensuring the
type matches expected Float[Array, "n_dim 2"]), and update the constructor
docstring/default class attribute comment accordingly; refer to the __init__
parameter name bounds and the instance attribute self.bounds when making the
change.
ℹ️ Review info
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (11)
.pre-commit-config.yamlsrc/flowMC/resource/buffers.pysrc/flowMC/resource/kernel/HMC.pysrc/flowMC/resource/kernel/MALA.pysrc/flowMC/resource/kernel/base.pysrc/flowMC/resource/model/common.pysrc/flowMC/resource/model/flowmatching/base.pysrc/flowMC/resource/model/nf_model/base.pysrc/flowMC/resource/model/nf_model/realNVP.pysrc/flowMC/strategy/optimization.pysrc/flowMC/strategy/parallel_tempering.py
✅ Files skipped from review due to trivial changes (1)
- src/flowMC/resource/buffers.py
🚧 Files skipped from review as they are similar to previous changes (2)
- src/flowMC/resource/kernel/MALA.py
- .pre-commit-config.yaml
…y arguments and return values
…s and return values
There was a problem hiding this comment.
Actionable comments posted: 2
♻️ Duplicate comments (1)
src/flowMC/resource/model/flowmatching/base.py (1)
171-172:⚠️ Potential issue | 🟡 Minor
FlowMatchingModel.samplereturn annotation still misses the batch axis.At Line 172, the annotated return type is a single vector, but this method returns
num_samplesrows.💡 Annotation fix
- ) -> Float[Array, " n_dim"]: + ) -> Float[Array, "num_samples n_dim"]:🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/flowMC/resource/model/flowmatching/base.py` around lines 171 - 172, The return type annotation of FlowMatchingModel.sample is incorrect: it declares a single vector but actually returns num_samples rows; update the signature for sample (method name FlowMatchingModel.sample with parameters rng_key, num_samples, dt) to include the batch axis — e.g. change the return annotation to include the leading num_samples dimension (such as Float[Array, "num_samples n_dim"] or equivalent in your typing convention) so the annotated shape matches the actual returned array.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/flowMC/resource/model/flowmatching/base.py`:
- Line 136: Several jaxtyping forward annotations like the attribute _data_cov
use shape strings with spaces (e.g., Float[Array, "n_dim n_dim"]) which trigger
Ruff F722; fix by appending "# noqa: F722" to each affected annotation line
using the jaxtyping Float[Array, "..."] pattern in this file (for example the
_data_cov attribute and all other class/attribute annotations that use
Float[Array, "..."]), ensuring every occurrence listed in the review gets the
noqa comment so CI passes; alternatively, if you prefer project-wide handling,
add an exclusion for F722 for this module in the Ruff config instead of per-line
noqa.
In `@src/flowMC/strategy/check_early_stop.py`:
- Around line 117-131: The asserts use walrus assignments which are removed with
-O; separate assignments from checks in the function (look for state_name and
acceptance_buffer_key usage inside check_early_stop) by first assigning state =
resources[self.state_name], buffer_name =
state.data.get(self.acceptance_buffer_key), and acceptance_buffer =
resources[buffer_name], then perform explicit validation (either plain assert
without walrus or explicit if checks raising TypeError/ValueError) that state is
a State, buffer_name is a str, and acceptance_buffer is a Buffer so the
variables exist regardless of optimization flags.
---
Duplicate comments:
In `@src/flowMC/resource/model/flowmatching/base.py`:
- Around line 171-172: The return type annotation of FlowMatchingModel.sample is
incorrect: it declares a single vector but actually returns num_samples rows;
update the signature for sample (method name FlowMatchingModel.sample with
parameters rng_key, num_samples, dt) to include the batch axis — e.g. change the
return annotation to include the leading num_samples dimension (such as
Float[Array, "num_samples n_dim"] or equivalent in your typing convention) so
the annotated shape matches the actual returned array.
ℹ️ Review info
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
⛔ Files ignored due to path filters (1)
uv.lockis excluded by!**/*.lock
📒 Files selected for processing (3)
src/flowMC/resource/model/flowmatching/base.pysrc/flowMC/resource/model/nf_model/base.pysrc/flowMC/strategy/check_early_stop.py
There was a problem hiding this comment.
Actionable comments posted: 5
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/flowMC/resource_strategy_bundle/RQSpline_GRW_PT.py`:
- Line 57: Remove the leading-space shape tokens and fix multi-token jaxtyping
shape strings in this file: e.g., change grw_step_size: Float | Float[Array, "
n_dim"] to use "n_dim" (no leading space) and for annotations currently using
"n_chains n_dim" (several array/type hints throughout RQSpline_GRW_PT.py) either
convert them to a comma-separated tuple-style shape string like "n_chains,
n_dim" or suppress the Ruff F722 by adding a trailing # noqa: F722 on those
specific annotations; update all occurrences (including the symbols where you
saw errors) so no shape string contains a leading space and multi-token shapes
are expressed in the accepted form or explicitly noqa'd.
In `@src/flowMC/resource_strategy_bundle/RQSpline_GRW.py`:
- Line 50: The jaxtyping annotation for grw_step_size uses a malformed shape
string with a leading space (" n_dim")—remove the leading space so it reads
"n_dim" in the grw_step_size annotation; additionally locate all other
annotations that use multi-token shape strings like "n_chains n_dim" and
suppress Ruff F722 by appending "# noqa: F722" to those annotation lines (there
are four such occurrences referenced in the review) so the parser accepts them.
Ensure you update the annotation for grw_step_size (the grw_step_size
variable/type hint) and add "# noqa: F722" to each multi-token shape string
annotation occurrence to resolve all six F722 violations.
In `@src/flowMC/resource_strategy_bundle/RQSpline_HMC.py`:
- Line 51: Two forward-annotation issues: remove the leading space in any
jaxtyping shape strings like " n_dim" (e.g., the annotation on condition_matrix
and the prior_mean annotation) so they read "n_dim"; and suppress Python's F722
parser error on the valid jaxtyping shapes that contain an embedded space by
appending "# noqa: F722" to the four annotations that use the shape string
"n_chains n_dim" (the four locations flagged by the linter). Locate the
annotations by searching for the exact shape substrings " n_dim" and "n_chains
n_dim", remove the leading space in the former and add the trailing "# noqa:
F722" comment to the latter so the jaxtyping syntax is preserved while silencing
F722.
In `@src/flowMC/resource_strategy_bundle/RQSpline_MALA_PT.py`:
- Line 57: Remove the leading space inside jaxtyping shape strings for
annotations: change " n_dim" to "n_dim" in the annotations for logpdf,
mala_step_size, and logprior (look for parameters named logpdf, mala_step_size,
and logprior in RQSpline_MALA_PT.py). For the multi-word shape annotations that
trigger Ruff F722, append a local suppression comment "# noqa: F722" to the
affected parameter annotations that use Float[Array, "n_chains n_dim"] (e.g.,
the parameters named initial_position and any other parameters typed as
Float[Array, "n_chains n_dim"] in the function signatures around the occurrences
reported); do not alter the jaxtyping strings themselves—only remove the stray
leading space in single-word shapes and add "# noqa: F722" to the multi-word
shape annotation lines.
In `@src/flowMC/resource_strategy_bundle/RQSpline_MALA.py`:
- Line 49: Fix the broken jaxtyping shape strings in the RQSpline_MALA
annotations: remove the leading space in the single-dimension shape string ("
n_dim" → "n_dim") for the parameter annotated as mala_step_size (and any other
annotations using " n_dim"), and change the multi-dimension shape strings from
space-separated to comma-separated ("n_chains n_dim" → "n_chains, n_dim") for
the four annotations flagged (the ones using "n_chains n_dim"); remove any "#
noqa: F722" suppressions after correcting the strings so the annotations follow
valid jaxtyping syntax.
ℹ️ Review info
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
src/flowMC/resource_strategy_bundle/RQSpline_GRW.pysrc/flowMC/resource_strategy_bundle/RQSpline_GRW_PT.pysrc/flowMC/resource_strategy_bundle/RQSpline_HMC.pysrc/flowMC/resource_strategy_bundle/RQSpline_HMC_PT.pysrc/flowMC/resource_strategy_bundle/RQSpline_MALA.pysrc/flowMC/resource_strategy_bundle/RQSpline_MALA_PT.py
Summary by CodeRabbit
New Features
Refactor
Documentation
Chores
Tests