-
Notifications
You must be signed in to change notification settings - Fork 120
Parallel tempering for ebm samplers #30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,66 @@ | ||
| import jax | ||
| import jax.numpy as jnp | ||
|
|
||
| from thrml import Block, SpinNode, make_empty_block_state | ||
| from thrml.models import IsingEBM, IsingSamplingProgram | ||
| from thrml.tempering import parallel_tempering | ||
|
|
||
|
|
||
| def _tiny_ising(): | ||
| # 2x2 torus | ||
| grid = [[SpinNode() for _ in range(2)] for _ in range(2)] | ||
| nodes = [n for row in grid for n in row] | ||
| edges = [] | ||
| for i in range(2): | ||
| for j in range(2): | ||
| n = grid[i][j] | ||
| edges.append((n, grid[i][(j + 1) % 2])) # right | ||
| edges.append((n, grid[(i + 1) % 2][j])) # down | ||
| # Two-coloring | ||
| even_nodes = [grid[i][j] for i in range(2) for j in range(2) if (i + j) % 2 == 0] | ||
| odd_nodes = [grid[i][j] for i in range(2) for j in range(2) if (i + j) % 2 == 1] | ||
| free_blocks = [Block(even_nodes), Block(odd_nodes)] | ||
| return nodes, edges, free_blocks | ||
|
|
||
|
|
||
| def test_parallel_tempering_smoke(): | ||
| nodes, edges, free_blocks = _tiny_ising() | ||
| biases = jnp.zeros((len(nodes),)) | ||
| weights = jnp.zeros((len(edges),)) | ||
|
|
||
| # Two temperatures; energies are zero so swaps should always accept | ||
| ebm_cold = IsingEBM(nodes, edges, biases, weights, jnp.array(1.0)) | ||
| ebm_hot = IsingEBM(nodes, edges, biases, weights, jnp.array(0.5)) | ||
| programs = [ | ||
| IsingSamplingProgram(ebm_cold, free_blocks, clamped_blocks=[]), | ||
| IsingSamplingProgram(ebm_hot, free_blocks, clamped_blocks=[]), | ||
| ] | ||
|
|
||
| init_state = make_empty_block_state(free_blocks, ebm_cold.node_shape_dtypes) | ||
| init_states = [init_state, init_state] | ||
|
|
||
| @jax.jit | ||
| def run(key, init_states): | ||
| return parallel_tempering( | ||
| key, | ||
| [ebm_cold, ebm_hot], | ||
| programs, | ||
| init_states, | ||
| clamp_state=[], | ||
| n_rounds=2, | ||
| gibbs_steps_per_round=1, | ||
| ) | ||
|
|
||
| key = jax.random.key(0) | ||
| final_states, sampler_states, stats = run(key, init_states) | ||
|
|
||
| assert len(final_states) == 2 | ||
| assert len(sampler_states) == 2 | ||
|
|
||
| # One adjacent pair | ||
| assert stats["accepted"].shape == (1,) | ||
| assert stats["attempted"].shape == (1,) | ||
| assert stats["acceptance_rate"].shape == (1,) | ||
| assert stats["accepted"][0] == 1 | ||
| assert stats["attempted"][0] == 1 | ||
| assert stats["acceptance_rate"][0] == 1.0 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,239 @@ | ||
| """ | ||
| Parallel tempering utilities built on THRML's block Gibbs samplers. | ||
|
|
||
| This module orchestrates multiple tempered chains (one per beta/model), runs | ||
| blocked Gibbs steps on each, and proposes swaps between adjacent temperatures. | ||
| It does not alter core sampling code; it simply composes existing | ||
| `BlockSamplingProgram`s. | ||
| """ | ||
|
|
||
| from typing import Sequence | ||
|
|
||
| import jax | ||
| import jax.numpy as jnp | ||
| from jax import lax | ||
|
|
||
| from thrml.block_sampling import BlockSamplingProgram, sample_blocks | ||
| from thrml.conditional_samplers import AbstractConditionalSampler | ||
| from thrml.models.ebm import AbstractEBM | ||
|
|
||
|
|
||
| def _init_sampler_states(program: BlockSamplingProgram): | ||
| """Initialize sampler state list for a BlockSamplingProgram.""" | ||
| return jax.tree.map( | ||
| lambda x: x.init(), | ||
| program.samplers, | ||
| is_leaf=lambda a: isinstance(a, AbstractConditionalSampler), | ||
| ) | ||
|
|
||
|
|
||
| def _gibbs_steps( | ||
| key, | ||
| program: BlockSamplingProgram, | ||
| state_free: list, | ||
| state_clamp: list, | ||
| sampler_state: list, | ||
| n_iters: int, | ||
| ) -> tuple[list, list]: | ||
| """Run n_iters block-Gibbs sweeps for a single chain.""" | ||
| if n_iters == 0: | ||
| return state_free, sampler_state | ||
|
|
||
| keys = jax.random.split(key, n_iters) | ||
| for k in keys: | ||
| state_free, sampler_state = sample_blocks(k, state_free, state_clamp, program, sampler_state) | ||
| return state_free, sampler_state | ||
|
|
||
|
|
||
| def _attempt_swap_pair( | ||
| key, | ||
| ebm_i: AbstractEBM, | ||
| ebm_j: AbstractEBM, | ||
| program_i: BlockSamplingProgram, | ||
| program_j: BlockSamplingProgram, | ||
| state_i: list, | ||
| state_j: list, | ||
| clamp_state: list, | ||
| ): | ||
| """ | ||
| Propose a swap between two adjacent temperature chains (i, j). | ||
|
|
||
| Acceptance ratio uses energies of both states under both models: | ||
| log r = (E_i(x_i) + E_j(x_j)) - (E_i(x_j) + E_j(x_i)) | ||
| """ | ||
| blocks_i = program_i.gibbs_spec.blocks | ||
| blocks_j = program_j.gibbs_spec.blocks | ||
|
|
||
| Ei_xi = ebm_i.energy(state_i + clamp_state, blocks_i) | ||
| Ej_xj = ebm_j.energy(state_j + clamp_state, blocks_j) | ||
| Ei_xj = ebm_i.energy(state_j + clamp_state, blocks_i) | ||
| Ej_xi = ebm_j.energy(state_i + clamp_state, blocks_j) | ||
|
|
||
| log_r = (Ei_xi + Ej_xj) - (Ei_xj + Ej_xi) | ||
| accept_prob = jnp.exp(jnp.minimum(0.0, log_r)) | ||
| u = jax.random.uniform(key) | ||
|
|
||
| def do_swap(states): | ||
| s_i, s_j = states | ||
| # swap and mark accepted | ||
| return s_j, s_i, jnp.int32(1) | ||
|
|
||
| def no_swap(states): | ||
| s_i, s_j = states | ||
| # keep as-is and mark rejected | ||
| return s_i, s_j, jnp.int32(0) | ||
|
|
||
| return lax.cond(u < accept_prob, do_swap, no_swap, (state_i, state_j)) | ||
|
|
||
|
|
||
| def _swap_pass( | ||
| key, | ||
| ebms: Sequence[AbstractEBM], | ||
| programs: Sequence[BlockSamplingProgram], | ||
| states: list[list], | ||
| sampler_states: list[list], | ||
| clamp_state: list, | ||
| pair_indices: Sequence[int], | ||
| ): | ||
| """Perform one swap pass over a fixed set of adjacent pairs.""" | ||
| n_pairs = len(ebms) - 1 | ||
| accept_counts = [0] * n_pairs | ||
| attempt_counts = [0] * n_pairs | ||
|
|
||
| if len(pair_indices) == 0: | ||
| return states, sampler_states, accept_counts, attempt_counts | ||
|
|
||
| keys = jax.random.split(key, len(pair_indices)) | ||
| new_states = list(states) | ||
| new_sampler_states = list(sampler_states) | ||
|
|
||
| for idx, pair in enumerate(pair_indices): | ||
| i, j = pair, pair + 1 | ||
| attempt_counts[pair] = 1 | ||
| new_i, new_j, accepted = _attempt_swap_pair( | ||
| keys[idx], ebms[i], ebms[j], programs[i], programs[j], new_states[i], new_states[j], clamp_state | ||
| ) | ||
| # states already come swapped or not from _attempt_swap_pair | ||
| new_states[i], new_states[j] = new_i, new_j | ||
| # sampler states follow the same swap pattern | ||
| new_sampler_states[i], new_sampler_states[j] = new_sampler_states[j], new_sampler_states[i] | ||
| accept_counts[pair] = accepted | ||
|
|
||
| return new_states, new_sampler_states, accept_counts, attempt_counts | ||
|
|
||
|
|
||
| def parallel_tempering( | ||
| key, | ||
| ebms: Sequence[AbstractEBM], | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Feels like we should have a high level wrapper that would just sample from the EBM and accept some beta type parameters
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That makes sense. The current function requires EBMs + BlockSamplingPrograms explicitly, but a high-level wrapper that accepts an EBM and a sequence of betas, constructs the tempered models/programs internally, and exposes a simple Once we align on placement and expected API shape, I can add it as a follow-up PR. |
||
| programs: Sequence[BlockSamplingProgram], | ||
| init_states: Sequence[list], | ||
| clamp_state: list, | ||
| n_rounds: int, | ||
| gibbs_steps_per_round: int, | ||
| sampler_states: Sequence[list] | None = None, | ||
| ): | ||
| """Run parallel tempering across a sequence of tempered chains. | ||
|
|
||
| Each round performs block Gibbs updates in every chain, then proposes | ||
| swaps between adjacent temperatures. All chains share the same | ||
| block structure and clamped state layout. | ||
| """ | ||
| if not (len(ebms) == len(programs) == len(init_states)): | ||
| raise ValueError("ebms, programs, and init_states must have the same length.") | ||
| if sampler_states is not None and len(sampler_states) != len(programs): | ||
| raise ValueError("sampler_states must match length of programs if provided.") | ||
|
|
||
| base_spec = programs[0].gibbs_spec | ||
| base_free = len(base_spec.free_blocks) | ||
| base_clamped = len(base_spec.clamped_blocks) | ||
| for prog in programs[1:]: | ||
| if len(prog.gibbs_spec.free_blocks) != base_free or len(prog.gibbs_spec.clamped_blocks) != base_clamped: | ||
| raise ValueError("All programs must share the same block structure (free + clamped blocks).") | ||
|
|
||
| clamp_state = clamp_state or [] | ||
| states = [list(s) for s in init_states] | ||
| sampler_states = ( | ||
| [list(s) for s in sampler_states] if sampler_states is not None else [_init_sampler_states(p) for p in programs] | ||
| ) | ||
|
|
||
| n_pairs = max(len(ebms) - 1, 0) | ||
| accepted = jnp.zeros((n_pairs,), dtype=jnp.int32) | ||
| attempted = jnp.zeros((n_pairs,), dtype=jnp.int32) | ||
|
|
||
| # Precompute adjacent pair indices for even and odd rounds. | ||
| even_pair_indices = list(range(0, n_pairs, 2)) | ||
| odd_pair_indices = list(range(1, n_pairs, 2)) | ||
|
|
||
| def one_round(carry, round_idx): | ||
| key, states, sampler_states, accepted, attempted = carry | ||
|
|
||
| # Keys for this round | ||
| key, key_round = jax.random.split(key) | ||
| keys = jax.random.split(key_round, len(ebms) + 1) | ||
| swap_key = keys[-1] | ||
|
|
||
| # Gibbs updates per chain (number of chains is static) | ||
| for i in range(len(ebms)): | ||
visvig marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| states[i], sampler_states[i] = _gibbs_steps( | ||
| keys[i], | ||
| programs[i], | ||
| states[i], | ||
| clamp_state, | ||
| sampler_states[i], | ||
| gibbs_steps_per_round, | ||
| ) | ||
|
|
||
| def do_even(args): | ||
| states, sampler_states, accepted, attempted, swap_key = args | ||
| states, sampler_states, acc_counts, att_counts = _swap_pass( | ||
| swap_key, | ||
| ebms, | ||
| programs, | ||
| states, | ||
| sampler_states, | ||
| clamp_state, | ||
| even_pair_indices, | ||
| ) | ||
| accepted = accepted + jnp.array(acc_counts, dtype=jnp.int32) | ||
| attempted = attempted + jnp.array(att_counts, dtype=jnp.int32) | ||
| return states, sampler_states, accepted, attempted | ||
|
|
||
| def do_odd(args): | ||
| states, sampler_states, accepted, attempted, swap_key = args | ||
| states, sampler_states, acc_counts, att_counts = _swap_pass( | ||
| swap_key, | ||
| ebms, | ||
| programs, | ||
| states, | ||
| sampler_states, | ||
| clamp_state, | ||
| odd_pair_indices, | ||
| ) | ||
| accepted = accepted + jnp.array(acc_counts, dtype=jnp.int32) | ||
| attempted = attempted + jnp.array(att_counts, dtype=jnp.int32) | ||
| return states, sampler_states, accepted, attempted | ||
|
|
||
| parity = round_idx & 1 | ||
| states, sampler_states, accepted, attempted = lax.cond( | ||
| parity == 0, | ||
| do_even, | ||
| do_odd, | ||
| (states, sampler_states, accepted, attempted, swap_key), | ||
| ) | ||
|
|
||
| new_carry = (key, states, sampler_states, accepted, attempted) | ||
| return new_carry, None | ||
|
|
||
| if n_rounds > 0: | ||
| init_carry = (key, states, sampler_states, accepted, attempted) | ||
| final_carry, _ = lax.scan(one_round, init_carry, jnp.arange(n_rounds)) | ||
| key, states, sampler_states, accepted, attempted = final_carry | ||
|
|
||
| acceptance_rate = jnp.where(attempted > 0, accepted / attempted, 0.0) | ||
| stats = { | ||
| "accepted": accepted, | ||
| "attempted": attempted, | ||
| "acceptance_rate": acceptance_rate, | ||
| } | ||
|
|
||
| return states, sampler_states, stats | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think parallel tempering is an interesting potential to add (I've looked at works such as https://arxiv.org/pdf/1905.02939 which I think could be really exciting), however, maybe we can think more how to best integrate it. Specifically, how to best work with parallel tempering within the graphical model framework. Granted I don't think it will inherit from conditionalsampler but perhaps there is a different inheritance line to follow down? What are your thoughts? Presumably these sort of "second order" samplers could have a well designed pattern.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For this initial PR, I kept parallel tempering as a standalone utility that composes existing BlockSamplingPrograms without modifying the sampler hierarchy.
I agree that a more formal integration (e.g., defining a second-order sampler abstraction or an inheritance path distinct from ConditionalSampler) would make sense longer-term. Before restructuring, I’d love your thoughts on where this fits best in THRML’s sampler architecture.
Should parallel tempering live as:
Happy to iterate on a design that fits the broader framework.