Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions tests/test_tempering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import jax
import jax.numpy as jnp

from thrml import Block, SpinNode, make_empty_block_state
from thrml.models import IsingEBM, IsingSamplingProgram
from thrml.tempering import parallel_tempering


def _tiny_ising():
# 2x2 torus
grid = [[SpinNode() for _ in range(2)] for _ in range(2)]
nodes = [n for row in grid for n in row]
edges = []
for i in range(2):
for j in range(2):
n = grid[i][j]
edges.append((n, grid[i][(j + 1) % 2])) # right
edges.append((n, grid[(i + 1) % 2][j])) # down
# Two-coloring
even_nodes = [grid[i][j] for i in range(2) for j in range(2) if (i + j) % 2 == 0]
odd_nodes = [grid[i][j] for i in range(2) for j in range(2) if (i + j) % 2 == 1]
free_blocks = [Block(even_nodes), Block(odd_nodes)]
return nodes, edges, free_blocks


def test_parallel_tempering_smoke():
nodes, edges, free_blocks = _tiny_ising()
biases = jnp.zeros((len(nodes),))
weights = jnp.zeros((len(edges),))

# Two temperatures; energies are zero so swaps should always accept
ebm_cold = IsingEBM(nodes, edges, biases, weights, jnp.array(1.0))
ebm_hot = IsingEBM(nodes, edges, biases, weights, jnp.array(0.5))
programs = [
IsingSamplingProgram(ebm_cold, free_blocks, clamped_blocks=[]),
IsingSamplingProgram(ebm_hot, free_blocks, clamped_blocks=[]),
]

init_state = make_empty_block_state(free_blocks, ebm_cold.node_shape_dtypes)
init_states = [init_state, init_state]

@jax.jit
def run(key, init_states):
return parallel_tempering(
key,
[ebm_cold, ebm_hot],
programs,
init_states,
clamp_state=[],
n_rounds=2,
gibbs_steps_per_round=1,
)

key = jax.random.key(0)
final_states, sampler_states, stats = run(key, init_states)

assert len(final_states) == 2
assert len(sampler_states) == 2

# One adjacent pair
assert stats["accepted"].shape == (1,)
assert stats["attempted"].shape == (1,)
assert stats["acceptance_rate"].shape == (1,)
assert stats["accepted"][0] == 1
assert stats["attempted"][0] == 1
assert stats["acceptance_rate"][0] == 1.0
1 change: 1 addition & 0 deletions thrml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,6 @@
from .pgm import AbstractNode as AbstractNode
from .pgm import CategoricalNode as CategoricalNode
from .pgm import SpinNode as SpinNode
from .tempering import parallel_tempering as parallel_tempering

__version__ = importlib.metadata.version("thrml")
239 changes: 239 additions & 0 deletions thrml/tempering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
"""
Parallel tempering utilities built on THRML's block Gibbs samplers.

This module orchestrates multiple tempered chains (one per beta/model), runs
blocked Gibbs steps on each, and proposes swaps between adjacent temperatures.
It does not alter core sampling code; it simply composes existing
`BlockSamplingProgram`s.
"""

from typing import Sequence

import jax
import jax.numpy as jnp
from jax import lax

from thrml.block_sampling import BlockSamplingProgram, sample_blocks
from thrml.conditional_samplers import AbstractConditionalSampler
from thrml.models.ebm import AbstractEBM


def _init_sampler_states(program: BlockSamplingProgram):
"""Initialize sampler state list for a BlockSamplingProgram."""
return jax.tree.map(
lambda x: x.init(),
program.samplers,
is_leaf=lambda a: isinstance(a, AbstractConditionalSampler),
)


def _gibbs_steps(
key,
program: BlockSamplingProgram,
state_free: list,
state_clamp: list,
sampler_state: list,
n_iters: int,
) -> tuple[list, list]:
"""Run n_iters block-Gibbs sweeps for a single chain."""
if n_iters == 0:
return state_free, sampler_state

keys = jax.random.split(key, n_iters)
for k in keys:
state_free, sampler_state = sample_blocks(k, state_free, state_clamp, program, sampler_state)
return state_free, sampler_state


