Skip to content

Conversation

@visvig
Copy link

@visvig visvig commented Dec 6, 2025

Parallel tempering support for EBMs using multiple tempered block-Gibbs chains + swap proposals

  • Uses core sampler code as-is
  • Includes a basic test

return new_states, new_sampler_states, accept_counts, attempt_counts


def parallel_tempering(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think parallel tempering is an interesting potential to add (I've looked at works such as https://arxiv.org/pdf/1905.02939 which I think could be really exciting), however, maybe we can think more how to best integrate it. Specifically, how to best work with parallel tempering within the graphical model framework. Granted I don't think it will inherit from conditionalsampler but perhaps there is a different inheritance line to follow down? What are your thoughts? Presumably these sort of "second order" samplers could have a well designed pattern.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this initial PR, I kept parallel tempering as a standalone utility that composes existing BlockSamplingPrograms without modifying the sampler hierarchy.

I agree that a more formal integration (e.g., defining a second-order sampler abstraction or an inheritance path distinct from ConditionalSampler) would make sense longer-term. Before restructuring, I’d love your thoughts on where this fits best in THRML’s sampler architecture.

Should parallel tempering live as:

  • a separate sampler type (similar to MCMC wrappers), or
  • an orchestration layer around existing samplers?

Happy to iterate on a design that fits the broader framework.


def parallel_tempering(
key,
ebms: Sequence[AbstractEBM],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feels like we should have a high level wrapper that would just sample from the EBM and accept some beta type parameters

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense. The current function requires EBMs + BlockSamplingPrograms explicitly, but a high-level wrapper that accepts an EBM and a sequence of betas, constructs the tempered models/programs internally, and exposes a simple .sample() API would be much cleaner.

Once we align on placement and expected API shape, I can add it as a follow-up PR.

swap_key = keys[-1]

# Gibbs updates per chain
for i in range(len(ebms)):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this double for loop seems painful for compilation time, I bet at least one of these could be converted to a scan

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, the nested loop is not ideal for compilation time indeed. Let me figure out how to JAXify.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants