Skip to content

Merge recent updates from dev branch#58

Merged
thomasckng merged 35 commits intomainfrom
flowMC-dev
Mar 2, 2026
Merged

Merge recent updates from dev branch#58
thomasckng merged 35 commits intomainfrom
flowMC-dev

Conversation

@thomasckng
Copy link
Member

@thomasckng thomasckng commented Feb 27, 2026

Summary by CodeRabbit

  • New Features

    • Added a CheckEarlyStop strategy and optional early‑stopping and adaptive step‑size controls for samplers.
  • Refactor

    • Migrated to the newer JAX RNG API (jax.random.key) and updated jaxtyping Key usage and related type hints across the library.
  • Documentation

    • Updated quickstart to use the new RNG API.
  • Chores

    • Updated project metadata, bumped pre‑commit tools, and removed Renovate configuration.
  • Tests

    • Added/updated tests for early‑stop behaviour and RNG initialisation.

- Updated the usage of PRNGKeyArray to Key in various modules to align with the latest jaxtyping standards.
- Modified the initialization of random keys in tests to use the new key generation method.
- Ensured consistency in random key handling across resource strategies, strategies, and tests.
feat: implement early stopping strategy
@thomasckng thomasckng added this to the FlowMC 0.5.0 milestone Feb 27, 2026
@thomasckng thomasckng self-assigned this Feb 27, 2026
@thomasckng thomasckng added the enhancement New feature or request label Feb 27, 2026
@coderabbitai
Copy link

coderabbitai bot commented Feb 27, 2026

Warning

Rate limit exceeded

@thomasckng has exceeded the limit for the number of commits that can be reviewed per hour. Please wait 19 minutes and 8 seconds before requesting another review.

⌛ How to resolve this issue?

After the wait time has elapsed, a review can be triggered using the @coderabbitai review command as a PR comment. Alternatively, push new commits to this PR.

We recommend that you space out your commits to avoid hitting the rate limit.

🚦 How do rate limits work?

CodeRabbit enforces hourly rate limits for each developer per organization.

Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout.

Please see our FAQ for further information.

📥 Commits

Reviewing files that changed from the base of the PR and between 31bb6f4 and e2f2112.