def _attempt_swap_pair(
key,
ebm_i: AbstractEBM,
ebm_j: AbstractEBM,
program_i: BlockSamplingProgram,
program_j: BlockSamplingProgram,
state_i: list,
state_j: list,
clamp_state: list,
):
"""
Propose a swap between two adjacent temperature chains (i, j).

Acceptance ratio uses energies of both states under both models:
log r = (E_i(x_i) + E_j(x_j)) - (E_i(x_j) + E_j(x_i))
"""
blocks_i = program_i.gibbs_spec.blocks
blocks_j = program_j.gibbs_spec.blocks

Ei_xi = ebm_i.energy(state_i + clamp_state, blocks_i)
Ej_xj = ebm_j.energy(state_j + clamp_state, blocks_j)
Ei_xj = ebm_i.energy(state_j + clamp_state, blocks_i)
Ej_xi = ebm_j.energy(state_i + clamp_state, blocks_j)

log_r = (Ei_xi + Ej_xj) - (Ei_xj + Ej_xi)
accept_prob = jnp.exp(jnp.minimum(0.0, log_r))
u = jax.random.uniform(key)

def do_swap(states):
s_i, s_j = states
# swap and mark accepted
return s_j, s_i, jnp.int32(1)

def no_swap(states):
s_i, s_j = states
# keep as-is and mark rejected
return s_i, s_j, jnp.int32(0)

return lax.cond(u < accept_prob, do_swap, no_swap, (state_i, state_j))


def _swap_pass(
key,
ebms: Sequence[AbstractEBM],
programs: Sequence[BlockSamplingProgram],
states: list[list],
sampler_states: list[list],
clamp_state: list,
pair_indices: Sequence[int],
):
"""Perform one swap pass over a fixed set of adjacent pairs."""
n_pairs = len(ebms) - 1
accept_counts = [0] * n_pairs
attempt_counts = [0] * n_pairs

if len(pair_indices) == 0:
return states, sampler_states, accept_counts, attempt_counts

keys = jax.random.split(key, len(pair_indices))
new_states = list(states)
new_sampler_states = list(sampler_states)

for idx, pair in enumerate(pair_indices):
i, j = pair, pair + 1
attempt_counts[pair] = 1
new_i, new_j, accepted = _attempt_swap_pair(
keys[idx], ebms[i], ebms[j], programs[i], programs[j], new_states[i], new_states[j], clamp_state
)
# states already come swapped or not from _attempt_swap_pair
new_states[i], new_states[j] = new_i, new_j
# sampler states follow the same swap pattern
new_sampler_states[i], new_sampler_states[j] = new_sampler_states[j], new_sampler_states[i]
accept_counts[pair] = accepted

return new_states, new_sampler_states, accept_counts, attempt_counts


