-
Notifications
You must be signed in to change notification settings - Fork 118
Parallel tempering for ebm samplers #30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
| return new_states, new_sampler_states, accept_counts, attempt_counts | ||
|
|
||
|
|
||
| def parallel_tempering( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think parallel tempering is an interesting potential to add (I've looked at works such as https://arxiv.org/pdf/1905.02939 which I think could be really exciting), however, maybe we can think more how to best integrate it. Specifically, how to best work with parallel tempering within the graphical model framework. Granted I don't think it will inherit from conditionalsampler but perhaps there is a different inheritance line to follow down? What are your thoughts? Presumably these sort of "second order" samplers could have a well designed pattern.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For this initial PR, I kept parallel tempering as a standalone utility that composes existing BlockSamplingPrograms without modifying the sampler hierarchy.
I agree that a more formal integration (e.g., defining a second-order sampler abstraction or an inheritance path distinct from ConditionalSampler) would make sense longer-term. Before restructuring, I’d love your thoughts on where this fits best in THRML’s sampler architecture.
Should parallel tempering live as:
- a separate sampler type (similar to MCMC wrappers), or
- an orchestration layer around existing samplers?
Happy to iterate on a design that fits the broader framework.
|
|
||
| def parallel_tempering( | ||
| key, | ||
| ebms: Sequence[AbstractEBM], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Feels like we should have a high level wrapper that would just sample from the EBM and accept some beta type parameters
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That makes sense. The current function requires EBMs + BlockSamplingPrograms explicitly, but a high-level wrapper that accepts an EBM and a sequence of betas, constructs the tempered models/programs internally, and exposes a simple .sample() API would be much cleaner.
Once we align on placement and expected API shape, I can add it as a follow-up PR.
| swap_key = keys[-1] | ||
|
|
||
| # Gibbs updates per chain | ||
| for i in range(len(ebms)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this double for loop seems painful for compilation time, I bet at least one of these could be converted to a scan
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed, the nested loop is not ideal for compilation time indeed. Let me figure out how to JAXify.
Parallel tempering support for EBMs using multiple tempered block-Gibbs chains + swap proposals