⛔ Files ignored due to path filters (1)
  • uv.lock is excluded by !**/*.lock
📒 Files selected for processing (1)
  • src/flowMC/resource/model/flowmatching/base.py
📝 Walkthrough

Walkthrough

Migrates JAX RNG typing from PRNGKeyArray to Key across the codebase, adds a new CheckEarlyStop strategy and wires early-stopping into bundles and the Sampler, standardises lambda/stepper signatures, removes Renovate config, updates pre-commit hooks and project metadata, and updates docs/tests to the new JAX key API.

Changes

Cohort / File(s) Summary
Infrastructure & Tooling
\.github/renovate.json, .pre-commit-config.yaml, pyproject.toml
Removed Renovate config; bumped pre-commit hooks; updated author name and added a maintainer.
Type Migration — Core RNG Types
src/flowMC/.../kernel/*, src/flowMC/.../model/*, src/flowMC/strategy/*, src/flowMC/resource_strategy_bundle/*, src/flowMC/Sampler.py
Replaced PRNGKeyArray with Key in imports, type hints and public signatures across kernels, models, strategies, bundles and Sampler; fixed minor annotation spacing.
New Strategy
src/flowMC/strategy/check_early_stop.py
Added CheckEarlyStop strategy that monitors acceptance buffers, computes mean/CoV, tracks patience and sets state.data["early_stopped"] when criteria met.
Early-stopping Integration
src/flowMC/resource_strategy_bundle/*
Added constructor flags (adapt_step_size, early_stopping, tolerance/patience/min_acceptance), added "early_stopped" to sampler_state, registered check_early_stop strategy and conditionally included it in training phases.
Sampler Control-flow Update
src/flowMC/Sampler.py
Sampler inspects State resources for early_stopped, sets skip_to_production to bypass remaining training strategies until reset_steppers clears it; improved invalid-strategy error messages.
Lambda / Stepper Signatures
src/flowMC/resource_strategy_bundle/*_*.py
Standardised lambda/stepper signatures to (rng_key, resources, initial_position, data) returning tuples; updated types to Key; set fixed n_loops_skip defaults for adapters/checks.
Strategy & Optimization Type Updates
src/flowMC/strategy/{adapt_step_size,optimization,parallel_tempering,sequential_monte_carlo,take_steps,lambda_function,train_model,update_state}.py
Updated strategy/optimization public signatures and returns to use Key; formatting cleanups in annotation strings.
Model & Kernel Type Updates
src/flowMC/resource/model/*, src/flowMC/resource/kernel/*
Updated public constructor/method signatures and docstrings to use Key instead of PRNGKeyArray.
Docs & Tests — JAX RNG API
docs/quickstart.md, tests/{integration,unit}/*
Replaced jax.random.PRNGKey(...) with jax.random.key(...) across docs and tests.
Resource Buffers minor typing
src/flowMC/resource/buffers.py
Trimmed spacing inside string-literal shape annotations for Buffer.data and related typing.

Sequence Diagram(s)

sequenceDiagram
    participant Sampler
    participant Strategy as Training\nStrategy
    participant State
    participant Resources

    Sampler->>Strategy: Call strategy(rng_key, resources, ...)
    Strategy-->>Sampler: (rng_key, resources, position)
    Sampler->>State: Read state.data["early_stopped"]
    alt early_stopped == True
        Sampler->>Sampler: Set skip_to_production = True
        Sampler->>Strategy: Skip remaining training strategies until reset_steppers
    else
        Sampler->>Strategy: Continue calling next training strategy
    end
    alt on reset_steppers
        Sampler->>State: Clear state.data["early_stopped"]
        Sampler->>Sampler: Set skip_to_production = False (enter production)
    end
Loading
sequenceDiagram
    participant Bundle
    participant Check as CheckEarlyStop\nStrategy
    participant State
    participant Buffer as Acceptance\nBuffer

    Bundle->>Check: Instantiate(params: tolerance, patience, ...)
    loop each training loop
        Bundle->>Check: __call__(rng_key, resources, initial_position, data)
        Check->>State: Read state resource (state_name)
        State->>Buffer: Provide acceptance buffer data
        Check->>Check: Compute mean & CoV, compare to previous
        alt Stability & patience reached OR min_acceptance reached
            Check->>State: state.data["early_stopped"] = True
            Check-->>Bundle: Return updated resources with early_stopped
        else
            Check-->>Bundle: Return without change
        end
    end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~75 minutes

Possibly related PRs

Suggested labels

dependencies

🚥 Pre-merge checks | ✅ 1 | ❌ 2

❌ Failed checks (1 warning, 1 inconclusive)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 33.75% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Title check ❓ Inconclusive The title 'Merge recent updates from dev branch' is vague and generic, using non-descriptive language that does not meaningfully convey the actual changes made in the pull request. Consider revising the title to be more specific about the primary changes, such as 'Migrate JAX RNG types from PRNGKeyArray to Key' or 'Add early-stopping mechanism and update RNG type system', which would better reflect the substantial type system migration and new feature additions.
✅ Passed checks (1 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch flowMC-dev

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@thomasckng
Copy link
Member Author

@coderabbitai review

@coderabbitai
Copy link

coderabbitai bot commented Feb 27, 2026

✅ Actions performed

Review triggered.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

🧹 Nitpick comments (5)
src/flowMC/resource/model/common.py (1)

54-55: Unused key parameter in Distribution.__call__.

The key parameter is declared but unused — the method simply delegates to self.log_prob(x). This appears intentional for API consistency, allowing subclasses to optionally use stochastic evaluation.

Consider adding a brief comment or using del key to explicitly acknowledge this is intentional, similar to the pattern suggested for rqSpline.py.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/flowMC/resource/model/common.py` around lines 54 - 55, The unused key
parameter in Distribution.__call__ should be explicitly acknowledged: inside the
__call__ method (Distribution.__call__), either add a short comment noting key
is kept for API compatibility or add a no-op statement like "del key" before
returning self.log_prob(x) so linters/readers know the unused parameter is
intentional; update only the __call__ method accordingly to avoid changing
behavior.
src/flowMC/resource/model/nf_model/rqSpline.py (1)

453-458: Unused key parameter in forward method.

The static analysis correctly identifies that the key parameter is declared but never used within the forward method. This parameter appears to be included for API consistency with other methods or potential future use with stochastic forward passes.

If the parameter is intentional for interface consistency, consider adding a brief comment or using _ = key to explicitly acknowledge it's unused. Otherwise, if it's not needed, consider removing it.

🔧 Option to explicitly acknowledge unused parameter
     def forward(
         self,
         x: Float[Array, " n_dim"],
         key: Optional[Key] = None,
         condition: Optional[Float[Array, " n_condition"]] = None,
     ) -> tuple[Float[Array, " n_dim"], Float]:
+        del key  # Unused, included for API consistency
         log_det = 0.0
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/flowMC/resource/model/nf_model/rqSpline.py` around lines 453 - 458, The
forward method declares an unused parameter key; either remove key from the
forward signature if it's unnecessary, or preserve API consistency by explicitly
acknowledging it inside the method (e.g., add a line like "_ = key" or a short
comment "# key unused (kept for API compatibility)" near the top of
rqSpline.forward) so static analysis no longer flags it; reference the forward
method and the key parameter when making the change.
src/flowMC/strategy/optimization.py (1)

95-122: Misplaced docstring will not be recognised by Python.

The docstring for the optimize method (lines 103–122) is placed after the bounds validation code (lines 95–101). In Python, docstrings must immediately follow the function signature to be accessible via help() or __doc__. This appears to be a pre-existing issue but worth addressing.

♻️ Proposed fix to relocate the docstring
     def optimize(
         self,
         rng_key: Key,
         objective: Callable,
         initial_position: Float[Array, " n_chain n_dim"],
         data: dict,
     ):
+        """Optimization kernel. This can be used independently of the __call__ method.
+
+        Args:
+            rng_key: Key
+                Random key for the optimization.
+            objective: Callable
+                Objective function to optimize.
+            initial_position: Float[Array, " n_chain n_dim"]
+                Initial positions for the optimization.
+            data: dict
+                Data to pass to the objective function.
+
+        Returns:
+            rng_key: Key
+                Updated random key.
+            optimized_positions: Float[Array, " n_chain n_dim"]
+                Optimized positions.
+            final_log_prob: Float[Array, " n_chain"]
+                Final log-probabilities of the optimized positions.
+        """
         # Validate bounds shape against n_dim
         n_dim = initial_position.shape[-1]
         if not (self.bounds.shape[0] == 1 or self.bounds.shape[0] == n_dim):
             raise ValueError(
                 f"bounds shape {self.bounds.shape} is incompatible with n_dim={n_dim}. "
                 "Provide bounds of shape (1, 2) for broadcasting or (n_dim, 2) for per-dimension bounds."
             )

-        """Optimization kernel. This can be used independently of the __call__ method.
-
-        Args:
-            rng_key: Key
-                Random key for the optimization.
-            objective: Callable
-                Objective function to optimize.
-            initial_position: Float[Array, " n_chain n_dim"]
-                Initial positions for the optimization.
-            data: dict
-                Data to pass to the objective function.
-
-        Returns:
-            rng_key: Key
-                Updated random key.
-            optimized_positions: Float[Array, " n_chain n_dim"]
-                Optimized positions.
-            final_log_prob: Float[Array, " n_chain"]
-                Final log-probabilities of the optimized positions.
-        """
         grad_fn = jax.jit(jax.grad(objective))
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/flowMC/strategy/optimization.py` around lines 95 - 122, The docstring for
the optimize method is placed after the bounds validation and will not be
recognized; move the triple-quoted docstring so it immediately follows the def
optimize(...) signature (before any code such as the bounds validation
referencing initial_position or self.bounds), preserving the current content and
indentation, and remove the stray duplicate docstring block after the validation
so optimize.__doc__ is correct.
tests/unit/test_nf.py (1)

30-30: Remove dead key-split outputs in these tests.

Line 30 unpacks rng_subkey but never uses it.
Lines 63-65 split a key, then ignore it and build the model from a separate key.

Proposed cleanup
-    rng_key, rng_subkey = jax.random.split(jax.random.key(0), 2)
+    rng_key, _ = jax.random.split(jax.random.key(0), 2)
@@
-    rng_key, rng_subkey = jax.random.split(jax.random.key(0), 2)
     model = MaskedCouplingRQSpline(
         n_features, n_layers, hidden_layes, n_bins, jax.random.key(10)
     )

Also applies to: 63-65

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unit/test_nf.py` at line 30, Remove unused split outputs by changing
calls to jax.random.split that currently unpack into rng_key, rng_subkey (e.g.,
the statement using jax.random.split(jax.random.key(0), 2)) so only the used key
is kept; specifically drop rng_subkey from the unpack and just assign the single
used key (or call jax.random.key(0) directly) wherever rng_subkey is never
referenced. Also fix the later occurrence around lines 63-65 where a key is
split and the result ignored—stop creating the dead key and pass the actual key
you use to build the model (reference jax.random.split and the variables
rng_key/rng_subkey and the model construction call) so no unused key variables
remain.
src/flowMC/resource_strategy_bundle/RQSpline_GRW.py (1)

231-238: Improve callback transparency by returning the strategy tuple explicitly.

The current code at lines 231–238 is safe—Lambda.call explicitly ignores callback return values and always returns (rng_key, resources, initial_position). However, the callbacks return None (from set_current_position), which depends on this implementation detail. Making callbacks return the strategy tuple explicitly improves code clarity and reduces dependency on Lambda's internal behaviour.

Suggested improvement
-        update_global_step = Lambda(
-            lambda rng_key, resources, initial_position, data: (
-                global_stepper.set_current_position(local_stepper.current_position)
-            )
-        )
-        update_local_step = Lambda(
-            lambda rng_key, resources, initial_position, data: (
-                local_stepper.set_current_position(global_stepper.current_position)
-            )
-        )
+        def sync_global_step(rng_key, resources, initial_position, data):
+            global_stepper.set_current_position(local_stepper.current_position)
+            return rng_key, resources, initial_position
+
+        def sync_local_step(rng_key, resources, initial_position, data):
+            local_stepper.set_current_position(global_stepper.current_position)
+            return rng_key, resources, initial_position
+
+        update_global_step = Lambda(sync_global_step)
+        update_local_step = Lambda(sync_local_step)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/flowMC/resource_strategy_bundle/RQSpline_GRW.py` around lines 231 - 238,
The lambdas used for stepper callbacks (the Lambda instances assigned around
update_local_step and the earlier global setter) currently call
global_stepper.set_current_position(...) and
local_stepper.set_current_position(...) which return None, relying on
Lambda.__call__ to ignore callback return values; change each callback to
perform the set_current_position call and then explicitly return the strategy
tuple (rng_key, resources, initial_position) so the callbacks return the
expected (rng_key, resources, initial_position) tuple instead of None (e.g.,
inside the Lambda bodies for the callbacks referencing global_stepper and
local_stepper).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@src/flowMC/resource_strategy_bundle/RQSpline_HMC_PT.py`:
- Line 250: Add a "# noqa: F722" comment to the jaxtyping forward-annotation
lines that use multi-word shape strings (e.g., the parameter signature
"initial_position: Float[Array, \"n_chains n_dim\"]") to silence Ruff F722;
update the exact signature lines at the shown locations (and the analogous lines
at 282 and 322 and the same patterns in RQSpline_GRW_PT.py, RQSpline_HMC.py,
RQSpline_MALA.py, and RQSpline_MALA_PT.py) by appending "# noqa: F722" to those
annotation lines so the multi-word shape literals are ignored by the
forward-annotation checker.

In `@src/flowMC/resource/model/flowmatching/base.py`:
- Around line 275-285: The train function's docstring is out of sync: remove the
obsolete `model` parameter mention and update parameter and return descriptions
to match the current signature (e.g., `rng`, `data`, `num_epochs`, `batch_size`,
`verbose`, and the optimizer/state objects actually passed such as `optim` or
`state`), and ensure the Returns section lists the actual returned items (e.g.,
updated `rng`, `state`/`optim`, and `loss_values` as appropriate); update
parameter names and types to match those used in the `train` function in this
file (search for `def train(`) and keep descriptions concise and accurate.
- Around line 171-172: The return type annotation for FlowMatchingModel.sample
is missing the batch dimension: update the annotated return shape to include
num_samples (e.g., change Float[Array, " n_dim"] to include the batch axis such
as Float[Array, " num_samples n_dim"]) so the signature for sample(self,
rng_key: Key, num_samples: int, dt: Float = 1e-1) -> ... correctly reflects a
batched array; locate the method named sample in class FlowMatchingModel in
base.py and adjust its return type annotation accordingly.

In `@src/flowMC/resource/model/nf_model/base.py`:
- Line 67: The forward-annotation strings contain leading spaces which cause
F722 errors; remove the stray spaces inside all Array dimension string
annotations (e.g., change Float[Array, " n_dim"] to Float[Array, "n_dim"] and
Float[Array, " n_batch n_dim"] to Float[Array, "n_batch n_dim"]) across this
module (methods using the Float[Array, "..."] annotation such as the method with
signature containing x: Float[Array, " n_dim"]) and also in
src/flowMC/resource_strategy_bundle/RQSpline_GRW.py so that no dimension string
contains leading spaces.

In `@src/flowMC/strategy/check_early_stop.py`:
- Around line 72-74: Docstring for the function/class that defines the
"patience" parameter (see signature with patience: int = 3) is
inconsistent—docstring says "Defaults to 1" while the actual default is 3;
update either the docstring or the function signature so they match. Locate the
"patience" parameter in check_early_stop (or the function/class in
src/flowMC/strategy/check_early_stop.py that declares patience: int = 3) and
change the docstring default text to "Defaults to 3" (or change the signature
default to 1 if intended) so the declared default and the docstring are
consistent. Ensure the docstring line describing patience is updated exactly
where the parameter is documented.

---

Nitpick comments:
In `@src/flowMC/resource_strategy_bundle/RQSpline_GRW.py`:
- Around line 231-238: The lambdas used for stepper callbacks (the Lambda
instances assigned around update_local_step and the earlier global setter)
currently call global_stepper.set_current_position(...) and
local_stepper.set_current_position(...) which return None, relying on
Lambda.__call__ to ignore callback return values; change each callback to
perform the set_current_position call and then explicitly return the strategy
tuple (rng_key, resources, initial_position) so the callbacks return the
expected (rng_key, resources, initial_position) tuple instead of None (e.g.,
inside the Lambda bodies for the callbacks referencing global_stepper and
local_stepper).

In `@src/flowMC/resource/model/common.py`:
- Around line 54-55: The unused key parameter in Distribution.__call__ should be
explicitly acknowledged: inside the __call__ method (Distribution.__call__),
either add a short comment noting key is kept for API compatibility or add a
no-op statement like "del key" before returning self.log_prob(x) so
linters/readers know the unused parameter is intentional; update only the
__call__ method accordingly to avoid changing behavior.

In `@src/flowMC/resource/model/nf_model/rqSpline.py`:
- Around line 453-458: The forward method declares an unused parameter key;
either remove key from the forward signature if it's unnecessary, or preserve
API consistency by explicitly acknowledging it inside the method (e.g., add a
line like "_ = key" or a short comment "# key unused (kept for API
compatibility)" near the top of rqSpline.forward) so static analysis no longer
flags it; reference the forward method and the key parameter when making the
change.

In `@src/flowMC/strategy/optimization.py`:
- Around line 95-122: The docstring for the optimize method is placed after the
bounds validation and will not be recognized; move the triple-quoted docstring
so it immediately follows the def optimize(...) signature (before any code such
as the bounds validation referencing initial_position or self.bounds),
preserving the current content and indentation, and remove the stray duplicate
docstring block after the validation so optimize.__doc__ is correct.

In `@tests/unit/test_nf.py`:
- Line 30: Remove unused split outputs by changing calls to jax.random.split
that currently unpack into rng_key, rng_subkey (e.g., the statement using
jax.random.split(jax.random.key(0), 2)) so only the used key is kept;
specifically drop rng_subkey from the unpack and just assign the single used key
(or call jax.random.key(0) directly) wherever rng_subkey is never referenced.
Also fix the later occurrence around lines 63-65 where a key is split and the
result ignored—stop creating the dead key and pass the actual key you use to
build the model (reference jax.random.split and the variables rng_key/rng_subkey
and the model construction call) so no unused key variables remain.

ℹ️ Review info

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 2416b11 and 1e44f82.

⛔ Files ignored due to path filters (1)
  • uv.lock is excluded by !**/*.lock
