Skip to content

Chunked + JAX-jit'd MBAR with vectorised batch evaluation#575

Open
jamesoliverh wants to merge 5 commits intochoderalab:mainfrom
jamesoliverh:chunked-u-kn
Open

Chunked + JAX-jit'd MBAR with vectorised batch evaluation#575
jamesoliverh wants to merge 5 commits intochoderalab:mainfrom
jamesoliverh:chunked-u-kn

Conversation

@jamesoliverh
Copy link
Copy Markdown

@jamesoliverh jamesoliverh commented Apr 26, 2026

This PR adds chunked working-array MBAR end-to-end across the fit and the post-fit batch evaluation. This caps peak working memory at O(chunk_size × K) regardless of sample count N, so MBAR can be fit and evaluated on datasets that previously exceeded available RAM. A single chunk_size knob streams the fit's inner reductions over the sample axis, and on the evaluation side replaces the per-target Python loop with one vectorised pass per chunk, skipping the dense (N, K + NL + S) log-weight allocation. JAX-jit'd per-chunk inner kernels run throughout (numpy fallback when JAX isn't available), measuring around 2× faster per chunk than the numpy path. Default behaviour without chunk_size is unchanged. A streaming-callable form for u_kn is also supported, so out-of-core fits don't need to load the full matrix into RAM.

Closes issue #574.

LOC summary

Section Impl LOC Notes
1. End-to-end chunking (_chunked.py) ~290 Whole new module (322 LOC)
2. Vectorised batch evaluation (mbar.py) ~40 chunk_size kwarg + chunked branch + 3 wrapper kwargs
3. JAX work (_chunked.py kernels + mbar_solvers.py bypass) ~70 4 @jit kernels + numpy fallback + JIT-bypass guard
Implementation total ~400
Tests ~625 5 test files
Cumulative PR diff +1149 / -19

Three areas of work

  1. End-to-end chunking. New pymbar/_chunked.py module holds the chunking primitives and per-chunk kernels. The fit's inner reductions and the batch-evaluation reductions both route through it under one chunk_size knob.
  2. Vectorised batch evaluation. compute_expectations_inner and the three public wrappers gain a chunk_size kwarg. When set, all target states are evaluated in one vectorised pass per chunk; the per-target Python loop and the dense (N, K + NL + S) allocation are gone on this branch. Default chunk_size=None runs the existing path bit-for-bit unchanged. return_theta=True and uncertainty_method='bootstrap' are rejected on the chunked path — neither composes cleanly with column-chunked kernels.
  3. JAX work. Four @jit-decorated per-chunk kernels in _chunked.py, with a clean numpy fallback (@jit reduces to identity, jnp becomes np) when JAX is unavailable or PYMBAR_DISABLE_JAX=1. The outer jax.jit that wraps the adaptive solver is bypassed when chunking is active, since the chunked helpers have side-effecting Python loops that can't be traced.

Tests

Unit tests for the new chunked helpers against scipy.special.logsumexp. Integration tests for the three public methods against the dense reference at decimal=10, parametrised over a range of chunk_size. Constraint tests for the return_theta=True / bootstrap rejection paths. A regression test that forces the adaptive solver and clears the JAX cache to exercise the JIT-bypass cold trace. 131 tests pass, 9 xfailed pre-existing, no regressions.

Memory / performance

For workloads where N ≫ chunk_size and K + NL + S is large, target-axis work moves from a Python loop to one vectorised reduction per chunk, and peak memory drops by O(N / chunk_size). JAX-jit'd per-chunk kernels measure around 2× faster than the numpy path.

Out of scope (follow-ons if requested)

  • Streaming u_ln via a callable yielding column chunks (motivation weaker — u_ln is usually constructed cheaply from a target-state list).
  • Theta from chunked weights (would need a redesign of _computeAsymptoticCovarianceMatrix).
  • Bootstrap with chunked path (feasible by re-running chunked helpers per replicate, not in this PR).

Notes for review

The PR is broader than issue #574's original scope: the chunked path covers fit + post-fit eval under one chunk_size knob, with JAX-jit'd per-chunk kernels throughout. If splitting the batch-eval helpers and the compute_* changes off into a separate PR would help review, I can rebase that piece onto its own branch.

When N is large enough that the (N, K) working-array temporaries inside
scipy.special.logsumexp(... b=N_k) exceed available RAM, the solver
OOMs even though u_kn itself fits. This adds an opt-in chunk_size
kwarg that streams these intermediates in (chunk_size, K) slices,
bounding peak working memory to O(chunk_size * K).

The chunked path is numpy-only by design (per-chunk JIT recompilation
would defeat the point). The JAX-jit'd dense path is unchanged when
chunk_size=None, the default.

Side benefit: the chunked numpy path is also 3-4x faster than the
dense numpy path at large N, because the per-chunk working array
fits in CPU cache while the dense (N, K) intermediate is RAM-bandwidth
bound. End-to-end MBAR fit (K=50, N=100k):
  dense numpy:   28.7 s
  chunked numpy:  8.0 s   (3.6x faster)
  max |f_dense - f_chunked| = 1.78e-15

Also adds cache_log_W_nk kwarg to skip the post-fit (N, K) Log_W_nk
allocation when the methods that consume it (compute_overlap,
compute_effective_sample_number, weights, Theta-based uncertainty)
are not needed.

Tests: 35 new (24 helper-equivalence vs scipy + 8 end-to-end MBAR +
3 tracemalloc memory regression). Existing pymbar suite passes
unchanged (129 passed, 21 skipped, 21 xfailed).

Out of scope (PR2 candidates):
- Streaming u_kn itself (when even u_kn doesn't fit).
- Bootstrap with chunk_size (currently raises NotImplementedError).
- compute_expectations for new target states with chunked u_ln.
- JAX-jit'd chunked path with locked chunk size.
@jamesoliverh jamesoliverh changed the title Chunked working-array path for large-N MBAR fits (closes #574) Chunked working-array path for large-N MBAR fits (refs #574) Apr 26, 2026
…b#574)

Extends the chunk_size kwarg from the previous commit: when u_kn is a
zero-arg callable returning a chunk iterator instead of a dense ndarray,
the solver streams chunks from the callable on each pass, never
materialising the full (K, N) matrix.

Three modes are now supported via the same kwargs:

  | u_kn      | chunk_size | behaviour                                 |
  | ndarray   | None       | dense + JAX (existing baseline)           |
  | ndarray   | int        | dense u_kn, chunked working temporaries   |
  | callable  | int (req)  | streaming u_kn + chunked working temps    |

The dense-ndarray + chunk_size mode is the dominant use case (most
users have u_kn fitting in RAM, just need the (N, K) doubling bounded).
The callable mode unlocks the additional case where u_kn itself doesn't
fit, and works transparently with mmap'd .npy files via
np.load(path, mmap_mode='r') on the dense branch.

Memory at the streaming dispatch (tracemalloc, K=50 N=200k chunk=10k):
  dense (N, K) reference: ~80 MB
  streaming chunked:      ~4 MB   (>20x smaller)

The streaming path mutates nothing; preconditioning is implemented as
two streaming passes that build an O(N) shift array, applied lazily
at chunk-read time via a wrapping callable. MBAR's invariance under
per-sample shifts keeps f_k and Log_W_nk unchanged.

Constraints when u_kn is callable:
- chunk_size must be set (else ValueError)
- cache_log_W_nk=True is rejected (the cache itself is (N, K))
- n_bootstraps>0 raises NotImplementedError (needs random-sample access)
- N_k=0 states rejected (state-axis filtering would need a wrapped callable)

13 new tests in pymbar/tests/test_chunked_streaming.py covering
end-to-end equivalence vs dense, helper-level equivalence, peak memory
via tracemalloc, validation rejections, and chunk iteration. Existing
pymbar suite (129 passed, 21 skipped, 21 xfailed) unchanged.
@jamesoliverh jamesoliverh changed the title Chunked working-array path for large-N MBAR fits (refs #574) Chunked working-array path + streaming u_kn for large-N MBAR fits (closes #574) Apr 26, 2026
The chunked working-array helpers in _chunked.py run Python loops over
chunks with mutable numpy output buffers. They cannot be traced inside
an outer jax.jit: under the 'adaptive' solver, MBAR(..., chunk_size=N)
either fails with ConcretizationTypeError on a cold JAX cache, or
silently reuses the cached dense jaxpr on a warm cache (since
_CHUNK_SIZE is a Python global that jit cannot key on).

Bypass the outer jit in staggered_jit when _CHUNK_SIZE is not None so
the chunked dispatch in mbar_gradient/mbar_objective/etc. actually
runs.

Regression test forces solver_protocol=(adaptive,) and clears the JAX
cache so the trace path is exercised; DEFAULT_SOLVER_PROTOCOL prefers
'hybr' first which would otherwise mask the bug on small fixtures.
…nk kernels

Extends the chunked path choderalab#574 introduced for the fit to post-fit batch
evaluation, and adds JAX-jit'd per-chunk inner kernels shared by every
chunked helper. Numpy fallback preserved when JAX is unavailable or
PYMBAR_DISABLE_JAX is set.

Adds chunk_size kwarg to compute_expectations,
compute_multiple_expectations, compute_perturbed_free_energies; replaces
the per-target Python for loops in compute_expectations_inner with one
vectorised BLAS pass per chunk via two new helpers
(chunked_log_numerator_l, chunked_log_observable). The dense
(N, K+NL+S) Log_W_nk allocation is skipped on the chunked branch.
Constraints: chunk_size requires compute_uncertainty=False (Theta needs
the full Log_W_nk) and uncertainty_method != 'bootstrap'.

Tests parametrised over chunk_size; existing dense path bit-for-bit
unchanged.
@mrshirts
Copy link
Copy Markdown
Collaborator

Being able to deal with larger matrices could be important. A few thoughts:

  • I am worried about making MBAR less maintainable. There are limited (near zero?) resources for maintenance and we are behind as it is on robustness issues. This would complicate the code in a way that would likely make it less maintainable
  • The PR reads as though it was AI-generated (Formatting, sentence structure, level of detail). I'd like @jamesoliverh to comment a bit more on this and walk through the logic? I wasn't able to figure out who this is with 5 min on his repositories.
  • The "right" way to do this is to not have to construct the entire MBAR u_kn matrix, but stitch together overlapping locally solved states, but that work has been stuck for 15 years . . .

@jamesoliverh
Copy link
Copy Markdown
Author

jamesoliverh commented Apr 26, 2026

Hi, thank you for your comments / concerns.

Indeed the PR was generated with the help of Opus 4.7. I trust that's not an issue?

I'm actively using this package for a project and used Opus to draft an improvement that gets me around a significant bottleneck and imho this represents a significant improvement of the functionality and ability to scale to larger problems, centred around chunking a) the internal operations, for significantly less memory use and (surprisingly, maybe) increased speed, and b) the u_nk matrix itself.

I'm also working on the JAX implementation and vectorised batch evaluation post-fit with chunking (again a significant improvement in my workflow), to be added shortly.

I'm open to all feedback and will iterate on it. If you want to know more about me, my Google Scholar is here:

https://scholar.google.com/citations?user=vYV9IgQAAAAJ&hl=en

and LinkedIn is here:

https://www.linkedin.com/in/james-hamp-327260126?utm_source=share&utm_campaign=share_via&utm_content=profile&utm_medium=ios_app

(not sure what else you might be looking for in terms of info about me).

I have contributed (pre Gen AI!) to numpy and various other libs.

Thanks very much for commenting and really look forward to working together to get this functionality into the lib.

… messages, parametrise state_dependent test

Pure cleanup, no behaviour change. -40 net LOC across the eval helpers,
the chunked branch in compute_expectations_inner, and the integration
test file. Tests still 131 pass / 9 xfailed pre-existing.
@jamesoliverh jamesoliverh changed the title Chunked working-array path + streaming u_kn for large-N MBAR fits (closes #574) Chunked + JAX-jit'd MBAR with vectorised batch evaluation Apr 26, 2026
@mrshirts
Copy link
Copy Markdown
Collaborator

Thanks for the information on the background!

Indeed the PR was generated with the help of Opus 4.7. I trust that's not an issue?

I don't there's an issue per se, but there is the worry that if the submitter may not understand the logic, and there could be some underlying changes that we don't realize that are happening, since AI can occasionally hallucinate, as I have certainly see in my own reviewing of code by Copilot.

We do need an official AI code policy!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants