From a494d006a8a7465d1621f210ce9bcc04b0941d45 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 7 Nov 2025 17:09:01 +0000 Subject: [PATCH 1/9] Make InitContext work with OnlyAccsVarInfo --- src/fasteval.jl | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/fasteval.jl b/src/fasteval.jl index c91254d43..dcdbaa608 100644 --- a/src/fasteval.jl +++ b/src/fasteval.jl @@ -81,6 +81,7 @@ using DynamicPPL: getlogprior_internal, leafcontext using ADTypes: ADTypes +using BangBang: BangBang using Bijectors: with_logabsdet_jacobian using AbstractPPL: AbstractPPL, VarName using Distributions: Distribution @@ -108,6 +109,9 @@ OnlyAccsVarInfo() = OnlyAccsVarInfo(default_accumulators()) DynamicPPL.maybe_invlink_before_eval!!(vi::OnlyAccsVarInfo, ::Model) = vi DynamicPPL.getaccs(vi::OnlyAccsVarInfo) = vi.accs DynamicPPL.setaccs!!(::OnlyAccsVarInfo, accs::AccumulatorTuple) = OnlyAccsVarInfo(accs) +@inline Base.haskey(::OnlyAccsVarInfo, ::VarName) = false +@inline DynamicPPL.is_transformed(::OnlyAccsVarInfo) = false +@inline BangBang.push!!(vi::OnlyAccsVarInfo, vn, y, dist) = vi function DynamicPPL.get_param_eltype( ::Union{OnlyAccsVarInfo,ThreadSafeVarInfo{<:OnlyAccsVarInfo}}, model::Model ) @@ -117,14 +121,11 @@ function DynamicPPL.get_param_eltype( leaf_ctx = DynamicPPL.leafcontext(model.context) if leaf_ctx isa FastEvalVectorContext return eltype(leaf_ctx.params) + elseif leaf_ctx isa InitContext{<:Any,<:InitFromParams} + eltype = DynamicPPL.infer_nested_eltype(leaf_ctx.strategy.params) + @info "Inferring parameter eltype as $eltype from InitContext" + return eltype else - # TODO(penelopeysm): In principle this can be done with InitContext{InitWithParams}. - # See also `src/simple_varinfo.jl` where `infer_nested_eltype` is used to try to - # figure out the parameter type from a NamedTuple or Dict. The benefit of - # implementing this for InitContext is that we could then use OnlyAccsVarInfo with - # it, which means fast evaluation with NamedTuple or Dict parameters! And I believe - # that Mooncake / Enzyme should be able to differentiate through that too and - # provide a NamedTuple of gradients (although I haven't tested this yet). error( "OnlyAccsVarInfo can only be used with FastEval contexts, found $(typeof(leaf_ctx))", ) From 35432c07a5b06291f644fba14143b5fb4e2889ca Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 7 Nov 2025 17:26:12 +0000 Subject: [PATCH 2/9] Do not convert NamedTuple to Dict --- src/contexts/init.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/contexts/init.jl b/src/contexts/init.jl index 44dbc5508..efc6f1087 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -102,7 +102,7 @@ struct InitFromParams{P,S<:Union{AbstractInitStrategy,Nothing}} <: AbstractInitS function InitFromParams( params::NamedTuple, fallback::Union{AbstractInitStrategy,Nothing}=InitFromPrior() ) - return InitFromParams(to_varname_dict(params), fallback) + return new{typeof(params),typeof(fallback)}(params, fallback) end end function init(rng::Random.AbstractRNG, vn::VarName, dist::Distribution, p::InitFromParams) From f330d59e07fe48992d19bba2c711f278dfdb66a5 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 7 Nov 2025 17:34:18 +0000 Subject: [PATCH 3/9] remove logging --- src/fasteval.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/fasteval.jl b/src/fasteval.jl index dcdbaa608..3b1ae2550 100644 --- a/src/fasteval.jl +++ b/src/fasteval.jl @@ -122,9 +122,7 @@ function DynamicPPL.get_param_eltype( if leaf_ctx isa FastEvalVectorContext return eltype(leaf_ctx.params) elseif leaf_ctx isa InitContext{<:Any,<:InitFromParams} - eltype = DynamicPPL.infer_nested_eltype(leaf_ctx.strategy.params) - @info "Inferring parameter eltype as $eltype from InitContext" - return eltype + return DynamicPPL.infer_nested_eltype(leaf_ctx.strategy.params) else error( "OnlyAccsVarInfo can only be used with FastEval contexts, found $(typeof(leaf_ctx))", From 7deaaabae77c4f25ea8f7b423390446764072ec6 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sat, 8 Nov 2025 17:24:35 +0000 Subject: [PATCH 4/9] Enable InitFromPrior and InitFromUniform too --- src/fasteval.jl | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/fasteval.jl b/src/fasteval.jl index 3b1ae2550..2e1bccdc6 100644 --- a/src/fasteval.jl +++ b/src/fasteval.jl @@ -60,6 +60,10 @@ using DynamicPPL: AbstractContext, AbstractVarInfo, AccumulatorTuple, + InitContext, + InitFromParams, + InitFromPrior, + InitFromUniform, LogJacobianAccumulator, LogLikelihoodAccumulator, LogPriorAccumulator, @@ -123,6 +127,9 @@ function DynamicPPL.get_param_eltype( return eltype(leaf_ctx.params) elseif leaf_ctx isa InitContext{<:Any,<:InitFromParams} return DynamicPPL.infer_nested_eltype(leaf_ctx.strategy.params) + elseif leaf_ctx isa InitContext{<:Any,<:Union{InitFromPrior,InitFromUniform}} + # No need to enforce any particular eltype here, since new parameters are sampled + return Any else error( "OnlyAccsVarInfo can only be used with FastEval contexts, found $(typeof(leaf_ctx))", From 699aa23b31d5f921cc5669bdc96faa3722a1f42d Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sat, 8 Nov 2025 18:27:36 +0000 Subject: [PATCH 5/9] Fix `infer_nested_eltype` invocation --- src/fasteval.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fasteval.jl b/src/fasteval.jl index 2e1bccdc6..fbc6a61ce 100644 --- a/src/fasteval.jl +++ b/src/fasteval.jl @@ -126,7 +126,7 @@ function DynamicPPL.get_param_eltype( if leaf_ctx isa FastEvalVectorContext return eltype(leaf_ctx.params) elseif leaf_ctx isa InitContext{<:Any,<:InitFromParams} - return DynamicPPL.infer_nested_eltype(leaf_ctx.strategy.params) + return DynamicPPL.infer_nested_eltype(typeof(leaf_ctx.strategy.params)) elseif leaf_ctx isa InitContext{<:Any,<:Union{InitFromPrior,InitFromUniform}} # No need to enforce any particular eltype here, since new parameters are sampled return Any From c188b7e43e45f2162259f7ab446338cfaaca5df6 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sat, 8 Nov 2025 18:23:40 +0000 Subject: [PATCH 6/9] Use OnlyAccsVarInfo for many re-evaluation functions --- ext/DynamicPPLMCMCChainsExt.jl | 120 +++++++++++---------------------- src/chains.jl | 26 ------- src/model.jl | 2 +- 3 files changed, 41 insertions(+), 107 deletions(-) diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl index d8c343917..5e8b349d8 100644 --- a/ext/DynamicPPLMCMCChainsExt.jl +++ b/ext/DynamicPPLMCMCChainsExt.jl @@ -3,39 +3,17 @@ module DynamicPPLMCMCChainsExt using DynamicPPL: DynamicPPL, AbstractPPL, AbstractMCMC using MCMCChains: MCMCChains -_has_varname_to_symbol(info::NamedTuple{names}) where {names} = :varname_to_symbol in names - -function DynamicPPL.supports_varname_indexing(chain::MCMCChains.Chains) - return _has_varname_to_symbol(chain.info) -end - -function _check_varname_indexing(c::MCMCChains.Chains) - return DynamicPPL.supports_varname_indexing(c) || - error("This `Chains` object does not support indexing using `VarName`s.") -end - -function DynamicPPL.getindex_varname( +function getindex_varname( c::MCMCChains.Chains, sample_idx, vn::DynamicPPL.VarName, chain_idx ) - _check_varname_indexing(c) return c[sample_idx, c.info.varname_to_symbol[vn], chain_idx] end -function DynamicPPL.varnames(c::MCMCChains.Chains) - _check_varname_indexing(c) +function get_varnames(c::MCMCChains.Chains) + haskey(c.info, :varname_to_symbol) || + error("This `Chains` object does not support indexing using `VarName`s.") return keys(c.info.varname_to_symbol) end -function chain_sample_to_varname_dict( - c::MCMCChains.Chains{Tval}, sample_idx, chain_idx -) where {Tval} - _check_varname_indexing(c) - d = Dict{DynamicPPL.VarName,Tval}() - for vn in DynamicPPL.varnames(c) - d[vn] = DynamicPPL.getindex_varname(c, sample_idx, vn, chain_idx) - end - return d -end - """ AbstractMCMC.from_samples( ::Type{MCMCChains.Chains}, @@ -118,8 +96,8 @@ function AbstractMCMC.to_samples( # Get parameters params_matrix = map(idxs) do (sample_idx, chain_idx) d = DynamicPPL.OrderedCollections.OrderedDict{DynamicPPL.VarName,Any}() - for vn in DynamicPPL.varnames(chain) - d[vn] = DynamicPPL.getindex_varname(chain, sample_idx, vn, chain_idx) + for vn in get_varnames(chain) + d[vn] = getindex_varname(chain, sample_idx, vn, chain_idx) end d end @@ -209,22 +187,17 @@ function DynamicPPL.predict( ) parameter_only_chain = MCMCChains.get_sections(chain, :parameters) - # Set up a VarInfo with the right accumulators - varinfo = DynamicPPL.setaccs!!( - DynamicPPL.VarInfo(), - ( - DynamicPPL.LogPriorAccumulator(), - DynamicPPL.LogLikelihoodAccumulator(), - DynamicPPL.ValuesAsInModelAccumulator(false), - ), - ) - _, varinfo = DynamicPPL.init!!(model, varinfo) - varinfo = DynamicPPL.typed_varinfo(varinfo) - params_and_stats = AbstractMCMC.to_samples( DynamicPPL.ParamsWithStats, parameter_only_chain ) predictions = map(params_and_stats) do ps + varinfo = DynamicPPL.Experimental.OnlyAccsVarInfo( + DynamicPPL.AccumulatorTuple(( + DynamicPPL.LogPriorAccumulator(), + DynamicPPL.LogLikelihoodAccumulator(), + DynamicPPL.ValuesAsInModelAccumulator(false), + )), + ) _, varinfo = DynamicPPL.init!!( rng, model, varinfo, DynamicPPL.InitFromParams(ps.params) ) @@ -311,16 +284,11 @@ julia> returned(model, chain) """ function DynamicPPL.returned(model::DynamicPPL.Model, chain_full::MCMCChains.Chains) chain = MCMCChains.get_sections(chain_full, :parameters) - varinfo = DynamicPPL.VarInfo(model) - iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) params_with_stats = AbstractMCMC.to_samples(DynamicPPL.ParamsWithStats, chain) return map(params_with_stats) do ps + varinfo = DynamicPPL.Experimental.OnlyAccsVarInfo(DynamicPPL.AccumulatorTuple(())) first( - DynamicPPL.init!!( - model, - varinfo, - DynamicPPL.InitFromParams(ps.params, DynamicPPL.InitFromPrior()), - ), + DynamicPPL.init!!(model, varinfo, DynamicPPL.InitFromParams(ps.params, nothing)) ) end end @@ -415,21 +383,14 @@ function DynamicPPL.pointwise_logdensities( ::Type{Tout}=MCMCChains.Chains, ::Val{whichlogprob}=Val(:both), ) where {whichlogprob,Tout} - vi = DynamicPPL.VarInfo(model) acc = DynamicPPL.PointwiseLogProbAccumulator{whichlogprob}() accname = DynamicPPL.accumulator_name(acc) - vi = DynamicPPL.setaccs!!(vi, (acc,)) parameter_only_chain = MCMCChains.get_sections(chain, :parameters) - iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) - pointwise_logps = map(iters) do (sample_idx, chain_idx) - # Extract values from the chain - values_dict = chain_sample_to_varname_dict(parameter_only_chain, sample_idx, chain_idx) - # Re-evaluate the model - _, vi = DynamicPPL.init!!( - model, - vi, - DynamicPPL.InitFromParams(values_dict, DynamicPPL.InitFromPrior()), - ) + params_with_stats = AbstractMCMC.to_samples(DynamicPPL.ParamsWithStats, chain) + pointwise_logps = map(params_with_stats) do ps + accs = DynamicPPL.AccumulatorTuple((acc,)) + vi = DynamicPPL.Experimental.OnlyAccsVarInfo(accs) + _, vi = DynamicPPL.init!!(model, vi, DynamicPPL.InitFromParams(ps.params, nothing)) DynamicPPL.getacc(vi, Val(accname)).logps end @@ -519,14 +480,15 @@ julia> logjoint(demo_model([1., 2.]), chain) ``` """ function DynamicPPL.logjoint(model::DynamicPPL.Model, chain::MCMCChains.Chains) - var_info = DynamicPPL.VarInfo(model) # extract variables info from the model - map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx) - argvals_dict = DynamicPPL.OrderedCollections.OrderedDict{DynamicPPL.VarName,Any}( - vn_parent => DynamicPPL.values_from_chain( - var_info, vn_parent, chain, chain_idx, iteration_idx - ) for vn_parent in keys(var_info) + params_with_stats = AbstractMCMC.to_samples(DynamicPPL.ParamsWithStats, chain) + return map(params_with_stats) do ps + vi = DynamicPPL.Experimental.OnlyAccsVarInfo( + DynamicPPL.AccumulatorTuple(( + DynamicPPL.LogPriorAccumulator(), DynamicPPL.LogLikelihoodAccumulator() + )), ) - DynamicPPL.logjoint(model, argvals_dict) + _, vi = DynamicPPL.init!!(model, vi, DynamicPPL.InitFromParams(ps.params, nothing)) + DynamicPPL.getlogjoint(vi) end end @@ -559,14 +521,13 @@ julia> loglikelihood(demo_model([1., 2.]), chain) ``` """ function DynamicPPL.loglikelihood(model::DynamicPPL.Model, chain::MCMCChains.Chains) - var_info = DynamicPPL.VarInfo(model) # extract variables info from the model - map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx) - argvals_dict = DynamicPPL.OrderedCollections.OrderedDict{DynamicPPL.VarName,Any}( - vn_parent => DynamicPPL.values_from_chain( - var_info, vn_parent, chain, chain_idx, iteration_idx - ) for vn_parent in keys(var_info) + params_with_stats = AbstractMCMC.to_samples(DynamicPPL.ParamsWithStats, chain) + return map(params_with_stats) do ps + vi = DynamicPPL.Experimental.OnlyAccsVarInfo( + DynamicPPL.AccumulatorTuple((DynamicPPL.LogLikelihoodAccumulator())) ) - DynamicPPL.loglikelihood(model, argvals_dict) + _, vi = DynamicPPL.init!!(model, vi, DynamicPPL.InitFromParams(ps.params, nothing)) + DynamicPPL.getloglikelihood(vi) end end @@ -600,14 +561,13 @@ julia> logprior(demo_model([1., 2.]), chain) ``` """ function DynamicPPL.logprior(model::DynamicPPL.Model, chain::MCMCChains.Chains) - var_info = DynamicPPL.VarInfo(model) # extract variables info from the model - map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx) - argvals_dict = DynamicPPL.OrderedCollections.OrderedDict{DynamicPPL.VarName,Any}( - vn_parent => DynamicPPL.values_from_chain( - var_info, vn_parent, chain, chain_idx, iteration_idx - ) for vn_parent in keys(var_info) + params_with_stats = AbstractMCMC.to_samples(DynamicPPL.ParamsWithStats, chain) + return map(params_with_stats) do ps + vi = DynamicPPL.Experimental.OnlyAccsVarInfo( + DynamicPPL.AccumulatorTuple((DynamicPPL.LogPriorAccumulator())) ) - DynamicPPL.logprior(model, argvals_dict) + _, vi = DynamicPPL.init!!(model, vi, DynamicPPL.InitFromParams(ps.params, nothing)) + DynamicPPL.getlogprior(vi) end end diff --git a/src/chains.jl b/src/chains.jl index 2b5976b9b..d47fb901a 100644 --- a/src/chains.jl +++ b/src/chains.jl @@ -1,29 +1,3 @@ -""" - supports_varname_indexing(chain::AbstractChains) - -Return `true` if `chain` supports indexing using `VarName` in place of the -variable name index. -""" -supports_varname_indexing(::AbstractChains) = false - -""" - getindex_varname(chain::AbstractChains, sample_idx, varname::VarName, chain_idx) - -Return the value of `varname` in `chain` at `sample_idx` and `chain_idx`. - -Whether this method is implemented for `chains` is indicated by [`supports_varname_indexing`](@ref). -""" -function getindex_varname end - -""" - varnames(chains::AbstractChains) - -Return an iterator over the varnames present in `chains`. - -Whether this method is implemented for `chains` is indicated by [`supports_varname_indexing`](@ref). -""" -function varnames end - """ ParamsWithStats diff --git a/src/model.jl b/src/model.jl index 6ca06aea6..718c56372 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1148,7 +1148,7 @@ julia> returned(model, Dict{VarName,Float64}(@varname(m) => 2.0)) ``` """ function returned(model::Model, parameters::Union{NamedTuple,AbstractDict{<:VarName}}) - vi = DynamicPPL.setaccs!!(VarInfo(), ()) + vi = DynamicPPL.Experimental.OnlyAccsVarInfo(AccumulatorTuple()) # Note: we can't use `fix(model, parameters)` because # https://github.com/TuringLang/DynamicPPL.jl/issues/1097 # Use `nothing` as the fallback to ensure that any missing parameters cause an error From 5a976fd56ef1b8436a1a4f803be8569d81889a12 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sat, 8 Nov 2025 18:53:42 +0000 Subject: [PATCH 7/9] Generalise eltype inference --- src/fasteval.jl | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/fasteval.jl b/src/fasteval.jl index fbc6a61ce..5b95c8376 100644 --- a/src/fasteval.jl +++ b/src/fasteval.jl @@ -60,6 +60,7 @@ using DynamicPPL: AbstractContext, AbstractVarInfo, AccumulatorTuple, + DynamicPPL, InitContext, InitFromParams, InitFromPrior, @@ -125,17 +126,20 @@ function DynamicPPL.get_param_eltype( leaf_ctx = DynamicPPL.leafcontext(model.context) if leaf_ctx isa FastEvalVectorContext return eltype(leaf_ctx.params) - elseif leaf_ctx isa InitContext{<:Any,<:InitFromParams} - return DynamicPPL.infer_nested_eltype(typeof(leaf_ctx.strategy.params)) - elseif leaf_ctx isa InitContext{<:Any,<:Union{InitFromPrior,InitFromUniform}} - # No need to enforce any particular eltype here, since new parameters are sampled - return Any + elseif leaf_ctx isa InitContext + return _get_strategy_eltype(leaf_ctx.strategy) else error( "OnlyAccsVarInfo can only be used with FastEval contexts, found $(typeof(leaf_ctx))", ) end end +_get_strategy_eltype(s::InitFromParams) = DynamicPPL.infer_nested_eltype(typeof(s.params)) +# No need to enforce any particular eltype here, since new parameters are sampled +_get_strategy_eltype(::InitFromPrior) = Any +_get_strategy_eltype(::InitFromUniform) = Any +# Default fallback +_get_strategy_eltype(::DynamicPPL.AbstractInitStrategy) = Any """ RangeAndLinked From 8902e6afc468903fc306f65cf8ed46de7b2c3e3a Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sat, 8 Nov 2025 22:14:47 +0000 Subject: [PATCH 8/9] Make it more elegant --- ext/DynamicPPLMCMCChainsExt.jl | 135 +++++++++++++++++++-------------- 1 file changed, 76 insertions(+), 59 deletions(-) diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl index 5e8b349d8..89fb511c1 100644 --- a/ext/DynamicPPLMCMCChainsExt.jl +++ b/ext/DynamicPPLMCMCChainsExt.jl @@ -1,6 +1,6 @@ module DynamicPPLMCMCChainsExt -using DynamicPPL: DynamicPPL, AbstractPPL, AbstractMCMC +using DynamicPPL: DynamicPPL, AbstractPPL, AbstractMCMC, Random using MCMCChains: MCMCChains function getindex_varname( @@ -118,6 +118,47 @@ function AbstractMCMC.to_samples( end end +""" + reevaluate_with( + rng::AbstractRNG, + model::Model, + chain::MCMCChains.Chains; + fallback=nothing, + ) + +Re-evaluate `model` for each sample in `chain`, returning an matrix of (retval, varinfo) +tuples. + +This loops over all entries in the chain and uses `DynamicPPL.InitFromParams` as the +initialisation strategy when re-evaluating the model. For many usecases the fallback should +not be provided (as we expect the chain to contain all necessary variables); but for +`predict` this has to be `InitFromPrior()` to allow sampling new variables (i.e. generating +the posterior predictions). +""" +function reevaluate_with_chain( + rng::Random.AbstractRNG, + model::DynamicPPL.Model, + chain::MCMCChains.Chains, + accs::NTuple{N,DynamicPPL.AbstractAccumulator}, + fallback::Union{DynamicPPL.AbstractInitStrategy,Nothing}=nothing, +) where {N} + params_with_stats = AbstractMCMC.to_samples(DynamicPPL.ParamsWithStats, chain) + return map(params_with_stats) do ps + varinfo = DynamicPPL.Experimental.OnlyAccsVarInfo(DynamicPPL.AccumulatorTuple(accs)) + DynamicPPL.init!!( + rng, model, varinfo, DynamicPPL.InitFromParams(ps.params, fallback) + ) + end +end +function reevaluate_with_chain( + model::DynamicPPL.Model, + chain::MCMCChains.Chains, + accs::NTuple{N,DynamicPPL.AbstractAccumulator}, + fallback::Union{DynamicPPL.AbstractInitStrategy,Nothing}=nothing, +) where {N} + return reevaluate_with_chain(Random.default_rng(), model, chain, accs, fallback) +end + """ predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false) @@ -186,25 +227,18 @@ function DynamicPPL.predict( include_all=false, ) parameter_only_chain = MCMCChains.get_sections(chain, :parameters) - - params_and_stats = AbstractMCMC.to_samples( - DynamicPPL.ParamsWithStats, parameter_only_chain + accs = DynamicPPL.AccumulatorTuple( + DynamicPPL.LogPriorAccumulator(), + DynamicPPL.LogLikelihoodAccumulator(), + DynamicPPL.ValuesAsInModelAccumulator(false), + ) + predictions = map( + DynamicPPL.ParamsWithStats ∘ last, + reevaluate_with_chain( + rng, model, parameter_only_chain, accs, DynamicPPL.InitFromPrior() + ), ) - predictions = map(params_and_stats) do ps - varinfo = DynamicPPL.Experimental.OnlyAccsVarInfo( - DynamicPPL.AccumulatorTuple(( - DynamicPPL.LogPriorAccumulator(), - DynamicPPL.LogLikelihoodAccumulator(), - DynamicPPL.ValuesAsInModelAccumulator(false), - )), - ) - _, varinfo = DynamicPPL.init!!( - rng, model, varinfo, DynamicPPL.InitFromParams(ps.params) - ) - DynamicPPL.ParamsWithStats(varinfo) - end chain_result = AbstractMCMC.from_samples(MCMCChains.Chains, predictions) - parameter_names = if include_all MCMCChains.names(chain_result, :parameters) else @@ -284,13 +318,7 @@ julia> returned(model, chain) """ function DynamicPPL.returned(model::DynamicPPL.Model, chain_full::MCMCChains.Chains) chain = MCMCChains.get_sections(chain_full, :parameters) - params_with_stats = AbstractMCMC.to_samples(DynamicPPL.ParamsWithStats, chain) - return map(params_with_stats) do ps - varinfo = DynamicPPL.Experimental.OnlyAccsVarInfo(DynamicPPL.AccumulatorTuple(())) - first( - DynamicPPL.init!!(model, varinfo, DynamicPPL.InitFromParams(ps.params, nothing)) - ) - end + return map(first, reevaluate_with_chain(model, chain, (), nothing)) end """ @@ -386,14 +414,10 @@ function DynamicPPL.pointwise_logdensities( acc = DynamicPPL.PointwiseLogProbAccumulator{whichlogprob}() accname = DynamicPPL.accumulator_name(acc) parameter_only_chain = MCMCChains.get_sections(chain, :parameters) - params_with_stats = AbstractMCMC.to_samples(DynamicPPL.ParamsWithStats, chain) - pointwise_logps = map(params_with_stats) do ps - accs = DynamicPPL.AccumulatorTuple((acc,)) - vi = DynamicPPL.Experimental.OnlyAccsVarInfo(accs) - _, vi = DynamicPPL.init!!(model, vi, DynamicPPL.InitFromParams(ps.params, nothing)) - DynamicPPL.getacc(vi, Val(accname)).logps - end - + pointwise_logps = + map(reevaluate_with_chain(model, parameter_only_chain, (acc,), nothing)) do (_, vi) + DynamicPPL.getacc(vi, Val(accname)).logps + end # pointwise_logps is a matrix of OrderedDicts all_keys = DynamicPPL.OrderedCollections.OrderedSet{DynamicPPL.VarName}() for d in pointwise_logps @@ -480,16 +504,15 @@ julia> logjoint(demo_model([1., 2.]), chain) ``` """ function DynamicPPL.logjoint(model::DynamicPPL.Model, chain::MCMCChains.Chains) - params_with_stats = AbstractMCMC.to_samples(DynamicPPL.ParamsWithStats, chain) - return map(params_with_stats) do ps - vi = DynamicPPL.Experimental.OnlyAccsVarInfo( - DynamicPPL.AccumulatorTuple(( - DynamicPPL.LogPriorAccumulator(), DynamicPPL.LogLikelihoodAccumulator() - )), - ) - _, vi = DynamicPPL.init!!(model, vi, DynamicPPL.InitFromParams(ps.params, nothing)) - DynamicPPL.getlogjoint(vi) - end + return map( + DynamicPPL.getlogjoint ∘ last, + reevaluate_with_chain( + model, + chain, + (DynamicPPL.LogPriorAccumulator(), DynamicPPL.LogLikelihoodAccumulator()), + nothing, + ), + ) end """ @@ -521,14 +544,12 @@ julia> loglikelihood(demo_model([1., 2.]), chain) ``` """ function DynamicPPL.loglikelihood(model::DynamicPPL.Model, chain::MCMCChains.Chains) - params_with_stats = AbstractMCMC.to_samples(DynamicPPL.ParamsWithStats, chain) - return map(params_with_stats) do ps - vi = DynamicPPL.Experimental.OnlyAccsVarInfo( - DynamicPPL.AccumulatorTuple((DynamicPPL.LogLikelihoodAccumulator())) - ) - _, vi = DynamicPPL.init!!(model, vi, DynamicPPL.InitFromParams(ps.params, nothing)) - DynamicPPL.getloglikelihood(vi) - end + return map( + DynamicPPL.getloglikelihood ∘ last, + reevaluate_with_chain( + model, chain, (DynamicPPL.LogLikelihoodAccumulator(),), nothing + ), + ) end """ @@ -561,14 +582,10 @@ julia> logprior(demo_model([1., 2.]), chain) ``` """ function DynamicPPL.logprior(model::DynamicPPL.Model, chain::MCMCChains.Chains) - params_with_stats = AbstractMCMC.to_samples(DynamicPPL.ParamsWithStats, chain) - return map(params_with_stats) do ps - vi = DynamicPPL.Experimental.OnlyAccsVarInfo( - DynamicPPL.AccumulatorTuple((DynamicPPL.LogPriorAccumulator())) - ) - _, vi = DynamicPPL.init!!(model, vi, DynamicPPL.InitFromParams(ps.params, nothing)) - DynamicPPL.getlogprior(vi) - end + return map( + DynamicPPL.getlogprior ∘ last, + reevaluate_with_chain(model, chain, (DynamicPPL.LogPriorAccumulator(),), nothing), + ) end end From 1d79c7fbf014d818ea00f5ef5995913fe9766bd4 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sat, 8 Nov 2025 22:50:10 +0000 Subject: [PATCH 9/9] Fix bug --- ext/DynamicPPLMCMCChainsExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl index 89fb511c1..9b34a9849 100644 --- a/ext/DynamicPPLMCMCChainsExt.jl +++ b/ext/DynamicPPLMCMCChainsExt.jl @@ -227,7 +227,7 @@ function DynamicPPL.predict( include_all=false, ) parameter_only_chain = MCMCChains.get_sections(chain, :parameters) - accs = DynamicPPL.AccumulatorTuple( + accs = ( DynamicPPL.LogPriorAccumulator(), DynamicPPL.LogLikelihoodAccumulator(), DynamicPPL.ValuesAsInModelAccumulator(false),