📒 Files selected for processing (43)
  • .coverage
  • .github/renovate.json
  • .pre-commit-config.yaml
  • docs/quickstart.md
  • pyproject.toml
  • src/flowMC/Sampler.py
  • src/flowMC/resource/kernel/Gaussian_random_walk.py
  • src/flowMC/resource/kernel/HMC.py
  • src/flowMC/resource/kernel/MALA.py
  • src/flowMC/resource/kernel/NF_proposal.py
  • src/flowMC/resource/kernel/base.py
  • src/flowMC/resource/model/common.py
  • src/flowMC/resource/model/flowmatching/base.py
  • src/flowMC/resource/model/nf_model/base.py
  • src/flowMC/resource/model/nf_model/realNVP.py
  • src/flowMC/resource/model/nf_model/rqSpline.py
  • src/flowMC/resource_strategy_bundle/RQSpline_GRW.py
  • src/flowMC/resource_strategy_bundle/RQSpline_GRW_PT.py
  • src/flowMC/resource_strategy_bundle/RQSpline_HMC.py
  • src/flowMC/resource_strategy_bundle/RQSpline_HMC_PT.py
  • src/flowMC/resource_strategy_bundle/RQSpline_MALA.py
  • src/flowMC/resource_strategy_bundle/RQSpline_MALA_PT.py
  • src/flowMC/strategy/adapt_step_size.py
  • src/flowMC/strategy/base.py
  • src/flowMC/strategy/check_early_stop.py
  • src/flowMC/strategy/lambda_function.py
  • src/flowMC/strategy/optimization.py
  • src/flowMC/strategy/parallel_tempering.py
  • src/flowMC/strategy/sequential_monte_carlo.py
  • src/flowMC/strategy/take_steps.py
  • src/flowMC/strategy/train_model.py
  • src/flowMC/strategy/update_state.py
  • tests/integration/test_HMC.py
  • tests/integration/test_MALA.py
  • tests/integration/test_RWMCMC.py
  • tests/integration/test_normalizingFlow.py
  • tests/integration/test_quickstart.py
  • tests/unit/test_bundle.py
  • tests/unit/test_flowmatching.py
  • tests/unit/test_kernels.py
  • tests/unit/test_nf.py
  • tests/unit/test_resources.py
  • tests/unit/test_strategies.py