def parallel_tempering(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think parallel tempering is an interesting potential to add (I've looked at works such as https://arxiv.org/pdf/1905.02939 which I think could be really exciting), however, maybe we can think more how to best integrate it. Specifically, how to best work with parallel tempering within the graphical model framework. Granted I don't think it will inherit from conditionalsampler but perhaps there is a different inheritance line to follow down? What are your thoughts? Presumably these sort of "second order" samplers could have a well designed pattern.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this initial PR, I kept parallel tempering as a standalone utility that composes existing BlockSamplingPrograms without modifying the sampler hierarchy.

I agree that a more formal integration (e.g., defining a second-order sampler abstraction or an inheritance path distinct from ConditionalSampler) would make sense longer-term. Before restructuring, I’d love your thoughts on where this fits best in THRML’s sampler architecture.

Should parallel tempering live as:

  • a separate sampler type (similar to MCMC wrappers), or
  • an orchestration layer around existing samplers?

Happy to iterate on a design that fits the broader framework.

key,
ebms: Sequence[AbstractEBM],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feels like we should have a high level wrapper that would just sample from the EBM and accept some beta type parameters

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense. The current function requires EBMs + BlockSamplingPrograms explicitly, but a high-level wrapper that accepts an EBM and a sequence of betas, constructs the tempered models/programs internally, and exposes a simple .sample() API would be much cleaner.

Once we align on placement and expected API shape, I can add it as a follow-up PR.

programs: Sequence[BlockSamplingProgram],
init_states: Sequence[list],
clamp_state: list,
n_rounds: int,
gibbs_steps_per_round: int,
sampler_states: Sequence[list] | None = None,
):
"""Run parallel tempering across a sequence of tempered chains.

Each round performs block Gibbs updates in every chain, then proposes
swaps between adjacent temperatures. All chains share the same
block structure and clamped state layout.
"""
if not (len(ebms) == len(programs) == len(init_states)):
raise ValueError("ebms, programs, and init_states must have the same length.")
if sampler_states is not None and len(sampler_states) != len(programs):
raise ValueError("sampler_states must match length of programs if provided.")

base_spec = programs[0].gibbs_spec
base_free = len(base_spec.free_blocks)
base_clamped = len(base_spec.clamped_blocks)
for prog in programs[1:]:
if len(prog.gibbs_spec.free_blocks) != base_free or len(prog.gibbs_spec.clamped_blocks) != base_clamped:
raise ValueError("All programs must share the same block structure (free + clamped blocks).")

clamp_state = clamp_state or []
states = [list(s) for s in init_states]
sampler_states = (
[list(s) for s in sampler_states] if sampler_states is not None else [_init_sampler_states(p) for p in programs]
)

n_pairs = max(len(ebms) - 1, 0)
accepted = jnp.zeros((n_pairs,), dtype=jnp.int32)
attempted = jnp.zeros((n_pairs,), dtype=jnp.int32)

# Precompute adjacent pair indices for even and odd rounds.
even_pair_indices = list(range(0, n_pairs, 2))
odd_pair_indices = list(range(1, n_pairs, 2))

def one_round(carry, round_idx):
key, states, sampler_states, accepted, attempted = carry

# Keys for this round
key, key_round = jax.random.split(key)
keys = jax.random.split(key_round, len(ebms) + 1)
swap_key = keys[-1]

# Gibbs updates per chain (number of chains is static)
for i in range(len(ebms)):
states[i], sampler_states[i] = _gibbs_steps(
keys[i],
programs[i],
states[i],
clamp_state,
sampler_states[i],
gibbs_steps_per_round,
)

def do_even(args):
states, sampler_states, accepted, attempted, swap_key = args
states, sampler_states, acc_counts, att_counts = _swap_pass(
swap_key,
ebms,
programs,
states,
sampler_states,
clamp_state,
even_pair_indices,
)
accepted = accepted + jnp.array(acc_counts, dtype=jnp.int32)
attempted = attempted + jnp.array(att_counts, dtype=jnp.int32)
return states, sampler_states, accepted, attempted

def do_odd(args):
states, sampler_states, accepted, attempted, swap_key = args
states, sampler_states, acc_counts, att_counts = _swap_pass(
swap_key,
ebms,
programs,
states,
sampler_states,
clamp_state,
odd_pair_indices,
)
accepted = accepted + jnp.array(acc_counts, dtype=jnp.int32)
attempted = attempted + jnp.array(att_counts, dtype=jnp.int32)
return states, sampler_states, accepted, attempted

parity = round_idx & 1
states, sampler_states, accepted, attempted = lax.cond(
parity == 0,
do_even,
do_odd,
(states, sampler_states, accepted, attempted, swap_key),
)

new_carry = (key, states, sampler_states, accepted, attempted)
return new_carry, None

if n_rounds > 0:
init_carry = (key, states, sampler_states, accepted, attempted)
final_carry, _ = lax.scan(one_round, init_carry, jnp.arange(n_rounds))
key, states, sampler_states, accepted, attempted = final_carry

acceptance_rate = jnp.where(attempted > 0, accepted / attempted, 0.0)
stats = {
"accepted": accepted,
"attempted": attempted,
"acceptance_rate": acceptance_rate,
}

return states, sampler_states, stats