Skip to content

Commit 077457f

Browse files
committed
Implement bundle_samples for ParamsWithStats -> MCMCChains
1 parent 2bf5b18 commit 077457f

File tree

1 file changed

+38
-0
lines changed

1 file changed

+38
-0
lines changed

ext/DynamicPPLMCMCChainsExt.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,44 @@ function AbstractMCMC.to_samples(
140140
end
141141
end
142142

143+
function AbstractMCMC.bundle_samples(
144+
ts::Vector{<:DynamicPPL.ParamsWithStats},
145+
model::DynamicPPL.Model,
146+
spl::AbstractMCMC.AbstractSampler,
147+
state,
148+
chain_type::Type{MCMCChains.Chains};
149+
save_state=false,
150+
stats=missing,
151+
sort_chain=false,
152+
discard_initial=0,
153+
thinning=1,
154+
kwargs...,
155+
)
156+
# Construct the 'bare' chain first
157+
bare_chain = AbstractMCMC.from_samples(MCMCChains.Chains, reshape(ts, :, 1))
158+
159+
# Add additional MCMC-specific info
160+
info = bare_chain.info
161+
if save_state
162+
info = merge(info, (model=model, sampler=spl, samplerstate=state))
163+
end
164+
if !ismissing(stats)
165+
info = merge(info, (start_time=stats.start, stop_time=stats.stop))
166+
end
167+
168+
# Reconstruct the chain with the extra information
169+
# Yeah, this is quite ugly. Blame MCMCChains.
170+
chain = MCMCChains.Chains(
171+
bare_chain.value.data,
172+
names(bare_chain),
173+
bare_chain.name_map;
174+
info=info,
175+
start=discard_initial + 1,
176+
thin=thinning,
177+
)
178+
return sort_chain ? sort(chain) : chain
179+
end
180+
143181
"""
144182
predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false)
145183

0 commit comments

Comments
 (0)