💤 Files with no reviewable changes (1)
  • .github/renovate.json

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
src/flowMC/strategy/optimization.py (1)

95-122: ⚠️ Potential issue | 🟡 Minor

Docstring is misplaced after executable code.

The docstring for the optimize method is placed after the bounds validation logic (lines 95-101). Docstrings must immediately follow the function definition to be recognised by Python's documentation tools.

💡 Suggested fix
     def optimize(
         self,
         rng_key: Key,
         objective: Callable,
         initial_position: Float[Array, "n_chain n_dim"],
         data: dict,
     ):
+        """Optimization kernel. This can be used independently of the __call__ method.
+
+        Args:
+            rng_key: Key
+                Random key for the optimization.
+            objective: Callable
+                Objective function to optimize.
+            initial_position: Float[Array, "n_chain n_dim"]
+                Initial positions for the optimization.
+            data: dict
+                Data to pass to the objective function.
+
+        Returns:
+            rng_key: Key
+                Updated random key.
+            optimized_positions: Float[Array, "n_chain n_dim"]
+                Optimized positions.
+            final_log_prob: Float[Array, "n_chain"]
+                Final log-probabilities of the optimized positions.
+        """
         # Validate bounds shape against n_dim
         n_dim = initial_position.shape[-1]
         if not (self.bounds.shape[0] == 1 or self.bounds.shape[0] == n_dim):
             raise ValueError(
                 f"bounds shape {self.bounds.shape} is incompatible with n_dim={n_dim}. "
                 "Provide bounds of shape (1, 2) for broadcasting or (n_dim, 2) for per-dimension bounds."
             )
