From e328bf479673a7e6b0d7e3ea2767bcf4a3f68b8c Mon Sep 17 00:00:00 2001 From: vis Date: Sat, 6 Dec 2025 10:02:18 +0530 Subject: [PATCH 1/3] parallel tempering for ebm samplers --- tests/test_tempering.py | 62 +++++++++++++ thrml/__init__.py | 1 + thrml/tempering.py | 188 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 251 insertions(+) create mode 100644 tests/test_tempering.py create mode 100644 thrml/tempering.py diff --git a/tests/test_tempering.py b/tests/test_tempering.py new file mode 100644 index 0000000..b13e8c6 --- /dev/null +++ b/tests/test_tempering.py @@ -0,0 +1,62 @@ +import jax +import jax.numpy as jnp + +from thrml import Block, SpinNode, make_empty_block_state +from thrml.models import IsingEBM, IsingSamplingProgram +from thrml.tempering import parallel_tempering + + +def _tiny_ising(): + # 2x2 torus + grid = [[SpinNode() for _ in range(2)] for _ in range(2)] + nodes = [n for row in grid for n in row] + edges = [] + for i in range(2): + for j in range(2): + n = grid[i][j] + edges.append((n, grid[i][(j + 1) % 2])) # right + edges.append((n, grid[(i + 1) % 2][j])) # down + # Two-coloring + even_nodes = [grid[i][j] for i in range(2) for j in range(2) if (i + j) % 2 == 0] + odd_nodes = [grid[i][j] for i in range(2) for j in range(2) if (i + j) % 2 == 1] + free_blocks = [Block(even_nodes), Block(odd_nodes)] + return nodes, edges, free_blocks + + +def test_parallel_tempering_smoke(): + nodes, edges, free_blocks = _tiny_ising() + biases = jnp.zeros((len(nodes),)) + weights = jnp.zeros((len(edges),)) + + # Two temperatures; energies are zero so swaps should always accept + ebm_cold = IsingEBM(nodes, edges, biases, weights, jnp.array(1.0)) + ebm_hot = IsingEBM(nodes, edges, biases, weights, jnp.array(0.5)) + programs = [ + IsingSamplingProgram(ebm_cold, free_blocks, clamped_blocks=[]), + IsingSamplingProgram(ebm_hot, free_blocks, clamped_blocks=[]), + ] + + init_state = make_empty_block_state(free_blocks, ebm_cold.node_shape_dtypes) + init_states = [init_state, init_state] + + key = jax.random.key(0) + final_states, sampler_states, stats = parallel_tempering( + key, + [ebm_cold, ebm_hot], + programs, + init_states, + clamp_state=[], + n_rounds=2, + gibbs_steps_per_round=1, + ) + + assert len(final_states) == 2 + assert len(sampler_states) == 2 + + # One adjacent pair + assert stats["accepted"].shape == (1,) + assert stats["attempted"].shape == (1,) + assert stats["acceptance_rate"].shape == (1,) + assert stats["accepted"][0] == 1 + assert stats["attempted"][0] == 1 + assert stats["acceptance_rate"][0] == 1.0 diff --git a/thrml/__init__.py b/thrml/__init__.py index 8996d0b..91f5e30 100644 --- a/thrml/__init__.py +++ b/thrml/__init__.py @@ -29,5 +29,6 @@ from .pgm import AbstractNode as AbstractNode from .pgm import CategoricalNode as CategoricalNode from .pgm import SpinNode as SpinNode +from .tempering import parallel_tempering as parallel_tempering __version__ = importlib.metadata.version("thrml") diff --git a/thrml/tempering.py b/thrml/tempering.py new file mode 100644 index 0000000..a28b624 --- /dev/null +++ b/thrml/tempering.py @@ -0,0 +1,188 @@ +""" +Parallel tempering utilities built on THRML's block Gibbs samplers. + +This module orchestrates multiple tempered chains (one per beta/model), runs +blocked Gibbs steps on each, and proposes swaps between adjacent temperatures. +It does not alter core sampling code; it simply composes existing +`BlockSamplingProgram`s. +""" + +from typing import Sequence + +import jax +import jax.numpy as jnp + +from thrml.block_sampling import BlockSamplingProgram, sample_blocks +from thrml.conditional_samplers import AbstractConditionalSampler +from thrml.models.ebm import AbstractEBM + + +def _init_sampler_states(program: BlockSamplingProgram): + """Initialize sampler state list for a BlockSamplingProgram.""" + return jax.tree.map( + lambda x: x.init(), + program.samplers, + is_leaf=lambda a: isinstance(a, AbstractConditionalSampler), + ) + + +def _gibbs_steps( + key, + program: BlockSamplingProgram, + state_free: list, + state_clamp: list, + sampler_state: list, + n_iters: int, +) -> tuple[list, list]: + """Run n_iters block-Gibbs sweeps for a single chain.""" + if n_iters == 0: + return state_free, sampler_state + + keys = jax.random.split(key, n_iters) + for k in keys: + state_free, sampler_state = sample_blocks(k, state_free, state_clamp, program, sampler_state) + return state_free, sampler_state + + +def _attempt_swap_pair( + key, + ebm_i: AbstractEBM, + ebm_j: AbstractEBM, + program_i: BlockSamplingProgram, + program_j: BlockSamplingProgram, + state_i: list, + state_j: list, + clamp_state: list, +): + """ + Propose a swap between two adjacent temperature chains (i, j). + + Acceptance ratio uses energies of both states under both models: + log r = (E_i(x_i) + E_j(x_j)) - (E_i(x_j) + E_j(x_i)) + """ + blocks_i = program_i.gibbs_spec.blocks + blocks_j = program_j.gibbs_spec.blocks + + Ei_xi = ebm_i.energy(state_i + clamp_state, blocks_i) + Ej_xj = ebm_j.energy(state_j + clamp_state, blocks_j) + Ei_xj = ebm_i.energy(state_j + clamp_state, blocks_i) + Ej_xi = ebm_j.energy(state_i + clamp_state, blocks_j) + + log_r = (Ei_xi + Ej_xj) - (Ei_xj + Ej_xi) + accept_prob = jnp.exp(jnp.minimum(0.0, log_r)) + accept = jax.random.uniform(key) < accept_prob + + if bool(accept): + return state_j, state_i, True + return state_i, state_j, False + + +def _swap_pass( + key, + ebms: Sequence[AbstractEBM], + programs: Sequence[BlockSamplingProgram], + states: list[list], + sampler_states: list[list], + clamp_state: list, + offset: int, +): + """Perform one swap pass over adjacent pairs with a given offset (0: (0,1),(2,3)...; 1: (1,2),(3,4)...).""" + n_pairs = len(ebms) - 1 + accept_counts = [0] * n_pairs + attempt_counts = [0] * n_pairs + + pair_indices = list(range(offset, n_pairs, 2)) + if len(pair_indices) == 0: + return states, sampler_states, accept_counts, attempt_counts + + keys = jax.random.split(key, len(pair_indices)) + new_states = list(states) + new_sampler_states = list(sampler_states) + + for idx, pair in enumerate(pair_indices): + i, j = pair, pair + 1 + attempt_counts[pair] = 1 + new_i, new_j, accepted = _attempt_swap_pair( + keys[idx], ebms[i], ebms[j], programs[i], programs[j], new_states[i], new_states[j], clamp_state + ) + if accepted: + accept_counts[pair] = 1 + new_states[i], new_states[j] = new_i, new_j + new_sampler_states[i], new_sampler_states[j] = new_sampler_states[j], new_sampler_states[i] + + return new_states, new_sampler_states, accept_counts, attempt_counts + + +def parallel_tempering( + key, + ebms: Sequence[AbstractEBM], + programs: Sequence[BlockSamplingProgram], + init_states: Sequence[list], + clamp_state: list, + n_rounds: int, + gibbs_steps_per_round: int, + sampler_states: Sequence[list] | None = None, +): + """ + Run parallel tempering with alternating swap passes across adjacent temperatures. + + Args: + key: JAX PRNG key. + ebms: Sequence of models ordered from coldest (highest beta) to hottest (lowest beta). + programs: Matching BlockSamplingPrograms (one per model). All must share the same block layout. + init_states: Initial free-block states for each chain (one list per program). + clamp_state: State for clamped blocks (shared across chains). + n_rounds: Number of outer iterations; each round does Gibbs updates then a swap pass. + gibbs_steps_per_round: Number of block Gibbs sweeps per chain before proposing swaps. + sampler_states: Optional initial sampler states (one per chain). Defaults to sampler.init(). + + Returns: + final_states: list of free-block states for each chain. + final_sampler_states: list of sampler states for each chain. + stats: dict with 'accepted', 'attempted', and 'acceptance_rate' arrays for adjacent pairs. + """ + if not (len(ebms) == len(programs) == len(init_states)): + raise ValueError("ebms, programs, and init_states must have the same length.") + if sampler_states is not None and len(sampler_states) != len(programs): + raise ValueError("sampler_states must match length of programs if provided.") + + base_spec = programs[0].gibbs_spec + base_free = len(base_spec.free_blocks) + base_clamped = len(base_spec.clamped_blocks) + for prog in programs[1:]: + if len(prog.gibbs_spec.free_blocks) != base_free or len(prog.gibbs_spec.clamped_blocks) != base_clamped: + raise ValueError("All programs must share the same block structure (free + clamped blocks).") + + clamp_state = clamp_state or [] + states = [list(s) for s in init_states] + sampler_states = ( + [list(s) for s in sampler_states] if sampler_states is not None else [_init_sampler_states(p) for p in programs] + ) + + n_pairs = max(len(ebms) - 1, 0) + accepted = jnp.zeros((n_pairs,), dtype=jnp.int32) + attempted = jnp.zeros((n_pairs,), dtype=jnp.int32) + + for round_idx in range(n_rounds): + key, key_round = jax.random.split(key) + keys = jax.random.split(key_round, len(ebms) + 1) + swap_key = keys[-1] + + # Gibbs updates per chain + for i in range(len(ebms)): + states[i], sampler_states[i] = _gibbs_steps( + keys[i], programs[i], states[i], clamp_state, sampler_states[i], gibbs_steps_per_round + ) + + # Swap pass (alternate even/odd pairing each round) + offset = round_idx % 2 + states, sampler_states, acc_counts, att_counts = _swap_pass( + swap_key, ebms, programs, states, sampler_states, clamp_state, offset + ) + accepted = accepted + jnp.array(acc_counts, dtype=jnp.int32) + attempted = attempted + jnp.array(att_counts, dtype=jnp.int32) + + acceptance_rate = jnp.where(attempted > 0, accepted / attempted, 0.0) + stats = {"accepted": accepted, "attempted": attempted, "acceptance_rate": acceptance_rate} + + return states, sampler_states, stats From 54441769fa4335886fb9eceec1ed80bb0c730bf0 Mon Sep 17 00:00:00 2001 From: vis Date: Wed, 10 Dec 2025 08:19:00 +0530 Subject: [PATCH 2/3] JIT-safe parallel tempering and smoke test --- tests/test_tempering.py | 22 ++++++++++-------- thrml/tempering.py | 49 ++++++++++++++++++++--------------------- 2 files changed, 37 insertions(+), 34 deletions(-) diff --git a/tests/test_tempering.py b/tests/test_tempering.py index b13e8c6..d648c55 100644 --- a/tests/test_tempering.py +++ b/tests/test_tempering.py @@ -39,16 +39,20 @@ def test_parallel_tempering_smoke(): init_state = make_empty_block_state(free_blocks, ebm_cold.node_shape_dtypes) init_states = [init_state, init_state] + @jax.jit + def run(key, init_states): + return parallel_tempering( + key, + [ebm_cold, ebm_hot], + programs, + init_states, + clamp_state=[], + n_rounds=2, + gibbs_steps_per_round=1, + ) + key = jax.random.key(0) - final_states, sampler_states, stats = parallel_tempering( - key, - [ebm_cold, ebm_hot], - programs, - init_states, - clamp_state=[], - n_rounds=2, - gibbs_steps_per_round=1, - ) + final_states, sampler_states, stats = run(key, init_states) assert len(final_states) == 2 assert len(sampler_states) == 2 diff --git a/thrml/tempering.py b/thrml/tempering.py index a28b624..abb0bec 100644 --- a/thrml/tempering.py +++ b/thrml/tempering.py @@ -11,6 +11,7 @@ import jax import jax.numpy as jnp +from jax import lax from thrml.block_sampling import BlockSamplingProgram, sample_blocks from thrml.conditional_samplers import AbstractConditionalSampler @@ -70,11 +71,19 @@ def _attempt_swap_pair( log_r = (Ei_xi + Ej_xj) - (Ei_xj + Ej_xi) accept_prob = jnp.exp(jnp.minimum(0.0, log_r)) - accept = jax.random.uniform(key) < accept_prob + u = jax.random.uniform(key) - if bool(accept): - return state_j, state_i, True - return state_i, state_j, False + def do_swap(states): + s_i, s_j = states + # swap and mark accepted + return s_j, s_i, jnp.int32(1) + + def no_swap(states): + s_i, s_j = states + # keep as-is and mark rejected + return s_i, s_j, jnp.int32(0) + + return lax.cond(u < accept_prob, do_swap, no_swap, (state_i, state_j)) def _swap_pass( @@ -105,10 +114,12 @@ def _swap_pass( new_i, new_j, accepted = _attempt_swap_pair( keys[idx], ebms[i], ebms[j], programs[i], programs[j], new_states[i], new_states[j], clamp_state ) - if accepted: - accept_counts[pair] = 1 - new_states[i], new_states[j] = new_i, new_j - new_sampler_states[i], new_sampler_states[j] = new_sampler_states[j], new_sampler_states[i] + + # states already come swapped or not from _attempt_swap_pair + new_states[i], new_states[j] = new_i, new_j + # sampler states follow the same swap pattern + new_sampler_states[i], new_sampler_states[j] = new_sampler_states[j], new_sampler_states[i] + accept_counts[pair] = accepted return new_states, new_sampler_states, accept_counts, attempt_counts @@ -123,23 +134,11 @@ def parallel_tempering( gibbs_steps_per_round: int, sampler_states: Sequence[list] | None = None, ): - """ - Run parallel tempering with alternating swap passes across adjacent temperatures. - - Args: - key: JAX PRNG key. - ebms: Sequence of models ordered from coldest (highest beta) to hottest (lowest beta). - programs: Matching BlockSamplingPrograms (one per model). All must share the same block layout. - init_states: Initial free-block states for each chain (one list per program). - clamp_state: State for clamped blocks (shared across chains). - n_rounds: Number of outer iterations; each round does Gibbs updates then a swap pass. - gibbs_steps_per_round: Number of block Gibbs sweeps per chain before proposing swaps. - sampler_states: Optional initial sampler states (one per chain). Defaults to sampler.init(). - - Returns: - final_states: list of free-block states for each chain. - final_sampler_states: list of sampler states for each chain. - stats: dict with 'accepted', 'attempted', and 'acceptance_rate' arrays for adjacent pairs. + """Run parallel tempering across a sequence of tempered chains. + + Each round performs block Gibbs updates in every chain, then proposes + swaps between adjacent temperatures. All chains share the same + block structure and clamped state layout. """ if not (len(ebms) == len(programs) == len(init_states)): raise ValueError("ebms, programs, and init_states must have the same length.") From 7d80bc3dd72fe358cd9ccaf5926ca7f21cf1f450 Mon Sep 17 00:00:00 2001 From: vis Date: Thu, 11 Dec 2025 08:35:43 +0530 Subject: [PATCH 3/3] refactor parallel_tempering using lax.scan --- thrml/tempering.py | 80 ++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 66 insertions(+), 14 deletions(-) diff --git a/thrml/tempering.py b/thrml/tempering.py index abb0bec..12e3fe4 100644 --- a/thrml/tempering.py +++ b/thrml/tempering.py @@ -93,14 +93,13 @@ def _swap_pass( states: list[list], sampler_states: list[list], clamp_state: list, - offset: int, + pair_indices: Sequence[int], ): - """Perform one swap pass over adjacent pairs with a given offset (0: (0,1),(2,3)...; 1: (1,2),(3,4)...).""" + """Perform one swap pass over a fixed set of adjacent pairs.""" n_pairs = len(ebms) - 1 accept_counts = [0] * n_pairs attempt_counts = [0] * n_pairs - pair_indices = list(range(offset, n_pairs, 2)) if len(pair_indices) == 0: return states, sampler_states, accept_counts, attempt_counts @@ -114,7 +113,6 @@ def _swap_pass( new_i, new_j, accepted = _attempt_swap_pair( keys[idx], ebms[i], ebms[j], programs[i], programs[j], new_states[i], new_states[j], clamp_state ) - # states already come swapped or not from _attempt_swap_pair new_states[i], new_states[j] = new_i, new_j # sampler states follow the same swap pattern @@ -162,26 +160,80 @@ def parallel_tempering( accepted = jnp.zeros((n_pairs,), dtype=jnp.int32) attempted = jnp.zeros((n_pairs,), dtype=jnp.int32) - for round_idx in range(n_rounds): + # Precompute adjacent pair indices for even and odd rounds. + even_pair_indices = list(range(0, n_pairs, 2)) + odd_pair_indices = list(range(1, n_pairs, 2)) + + def one_round(carry, round_idx): + key, states, sampler_states, accepted, attempted = carry + + # Keys for this round key, key_round = jax.random.split(key) keys = jax.random.split(key_round, len(ebms) + 1) swap_key = keys[-1] - # Gibbs updates per chain + # Gibbs updates per chain (number of chains is static) for i in range(len(ebms)): states[i], sampler_states[i] = _gibbs_steps( - keys[i], programs[i], states[i], clamp_state, sampler_states[i], gibbs_steps_per_round + keys[i], + programs[i], + states[i], + clamp_state, + sampler_states[i], + gibbs_steps_per_round, ) - # Swap pass (alternate even/odd pairing each round) - offset = round_idx % 2 - states, sampler_states, acc_counts, att_counts = _swap_pass( - swap_key, ebms, programs, states, sampler_states, clamp_state, offset + def do_even(args): + states, sampler_states, accepted, attempted, swap_key = args + states, sampler_states, acc_counts, att_counts = _swap_pass( + swap_key, + ebms, + programs, + states, + sampler_states, + clamp_state, + even_pair_indices, + ) + accepted = accepted + jnp.array(acc_counts, dtype=jnp.int32) + attempted = attempted + jnp.array(att_counts, dtype=jnp.int32) + return states, sampler_states, accepted, attempted + + def do_odd(args): + states, sampler_states, accepted, attempted, swap_key = args + states, sampler_states, acc_counts, att_counts = _swap_pass( + swap_key, + ebms, + programs, + states, + sampler_states, + clamp_state, + odd_pair_indices, + ) + accepted = accepted + jnp.array(acc_counts, dtype=jnp.int32) + attempted = attempted + jnp.array(att_counts, dtype=jnp.int32) + return states, sampler_states, accepted, attempted + + parity = round_idx & 1 + states, sampler_states, accepted, attempted = lax.cond( + parity == 0, + do_even, + do_odd, + (states, sampler_states, accepted, attempted, swap_key), ) - accepted = accepted + jnp.array(acc_counts, dtype=jnp.int32) - attempted = attempted + jnp.array(att_counts, dtype=jnp.int32) + + new_carry = (key, states, sampler_states, accepted, attempted) + return new_carry, None + + if n_rounds > 0: + init_carry = (key, states, sampler_states, accepted, attempted) + final_carry, _ = lax.scan(one_round, init_carry, jnp.arange(n_rounds)) + key, states, sampler_states, accepted, attempted = final_carry acceptance_rate = jnp.where(attempted > 0, accepted / attempted, 0.0) - stats = {"accepted": accepted, "attempted": attempted, "acceptance_rate": acceptance_rate} + stats = { + "accepted": accepted, + "attempted": attempted, + "acceptance_rate": acceptance_rate, + } return states, sampler_states, stats