-
-        """Optimization kernel. This can be used independently of the __call__ method.
-        ...
-        """
         grad_fn = jax.jit(jax.grad(objective))
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/flowMC/strategy/optimization.py` around lines 95 - 122, The optimize
method's docstring is incorrectly placed after executable bounds validation;
move the triple-quoted docstring so it immediately follows the def optimize(...)
signature (before any code like the bounds check using self.bounds and
initial_position), preserving the exact docstring content and indentation, so
Python documentation tools will recognize it and the validation logic (n_dim,
bounds shape check) remains unchanged after the docstring.
♻️ Duplicate comments (2)
src/flowMC/resource/model/flowmatching/base.py (2)

170-172: ⚠️ Potential issue | 🟡 Minor

Return type annotation is missing the batch dimension.

The sample method returns num_samples samples, but the return type annotation indicates a single vector Float[Array, " n_dim"] instead of Float[Array, "num_samples n_dim"].

💡 Suggested fix
     def sample(
         self, rng_key: Key, num_samples: int, dt: Float = 1e-1
-    ) -> Float[Array, " n_dim"]:
+    ) -> Float[Array, "num_samples n_dim"]:
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/flowMC/resource/model/flowmatching/base.py` around lines 170 - 172, The
sample method's return type annotation on class/method sample is missing the
batch dimension; update the return annotation from Float[Array, " n_dim"] to
include the num_samples dimension (e.g. Float[Array, "num_samples n_dim"]) so it
correctly reflects that sample(self, rng_key: Key, num_samples: int, dt: Float =
1e-1) returns a batch of num_samples vectors; apply this change to the sample
signature in flowmatching.base.py and ensure any whitespace/typing string
formatting matches surrounding annotations.

274-285: ⚠️ Potential issue | 🟡 Minor

Docstring is out of sync with the function signature.

The docstring still documents a model argument that does not exist in the signature, and the Returns section is incomplete (missing state return value).

💡 Suggested docstring correction
         """Train a normalizing flow model.

         Args:
             rng (Key): JAX PRNGKey.
-            model (eqx.Module): NF model to train.
             data (Array): Training data.
+            optim (optax.GradientTransformation): Optimiser transformation.
+            state (optax.OptState): Optimiser state.
             num_epochs (int): Number of epochs to train for.
             batch_size (int): Batch size.
             verbose (bool): Whether to print progress.

         Returns:
             rng (Key): Updated JAX PRNGKey.
-            model (eqx.Model): Updated NF model.
+            model (Self): Updated NF model.
+            state (optax.OptState): Updated optimiser state.
             loss_values (Array): Loss values.
         """
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/flowMC/resource/model/flowmatching/base.py` around lines 274 - 285, The
docstring currently documents a removed `model` parameter and omits the `state`
return; update the docstring for the function containing the shown Args so the
Args list matches the actual signature (include `rng`, `state`, `data`,
`num_epochs`, `batch_size`, `verbose` as present) and remove any reference to
`model`, and update the Returns section to list both `rng` and `state` with
their correct types (and `loss_values` if returned) and short descriptions; also
adjust the top summary line to reflect the real behavior of the function and
ensure parameter names `rng` and `state` match exactly the symbols in the
implementation.
🧹 Nitpick comments (3)
src/flowMC/resource/model/nf_model/realNVP.py (1)

177-182: Unused key parameter in forward method.

The key parameter is accepted but never used in the method body. RealNVP is a deterministic flow, so this is expected behaviour. If this parameter exists solely for interface compatibility with stochastic flow models, consider documenting this or prefixing with underscore (_key) to suppress the linter warning.

♻️ Optional: Suppress linter warning
     def forward(
         self,
         x: Float[Array, " n_dim"],
-        key: Optional[Key] = None,
+        _key: Optional[Key] = None,  # Unused; kept for interface compatibility
         condition: Optional[Float[Array, " n_condition"]] = None,
     ) -> tuple[Float[Array, " n_dim"], Float]:
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/flowMC/resource/model/nf_model/realNVP.py` around lines 177 - 182, The
forward method in RealNVP (def forward(...)) accepts a key parameter but never
uses it; to fix, either rename the parameter to _key to suppress linter warnings
or add a short docstring/comment in the RealNVP.forward signature body stating
that key is accepted for interface compatibility with stochastic flows and is
intentionally unused; update references to the parameter name accordingly to
avoid unused-variable errors.
src/flowMC/resource/model/common.py (1)

54-55: Consider documenting or removing the unused key parameter.

The key parameter in Distribution.__call__ is currently unused. If this is intentional for interface consistency with subclasses that may need randomness, consider adding a brief docstring explaining this. Otherwise, it could be removed.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/flowMC/resource/model/common.py` around lines 54 - 55, The __call__
method on Distribution currently accepts an unused key parameter; either
document its purpose for interface consistency or remove it: add a short
docstring to Distribution.__call__ explaining that key is optional and unused by
the base implementation but kept so subclasses that require RNG (e.g., sampling
subclasses overriding __call__ or methods like log_prob) can accept it, or
remove key from Distribution.__call__ signature and update any subclass
overrides and callers to match (search for __call__ and log_prob in the class to
locate affected code).
src/flowMC/strategy/optimization.py (1)

42-54: Consider using None as default and initialising bounds inside the function.

Using jnp.array(...) as a default argument value triggers B008. While JAX arrays are immutable and this is safe in practice, the idiomatic pattern is to use None as the default and initialise inside the function body.

💡 Suggested fix
     def __init__(
         self,
         logpdf: Callable[[Float[Array, " n_dim"], dict], Float],
         n_steps: int = 100,
         learning_rate: float = 1e-2,
         noise_level: float = 10,
-        bounds: Float[Array, "n_dim 2"] = jnp.array([[-jnp.inf, jnp.inf]]),
+        bounds: Float[Array, "n_dim 2"] | None = None,
     ):
         self.logpdf = logpdf
         self.n_steps = n_steps
         self.learning_rate = learning_rate
         self.noise_level = noise_level
-        self.bounds = bounds
+        self.bounds = bounds if bounds is not None else jnp.array([[-jnp.inf, jnp.inf]])

Note: The class attribute default and docstring would also need updating.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/flowMC/strategy/optimization.py` around lines 42 - 54, The __init__
method currently uses a JAX array literal as the default for the bounds
parameter which triggers B008; change the signature to use bounds:
Optional[...]=None and inside __init__ set self.bounds = jnp.array([[-jnp.inf,
jnp.inf]]) if bounds is None else bounds (ensuring the type matches expected
Float[Array, "n_dim 2"]), and update the constructor docstring/default class
attribute comment accordingly; refer to the __init__ parameter name bounds and
the instance attribute self.bounds when making the change.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@src/flowMC/resource/model/nf_model/base.py`:
- Around line 166-179: Update the docstring for the train function in
nf_model.base to match the actual signature: remove the nonexistent `model`
parameter from the Args section, ensure the listed arguments (e.g., rng, data,
num_epochs, batch_size, verbose) exactly match the function signature, and add
the missing `state` to the Returns section (include its type, e.g.,
optimizer/state object or eqx state) alongside rng, model, and loss_values;
reference the function by name (train) and the returned `state` symbol when
making these edits.

---

Outside diff comments:
In `@src/flowMC/strategy/optimization.py`:
- Around line 95-122: The optimize method's docstring is incorrectly placed
after executable bounds validation; move the triple-quoted docstring so it
immediately follows the def optimize(...) signature (before any code like the
bounds check using self.bounds and initial_position), preserving the exact
docstring content and indentation, so Python documentation tools will recognize
it and the validation logic (n_dim, bounds shape check) remains unchanged after
the docstring.

---

Duplicate comments:
In `@src/flowMC/resource/model/flowmatching/base.py`:
- Around line 170-172: The sample method's return type annotation on
class/method sample is missing the batch dimension; update the return annotation
from Float[Array, " n_dim"] to include the num_samples dimension (e.g.
Float[Array, "num_samples n_dim"]) so it correctly reflects that sample(self,
rng_key: Key, num_samples: int, dt: Float = 1e-1) returns a batch of num_samples
vectors; apply this change to the sample signature in flowmatching.base.py and
ensure any whitespace/typing string formatting matches surrounding annotations.
- Around line 274-285: The docstring currently documents a removed `model`
parameter and omits the `state` return; update the docstring for the function
containing the shown Args so the Args list matches the actual signature (include
`rng`, `state`, `data`, `num_epochs`, `batch_size`, `verbose` as present) and
remove any reference to `model`, and update the Returns section to list both
`rng` and `state` with their correct types (and `loss_values` if returned) and
short descriptions; also adjust the top summary line to reflect the real
behavior of the function and ensure parameter names `rng` and `state` match
exactly the symbols in the implementation.

---

Nitpick comments:
In `@src/flowMC/resource/model/common.py`:
- Around line 54-55: The __call__ method on Distribution currently accepts an
unused key parameter; either document its purpose for interface consistency or
remove it: add a short docstring to Distribution.__call__ explaining that key is
optional and unused by the base implementation but kept so subclasses that
require RNG (e.g., sampling subclasses overriding __call__ or methods like
log_prob) can accept it, or remove key from Distribution.__call__ signature and
update any subclass overrides and callers to match (search for __call__ and
log_prob in the class to locate affected code).

In `@src/flowMC/resource/model/nf_model/realNVP.py`:
- Around line 177-182: The forward method in RealNVP (def forward(...)) accepts
a key parameter but never uses it; to fix, either rename the parameter to _key
to suppress linter warnings or add a short docstring/comment in the
RealNVP.forward signature body stating that key is accepted for interface
compatibility with stochastic flows and is intentionally unused; update
references to the parameter name accordingly to avoid unused-variable errors.

In `@src/flowMC/strategy/optimization.py`:
- Around line 42-54: The __init__ method currently uses a JAX array literal as
the default for the bounds parameter which triggers B008; change the signature
to use bounds: Optional[...]=None and inside __init__ set self.bounds =
jnp.array([[-jnp.inf, jnp.inf]]) if bounds is None else bounds (ensuring the
type matches expected Float[Array, "n_dim 2"]), and update the constructor
docstring/default class attribute comment accordingly; refer to the __init__
parameter name bounds and the instance attribute self.bounds when making the
change.

ℹ️ Review info

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 1e44f82 and 7ce0e2d.

📒 Files selected for processing (11)
  • .pre-commit-config.yaml
  • src/flowMC/resource/buffers.py
  • src/flowMC/resource/kernel/HMC.py
  • src/flowMC/resource/kernel/MALA.py
  • src/flowMC/resource/kernel/base.py
  • src/flowMC/resource/model/common.py
  • src/flowMC/resource/model/flowmatching/base.py
  • src/flowMC/resource/model/nf_model/base.py
  • src/flowMC/resource/model/nf_model/realNVP.py
  • src/flowMC/strategy/optimization.py
  • src/flowMC/strategy/parallel_tempering.py
✅ Files skipped from review due to trivial changes (1)
  • src/flowMC/resource/buffers.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • src/flowMC/resource/kernel/MALA.py
  • .pre-commit-config.yaml

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (1)
src/flowMC/resource/model/flowmatching/base.py (1)

171-172: ⚠️ Potential issue | 🟡 Minor

FlowMatchingModel.sample return annotation still misses the batch axis.

At Line 172, the annotated return type is a single vector, but this method returns num_samples rows.

💡 Annotation fix
-    ) -> Float[Array, " n_dim"]:
+    ) -> Float[Array, "num_samples n_dim"]:
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/flowMC/resource/model/flowmatching/base.py` around lines 171 - 172, The
return type annotation of FlowMatchingModel.sample is incorrect: it declares a
single vector but actually returns num_samples rows; update the signature for
sample (method name FlowMatchingModel.sample with parameters rng_key,
num_samples, dt) to include the batch axis — e.g. change the return annotation
to include the leading num_samples dimension (such as Float[Array, "num_samples
n_dim"] or equivalent in your typing convention) so the annotated shape matches
the actual returned array.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@src/flowMC/resource/model/flowmatching/base.py`:
- Line 136: Several jaxtyping forward annotations like the attribute _data_cov
use shape strings with spaces (e.g., Float[Array, "n_dim n_dim"]) which trigger
Ruff F722; fix by appending "# noqa: F722" to each affected annotation line
using the jaxtyping Float[Array, "..."] pattern in this file (for example the
_data_cov attribute and all other class/attribute annotations that use
Float[Array, "..."]), ensuring every occurrence listed in the review gets the
noqa comment so CI passes; alternatively, if you prefer project-wide handling,
add an exclusion for F722 for this module in the Ruff config instead of per-line
noqa.

In `@src/flowMC/strategy/check_early_stop.py`:
- Around line 117-131: The asserts use walrus assignments which are removed with
-O; separate assignments from checks in the function (look for state_name and
acceptance_buffer_key usage inside check_early_stop) by first assigning state =
resources[self.state_name], buffer_name =
state.data.get(self.acceptance_buffer_key), and acceptance_buffer =
resources[buffer_name], then perform explicit validation (either plain assert
without walrus or explicit if checks raising TypeError/ValueError) that state is
a State, buffer_name is a str, and acceptance_buffer is a Buffer so the
variables exist regardless of optimization flags.

---

Duplicate comments:
In `@src/flowMC/resource/model/flowmatching/base.py`:
- Around line 171-172: The return type annotation of FlowMatchingModel.sample is
incorrect: it declares a single vector but actually returns num_samples rows;
update the signature for sample (method name FlowMatchingModel.sample with
parameters rng_key, num_samples, dt) to include the batch axis — e.g. change the
return annotation to include the leading num_samples dimension (such as
Float[Array, "num_samples n_dim"] or equivalent in your typing convention) so
the annotated shape matches the actual returned array.

ℹ️ Review info

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 7ce0e2d and 1745377.

⛔ Files ignored due to path filters (1)
  • uv.lock is excluded by !**/*.lock
📒 Files selected for processing (3)
  • src/flowMC/resource/model/flowmatching/base.py
  • src/flowMC/resource/model/nf_model/base.py
  • src/flowMC/strategy/check_early_stop.py

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@src/flowMC/resource_strategy_bundle/RQSpline_GRW_PT.py`:
- Line 57: Remove the leading-space shape tokens and fix multi-token jaxtyping
shape strings in this file: e.g., change grw_step_size: Float | Float[Array, "
n_dim"] to use "n_dim" (no leading space) and for annotations currently using
"n_chains n_dim" (several array/type hints throughout RQSpline_GRW_PT.py) either
convert them to a comma-separated tuple-style shape string like "n_chains,
n_dim" or suppress the Ruff F722 by adding a trailing # noqa: F722 on those
specific annotations; update all occurrences (including the symbols where you
saw errors) so no shape string contains a leading space and multi-token shapes
are expressed in the accepted form or explicitly noqa'd.

In `@src/flowMC/resource_strategy_bundle/RQSpline_GRW.py`:
- Line 50: The jaxtyping annotation for grw_step_size uses a malformed shape
string with a leading space (" n_dim")—remove the leading space so it reads
"n_dim" in the grw_step_size annotation; additionally locate all other
annotations that use multi-token shape strings like "n_chains n_dim" and
suppress Ruff F722 by appending "# noqa: F722" to those annotation lines (there
are four such occurrences referenced in the review) so the parser accepts them.
Ensure you update the annotation for grw_step_size (the grw_step_size
variable/type hint) and add "# noqa: F722" to each multi-token shape string
annotation occurrence to resolve all six F722 violations.

In `@src/flowMC/resource_strategy_bundle/RQSpline_HMC.py`:
- Line 51: Two forward-annotation issues: remove the leading space in any
jaxtyping shape strings like " n_dim" (e.g., the annotation on condition_matrix
and the prior_mean annotation) so they read "n_dim"; and suppress Python's F722
parser error on the valid jaxtyping shapes that contain an embedded space by
appending "# noqa: F722" to the four annotations that use the shape string
"n_chains n_dim" (the four locations flagged by the linter). Locate the
annotations by searching for the exact shape substrings " n_dim" and "n_chains
n_dim", remove the leading space in the former and add the trailing "# noqa:
F722" comment to the latter so the jaxtyping syntax is preserved while silencing
F722.

In `@src/flowMC/resource_strategy_bundle/RQSpline_MALA_PT.py`:
- Line 57: Remove the leading space inside jaxtyping shape strings for
annotations: change " n_dim" to "n_dim" in the annotations for logpdf,
mala_step_size, and logprior (look for parameters named logpdf, mala_step_size,
and logprior in RQSpline_MALA_PT.py). For the multi-word shape annotations that
trigger Ruff F722, append a local suppression comment "# noqa: F722" to the
affected parameter annotations that use Float[Array, "n_chains n_dim"] (e.g.,
the parameters named initial_position and any other parameters typed as
Float[Array, "n_chains n_dim"] in the function signatures around the occurrences
reported); do not alter the jaxtyping strings themselves—only remove the stray
leading space in single-word shapes and add "# noqa: F722" to the multi-word
shape annotation lines.

In `@src/flowMC/resource_strategy_bundle/RQSpline_MALA.py`:
- Line 49: Fix the broken jaxtyping shape strings in the RQSpline_MALA
annotations: remove the leading space in the single-dimension shape string ("
n_dim" → "n_dim") for the parameter annotated as mala_step_size (and any other
annotations using " n_dim"), and change the multi-dimension shape strings from
space-separated to comma-separated ("n_chains n_dim" → "n_chains, n_dim") for
the four annotations flagged (the ones using "n_chains n_dim"); remove any "#
noqa: F722" suppressions after correcting the strings so the annotations follow
valid jaxtyping syntax.

ℹ️ Review info

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 1745377 and 31bb6f4.

📒 Files selected for processing (6)
  • src/flowMC/resource_strategy_bundle/RQSpline_GRW.py
  • src/flowMC/resource_strategy_bundle/RQSpline_GRW_PT.py
  • src/flowMC/resource_strategy_bundle/RQSpline_HMC.py
  • src/flowMC/resource_strategy_bundle/RQSpline_HMC_PT.py
  • src/flowMC/resource_strategy_bundle/RQSpline_MALA.py
  • src/flowMC/resource_strategy_bundle/RQSpline_MALA_PT.py

@thomasckng thomasckng merged commit 3e08e3b into main Mar 2, 2026
11 checks passed
@thomasckng thomasckng deleted the flowMC-dev branch March 2, 2026 09:12
@thomasckng thomasckng restored the flowMC-dev branch March 2, 2026 09:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant