From 7cddac775a3bfa0eb0808bb241551e875c6ac0b2 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 5 Nov 2025 23:58:25 +0000 Subject: [PATCH 01/23] Fast Log Density Function --- src/DynamicPPL.jl | 1 + src/fastldf.jl | 91 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 92 insertions(+) create mode 100644 src/fastldf.jl diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index e66f3fe11..77d527ced 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -191,6 +191,7 @@ include("simple_varinfo.jl") include("compiler.jl") include("pointwise_logdensities.jl") include("logdensityfunction.jl") +include("fastldf.jl") include("model_utils.jl") include("extract_priors.jl") include("values_as_in_model.jl") diff --git a/src/fastldf.jl b/src/fastldf.jl new file mode 100644 index 000000000..05b024dbe --- /dev/null +++ b/src/fastldf.jl @@ -0,0 +1,91 @@ +struct OnlyAccsVarInfo{Accs<:AccumulatorTuple} <: AbstractVarInfo + accs::Accs +end +DynamicPPL.getaccs(vi::OnlyAccsVarInfo) = vi.accs +DynamicPPL.maybe_invlink_before_eval!!(vi::OnlyAccsVarInfo, ::Model) = vi +DynamicPPL.setaccs!!(::OnlyAccsVarInfo, accs::AccumulatorTuple) = OnlyAccsVarInfo(accs) + +struct RangeAndLinked + # indices that the variable corresponds to in the vectorised parameter + range::UnitRange{Int} + # whether it's linked + is_linked::Bool +end + +struct FastLDFContext{T<:AbstractVector{<:Real}} <: AbstractContext + varname_ranges::Dict{VarName,RangeAndLinked} + params::T +end +DynamicPPL.NodeTrait(::FastLDFContext) = IsLeaf() + +function tilde_assume!!( + ctx::FastLDFContext, right::Distribution, vn::VarName, vi::OnlyAccsVarInfo +) + # Don't need to read the data from the varinfo at all since it's + # all inside the context. + range_and_linked = ctx.varname_ranges[vn] + y = @view ctx.params[range_and_linked.range] + is_linked = range_and_linked.is_linked + f = if is_linked + from_linked_vec_transform(right) + else + from_vec_transform(right) + end + x, inv_logjac = with_logabsdet_jacobian(f, y) + vi = accumulate_assume!!(vi, x, -inv_logjac, vn, right) + return x, vi +end + +function tilde_observe!!( + ::FastLDFContext, + right::Distribution, + left, + vn::Union{VarName,Nothing}, + vi::OnlyAccsVarInfo, +) + # This is the same as for DefaultContext + vi = accumulate_observe!!(vi, right, left, vn) + return left, vi +end + +struct FastLDF{M<:Model,F<:Function} + _model::M + _getlogdensity::F + _varname_ranges::Dict{VarName,RangeAndLinked} + + function FastLDF( + model::Model, + getlogdensity::Function, + # This only works with typed Metadata-varinfo. + # Obviously, this can be generalised later. + varinfo::VarInfo{<:NamedTuple{syms}}, + ) where {syms} + # Figure out which variable corresponds to which index, and + # which variables are linked. + all_ranges = Dict{VarName,RangeAndLinked}() + offset = 1 + for sym in syms + md = varinfo.metadata[sym] + for (vn, idx) in md.idcs + len = length(md.ranges[idx]) + is_linked = md.is_transformed[idx] + range = offset:(offset + len - 1) + all_ranges[vn] = RangeAndLinked(range, is_linked) + offset += len + end + end + return new{typeof(model),typeof(getlogdensity)}(model, getlogdensity, all_ranges) + end +end + +function LogDensityProblems.logdensity(fldf::FastLDF, params::AbstractVector{<:Real}) + ctx = FastLDFContext(fldf._varname_ranges, params) + model = DynamicPPL.setleafcontext(fldf._model, ctx) + # This can obviously also be optimised for the case where not + # all accumulators are needed. + accs = AccumulatorTuple(( + LogPriorAccumulator(), LogLikelihoodAccumulator(), LogJacobianAccumulator() + )) + _, vi = DynamicPPL._evaluate!!(model, OnlyAccsVarInfo(accs)) + return fldf._getlogdensity(vi) +end From 5ed4295f23b0c5cd2d1568de247da249de42913b Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 6 Nov 2025 00:33:49 +0000 Subject: [PATCH 02/23] Make it work with AD --- src/fastldf.jl | 56 +++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 49 insertions(+), 7 deletions(-) diff --git a/src/fastldf.jl b/src/fastldf.jl index 05b024dbe..2178c84a2 100644 --- a/src/fastldf.jl +++ b/src/fastldf.jl @@ -48,17 +48,25 @@ function tilde_observe!!( return left, vi end -struct FastLDF{M<:Model,F<:Function} +struct FastLDF{ + M<:Model, + F<:Function, + AD<:Union{ADTypes.AbstractADType,Nothing}, + ADP<:Union{Nothing,DI.GradientPrep}, +} _model::M _getlogdensity::F _varname_ranges::Dict{VarName,RangeAndLinked} + _adtype::AD + _adprep::ADP function FastLDF( model::Model, getlogdensity::Function, # This only works with typed Metadata-varinfo. # Obviously, this can be generalised later. - varinfo::VarInfo{<:NamedTuple{syms}}, + varinfo::VarInfo{<:NamedTuple{syms}}; + adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, ) where {syms} # Figure out which variable corresponds to which index, and # which variables are linked. @@ -74,18 +82,52 @@ struct FastLDF{M<:Model,F<:Function} offset += len end end - return new{typeof(model),typeof(getlogdensity)}(model, getlogdensity, all_ranges) + # Do AD prep if needed + prep = if adtype === nothing + nothing + else + # Make backend-specific tweaks to the adtype + adtype = tweak_adtype(adtype, model, varinfo) + x = [val for val in varinfo[:]] + DI.prepare_gradient( + FastLogDensityAt(model, getlogdensity, all_ranges), adtype, x + ) + end + + return new{typeof(model),typeof(getlogdensity),typeof(adtype),typeof(prep)}( + model, getlogdensity, all_ranges, adtype, prep + ) end end -function LogDensityProblems.logdensity(fldf::FastLDF, params::AbstractVector{<:Real}) - ctx = FastLDFContext(fldf._varname_ranges, params) - model = DynamicPPL.setleafcontext(fldf._model, ctx) +struct FastLogDensityAt{M<:Model,F<:Function} + _model::M + _getlogdensity::F + _varname_ranges::Dict{VarName,RangeAndLinked} +end +function (f::FastLogDensityAt)(params::AbstractVector{<:Real}) + ctx = FastLDFContext(f._varname_ranges, params) + model = DynamicPPL.setleafcontext(f._model, ctx) # This can obviously also be optimised for the case where not # all accumulators are needed. accs = AccumulatorTuple(( LogPriorAccumulator(), LogLikelihoodAccumulator(), LogJacobianAccumulator() )) _, vi = DynamicPPL._evaluate!!(model, OnlyAccsVarInfo(accs)) - return fldf._getlogdensity(vi) + return f._getlogdensity(vi) +end + +function LogDensityProblems.logdensity(fldf::FastLDF, params::AbstractVector{<:Real}) + return FastLogDensityAt(fldf._model, fldf._getlogdensity, fldf._varname_ranges)(params) +end + +function LogDensityProblems.logdensity_and_gradient( + fldf::FastLDF, params::AbstractVector{<:Real} +) + return DI.value_and_gradient( + FastLogDensityAt(fldf._model, fldf._getlogdensity, fldf._varname_ranges), + fldf._adprep, + fldf._adtype, + params, + ) end From e199520dac0b8c0ae82a3ed7fd1e4673bbd7c28a Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 6 Nov 2025 01:37:05 +0000 Subject: [PATCH 03/23] Optimise performance for identity VarNames --- src/fastldf.jl | 69 +++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 57 insertions(+), 12 deletions(-) diff --git a/src/fastldf.jl b/src/fastldf.jl index 2178c84a2..e59b12791 100644 --- a/src/fastldf.jl +++ b/src/fastldf.jl @@ -12,21 +12,35 @@ struct RangeAndLinked is_linked::Bool end -struct FastLDFContext{T<:AbstractVector{<:Real}} <: AbstractContext +struct FastLDFContext{N<:NamedTuple,T<:AbstractVector{<:Real}} <: AbstractContext + # The ranges of identity VarNames are stored in a NamedTuple for performance + # reasons. For just plain evaluation this doesn't make _that_ much of a + # difference (maybe 1.5x), but when doing AD with Mooncake this makes a HUGE + # difference (around 4x). Of course, the exact numbers depend on the model. + iden_varname_ranges::N + # This Dict stores the ranges for all other VarNames varname_ranges::Dict{VarName,RangeAndLinked} + # The full parameter vector which we index into to get variable values params::T end DynamicPPL.NodeTrait(::FastLDFContext) = IsLeaf() +function get_range_and_linked( + ctx::FastLDFContext, ::VarName{sym,typeof(identity)} +) where {sym} + return ctx.iden_varname_ranges[sym] +end +function get_range_and_linked(ctx::FastLDFContext, vn::VarName) + return ctx.varname_ranges[vn] +end function tilde_assume!!( ctx::FastLDFContext, right::Distribution, vn::VarName, vi::OnlyAccsVarInfo ) # Don't need to read the data from the varinfo at all since it's # all inside the context. - range_and_linked = ctx.varname_ranges[vn] + range_and_linked = get_range_and_linked(ctx, vn) y = @view ctx.params[range_and_linked.range] - is_linked = range_and_linked.is_linked - f = if is_linked + f = if range_and_linked.is_linked from_linked_vec_transform(right) else from_vec_transform(right) @@ -51,11 +65,14 @@ end struct FastLDF{ M<:Model, F<:Function, + N<:NamedTuple, AD<:Union{ADTypes.AbstractADType,Nothing}, ADP<:Union{Nothing,DI.GradientPrep}, } _model::M _getlogdensity::F + # See FastLDFContext for explanation of these two fields + _iden_varname_ranges::N _varname_ranges::Dict{VarName,RangeAndLinked} _adtype::AD _adprep::ADP @@ -70,6 +87,7 @@ struct FastLDF{ ) where {syms} # Figure out which variable corresponds to which index, and # which variables are linked. + all_iden_ranges = NamedTuple() all_ranges = Dict{VarName,RangeAndLinked}() offset = 1 for sym in syms @@ -78,7 +96,16 @@ struct FastLDF{ len = length(md.ranges[idx]) is_linked = md.is_transformed[idx] range = offset:(offset + len - 1) - all_ranges[vn] = RangeAndLinked(range, is_linked) + if AbstractPPL.getoptic(vn) === identity + all_iden_ranges = merge( + all_iden_ranges, + NamedTuple(( + AbstractPPL.getsym(vn) => RangeAndLinked(range, is_linked), + )), + ) + else + all_ranges[vn] = RangeAndLinked(range, is_linked) + end offset += len end end @@ -90,23 +117,32 @@ struct FastLDF{ adtype = tweak_adtype(adtype, model, varinfo) x = [val for val in varinfo[:]] DI.prepare_gradient( - FastLogDensityAt(model, getlogdensity, all_ranges), adtype, x + FastLogDensityAt(model, getlogdensity, all_iden_ranges, all_ranges), + adtype, + x, ) end - return new{typeof(model),typeof(getlogdensity),typeof(adtype),typeof(prep)}( - model, getlogdensity, all_ranges, adtype, prep + return new{ + typeof(model), + typeof(getlogdensity), + typeof(all_iden_ranges), + typeof(adtype), + typeof(prep), + }( + model, getlogdensity, all_iden_ranges, all_ranges, adtype, prep ) end end -struct FastLogDensityAt{M<:Model,F<:Function} +struct FastLogDensityAt{M<:Model,F<:Function,N<:NamedTuple} _model::M _getlogdensity::F + _iden_varname_ranges::N _varname_ranges::Dict{VarName,RangeAndLinked} end function (f::FastLogDensityAt)(params::AbstractVector{<:Real}) - ctx = FastLDFContext(f._varname_ranges, params) + ctx = FastLDFContext(f._iden_varname_ranges, f._varname_ranges, params) model = DynamicPPL.setleafcontext(f._model, ctx) # This can obviously also be optimised for the case where not # all accumulators are needed. @@ -118,14 +154,23 @@ function (f::FastLogDensityAt)(params::AbstractVector{<:Real}) end function LogDensityProblems.logdensity(fldf::FastLDF, params::AbstractVector{<:Real}) - return FastLogDensityAt(fldf._model, fldf._getlogdensity, fldf._varname_ranges)(params) + return FastLogDensityAt( + fldf._model, fldf._getlogdensity, fldf._iden_varname_ranges, fldf._varname_ranges + )( + params + ) end function LogDensityProblems.logdensity_and_gradient( fldf::FastLDF, params::AbstractVector{<:Real} ) return DI.value_and_gradient( - FastLogDensityAt(fldf._model, fldf._getlogdensity, fldf._varname_ranges), + FastLogDensityAt( + fldf._model, + fldf._getlogdensity, + fldf._iden_varname_ranges, + fldf._varname_ranges, + ), fldf._adprep, fldf._adtype, params, From 4cefaca4803107da0570430277a3f05a27ae2146 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 6 Nov 2025 02:43:55 +0000 Subject: [PATCH 04/23] Mark `get_range_and_linked` as having zero derivative --- ext/DynamicPPLEnzymeCoreExt.jl | 13 ++++++------- ext/DynamicPPLMooncakeExt.jl | 3 ++- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/ext/DynamicPPLEnzymeCoreExt.jl b/ext/DynamicPPLEnzymeCoreExt.jl index 35159636f..29a4e2cc7 100644 --- a/ext/DynamicPPLEnzymeCoreExt.jl +++ b/ext/DynamicPPLEnzymeCoreExt.jl @@ -1,16 +1,15 @@ module DynamicPPLEnzymeCoreExt -if isdefined(Base, :get_extension) - using DynamicPPL: DynamicPPL - using EnzymeCore -else - using ..DynamicPPL: DynamicPPL - using ..EnzymeCore -end +using DynamicPPL: DynamicPPL +using EnzymeCore # Mark is_transformed as having 0 derivative. The `nothing` return value is not significant, Enzyme # only checks whether such a method exists, and never runs it. @inline EnzymeCore.EnzymeRules.inactive(::typeof(DynamicPPL.is_transformed), args...) = nothing +# Likewise for get_range_and_linked. +@inline EnzymeCore.EnzymeRules.inactive( + ::typeof(DynamicPPL.get_range_and_linked), args... +) = nothing end diff --git a/ext/DynamicPPLMooncakeExt.jl b/ext/DynamicPPLMooncakeExt.jl index 23a3430eb..e49b81cb2 100644 --- a/ext/DynamicPPLMooncakeExt.jl +++ b/ext/DynamicPPLMooncakeExt.jl @@ -1,9 +1,10 @@ module DynamicPPLMooncakeExt -using DynamicPPL: DynamicPPL, is_transformed +using DynamicPPL: DynamicPPL, is_transformed, get_range_and_linked using Mooncake: Mooncake # This is purely an optimisation. Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{typeof(is_transformed),Vararg} +Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{typeof(get_range_and_linked),Vararg} end # module From 6dfd106ace57d776799c487c5dccbaed48211b13 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 6 Nov 2025 03:11:32 +0000 Subject: [PATCH 05/23] Update comment --- src/fastldf.jl | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/fastldf.jl b/src/fastldf.jl index e59b12791..6c8798d4c 100644 --- a/src/fastldf.jl +++ b/src/fastldf.jl @@ -13,10 +13,8 @@ struct RangeAndLinked end struct FastLDFContext{N<:NamedTuple,T<:AbstractVector{<:Real}} <: AbstractContext - # The ranges of identity VarNames are stored in a NamedTuple for performance - # reasons. For just plain evaluation this doesn't make _that_ much of a - # difference (maybe 1.5x), but when doing AD with Mooncake this makes a HUGE - # difference (around 4x). Of course, the exact numbers depend on the model. + # The ranges of identity VarNames are stored in a NamedTuple for improved performance + # (it's around 1.5x faster). iden_varname_ranges::N # This Dict stores the ranges for all other VarNames varname_ranges::Dict{VarName,RangeAndLinked} From 4ca9cf74fe1fd209a83fa1e4e99188b7e74e9a98 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 6 Nov 2025 13:05:35 +0000 Subject: [PATCH 06/23] Squeeze down VarInfo allocations --- src/varinfo.jl | 28 +++++++++++++++++++++++----- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 486d24191..f149c3522 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -394,7 +394,9 @@ end for f in names mdf = :(metadata.$f) len = :(sum(length, $mdf.ranges)) - push!(exprs, :($f = unflatten_metadata($mdf, x[($offset + 1):($offset + $len)]))) + push!( + exprs, :($f = unflatten_metadata($mdf, @view x[($offset + 1):($offset + $len)])) + ) offset = :($offset + $len) end length(exprs) == 0 && return :(NamedTuple()) @@ -755,7 +757,10 @@ getindex_internal(vi::VarInfo, vn::VarName) = getindex_internal(getmetadata(vi, # since then we might be returning a `SubArray` rather than an `Array`, which is typically # what a bijector would result in, even if the input is a view (`SubArray`). # TODO(torfjelde): An alternative is to implement `view` directly instead. -getindex_internal(md::Metadata, vn::VarName) = getindex(md.vals, getrange(md, vn)) +function getindex_internal(md::Metadata, vn::VarName) + rng = getrange(md, vn) + return @view md.vals[rng] +end function getindex_internal(vi::VarInfo, vns::Vector{<:VarName}) return mapreduce(Base.Fix1(getindex_internal, vi), vcat, vns) end @@ -1495,8 +1500,21 @@ space. If some but only some of the variables in `vi` are transformed, this function will return `true`. This behavior will likely change in the future. """ -function is_transformed(vi::VarInfo) - return any(is_transformed(vi, vn) for vn in keys(vi)) +function is_transformed(vi::NTVarInfo) + return is_transformed(vi.metadata) +end + +@generated function is_transformed(nt::NamedTuple{names}) where {names} + expr = Expr(:block) + push!(expr.args, :(result = false)) + for n in names + push!(expr.args, :(result = result || is_transformed(nt.$n))) + end + return expr +end + +function is_transformed(md::Metadata) + return any(md.is_transformed) end # The default getindex & setindex!() for get & set values @@ -1552,7 +1570,7 @@ end @generated function _getindex(metadata, ranges::NamedTuple{names}) where {names} expr = Expr(:tuple) for f in names - push!(expr.args, :(metadata.$f.vals[ranges.$f])) + push!(expr.args, :(@view metadata.$f.vals[ranges.$f])) end return expr end From 7c6e8c1b692663357cd339bcb96a60cbfd65adfc Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 6 Nov 2025 13:08:32 +0000 Subject: [PATCH 07/23] Remove old out-of-date comment --- src/varinfo.jl | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index f149c3522..0d77c70e6 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -753,10 +753,6 @@ function getdist(::VarNamedVector, ::VarName) end getindex_internal(vi::VarInfo, vn::VarName) = getindex_internal(getmetadata(vi, vn), vn) -# TODO(torfjelde): Use `view` instead of `getindex`. Requires addressing type-stability issues though, -# since then we might be returning a `SubArray` rather than an `Array`, which is typically -# what a bijector would result in, even if the input is a view (`SubArray`). -# TODO(torfjelde): An alternative is to implement `view` directly instead. function getindex_internal(md::Metadata, vn::VarName) rng = getrange(md, vn) return @view md.vals[rng] From 5c817a46ad0b040a1d39ee71d77861ba5b13d86f Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 6 Nov 2025 13:25:16 +0000 Subject: [PATCH 08/23] implement `is_transformed(::VarNamedVector)` --- src/varnamedvector.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/varnamedvector.jl b/src/varnamedvector.jl index 17b851d1d..a81f33ea5 100644 --- a/src/varnamedvector.jl +++ b/src/varnamedvector.jl @@ -367,6 +367,12 @@ Return a boolean for whether `vn` is guaranteed to have been transformed so that is all of Euclidean space. """ is_transformed(vnv::VarNamedVector, vn::VarName) = vnv.is_unconstrained[getidx(vnv, vn)] +""" + is_transformed(vnv::VarNamedVector) + +Return true if any variable in `vnv` is guaranteed to have been transformed. +""" +is_transformed(vnv::VarNamedVector) = any(vnv.is_unconstrained) """ set_transformed!(vnv::VarNamedVector, val::Bool, vn::VarName) From 93daa2b3286d1cc68e1c176f9f34b001e3561a95 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 6 Nov 2025 13:36:15 +0000 Subject: [PATCH 09/23] Handle errors in benchmark suite --- benchmarks/benchmarks.jl | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/benchmarks/benchmarks.jl b/benchmarks/benchmarks.jl index 035d8ff49..cf5e7daa6 100644 --- a/benchmarks/benchmarks.jl +++ b/benchmarks/benchmarks.jl @@ -87,12 +87,18 @@ results_table = Tuple{String,Int,String,String,Bool,Float64,Float64}[] for (model_name, model, varinfo_choice, adbackend, islinked) in chosen_combinations @info "Running benchmark for $model_name" - suite = make_suite(model, varinfo_choice, adbackend, islinked) - results = run(suite) - eval_time = median(results["evaluation"]).time - relative_eval_time = eval_time / reference_time - ad_eval_time = median(results["gradient"]).time - relative_ad_eval_time = ad_eval_time / eval_time + try + suite = make_suite(model, varinfo_choice, adbackend, islinked) + results = run(suite) + eval_time = median(results["evaluation"]).time + relative_eval_time = eval_time / reference_time + ad_eval_time = median(results["gradient"]).time + relative_ad_eval_time = ad_eval_time / eval_time + catch e + @warn "Benchmark failed for $model_name with error: $e" + relative_eval_time = NaN + relative_ad_eval_time = NaN + end push!( results_table, ( From 41ee7f3d7b58503de91d835c42d394cd6e1818f1 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 6 Nov 2025 13:53:36 +0000 Subject: [PATCH 10/23] make AD testing / benchmarking use FastLDF --- benchmarks/src/DynamicPPLBenchmarks.jl | 4 +- src/fastldf.jl | 2 +- src/test_utils/ad.jl | 9 +-- test/ad.jl | 77 +++++++------------------- 4 files changed, 24 insertions(+), 68 deletions(-) diff --git a/benchmarks/src/DynamicPPLBenchmarks.jl b/benchmarks/src/DynamicPPLBenchmarks.jl index 225e40cd8..e6988d3f2 100644 --- a/benchmarks/src/DynamicPPLBenchmarks.jl +++ b/benchmarks/src/DynamicPPLBenchmarks.jl @@ -94,9 +94,7 @@ function make_suite(model, varinfo_choice::Symbol, adbackend::Symbol, islinked:: vi = DynamicPPL.link(vi, model) end - f = DynamicPPL.LogDensityFunction( - model, DynamicPPL.getlogjoint_internal, vi; adtype=adbackend - ) + f = DynamicPPL.FastLDF(model, DynamicPPL.getlogjoint_internal, vi; adtype=adbackend) # The parameters at which we evaluate f. θ = vi[:] diff --git a/src/fastldf.jl b/src/fastldf.jl index 6c8798d4c..7ec193891 100644 --- a/src/fastldf.jl +++ b/src/fastldf.jl @@ -80,7 +80,7 @@ struct FastLDF{ getlogdensity::Function, # This only works with typed Metadata-varinfo. # Obviously, this can be generalised later. - varinfo::VarInfo{<:NamedTuple{syms}}; + varinfo::VarInfo{<:NamedTuple{syms}}=VarInfo(model); adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, ) where {syms} # Figure out which variable corresponds to which index, and diff --git a/src/test_utils/ad.jl b/src/test_utils/ad.jl index a49ffd18b..fbbae85b7 100644 --- a/src/test_utils/ad.jl +++ b/src/test_utils/ad.jl @@ -4,8 +4,7 @@ using ADTypes: AbstractADType, AutoForwardDiff using Chairmarks: @be import DifferentiationInterface as DI using DocStringExtensions -using DynamicPPL: - Model, LogDensityFunction, VarInfo, AbstractVarInfo, getlogjoint_internal, link +using DynamicPPL: Model, FastLDF, VarInfo, AbstractVarInfo, getlogjoint_internal, link using LogDensityProblems: logdensity, logdensity_and_gradient using Random: AbstractRNG, default_rng using Statistics: median @@ -265,7 +264,7 @@ function run_ad( # Calculate log-density and gradient with the backend of interest verbose && @info "Running AD on $(model.f) with $(adtype)\n" verbose && println(" params : $(params)") - ldf = LogDensityFunction(model, getlogdensity, varinfo; adtype=adtype) + ldf = FastLDF(model, getlogdensity, varinfo; adtype=adtype) value, grad = logdensity_and_gradient(ldf, params) # collect(): https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/754 @@ -282,9 +281,7 @@ function run_ad( value_true = test.value grad_true = test.grad elseif test isa WithBackend - ldf_reference = LogDensityFunction( - model, getlogdensity, varinfo; adtype=test.adtype - ) + ldf_reference = FastLDF(model, getlogdensity, varinfo; adtype=test.adtype) value_true, grad_true = logdensity_and_gradient(ldf_reference, params) # collect(): https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/754 grad_true = collect(grad_true) diff --git a/test/ad.jl b/test/ad.jl index d7505aab2..6d140197e 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -1,4 +1,4 @@ -using DynamicPPL: LogDensityFunction +using DynamicPPL: FastLDF using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest @testset "Automatic differentiation" begin @@ -15,64 +15,25 @@ using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest [AutoReverseDiff(; compile=false), AutoReverseDiff(; compile=true)] end - @testset "Unsupported backends" begin - @model demo() = x ~ Normal() - @test_logs (:warn, r"not officially supported") LogDensityFunction( - demo(); adtype=AutoZygote() - ) - end - @testset "Correctness" begin @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS - rand_param_values = DynamicPPL.TestUtils.rand_prior_true(m) - vns = DynamicPPL.TestUtils.varnames(m) - varinfos = DynamicPPL.TestUtils.setup_varinfos(m, rand_param_values, vns) - - @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos - linked_varinfo = DynamicPPL.link(varinfo, m) - f = LogDensityFunction(m, getlogjoint_internal, linked_varinfo) - x = DynamicPPL.getparams(f) - - # Calculate reference logp + gradient of logp using ForwardDiff - ref_ad_result = run_ad(m, ref_adtype; varinfo=linked_varinfo, test=NoTest()) - ref_logp, ref_grad = ref_ad_result.value_actual, ref_ad_result.grad_actual - - @testset "$adtype" for adtype in test_adtypes - @info "Testing AD on: $(m.f) - $(short_varinfo_name(linked_varinfo)) - $adtype" - - # Put predicates here to avoid long lines - is_mooncake = adtype isa AutoMooncake - is_1_10 = v"1.10" <= VERSION < v"1.11" - is_1_11 = v"1.11" <= VERSION < v"1.12" - is_svi_vnv = - linked_varinfo isa SimpleVarInfo{<:DynamicPPL.VarNamedVector} - is_svi_od = linked_varinfo isa SimpleVarInfo{<:OrderedDict} - - # Mooncake doesn't work with several combinations of SimpleVarInfo. - if is_mooncake && is_1_11 && is_svi_vnv - # https://github.com/compintell/Mooncake.jl/issues/470 - @test_throws ArgumentError DynamicPPL.LogDensityFunction( - m, getlogjoint_internal, linked_varinfo; adtype=adtype - ) - elseif is_mooncake && is_1_10 && is_svi_vnv - # TODO: report upstream - @test_throws UndefRefError DynamicPPL.LogDensityFunction( - m, getlogjoint_internal, linked_varinfo; adtype=adtype - ) - elseif is_mooncake && is_1_10 && is_svi_od - # TODO: report upstream - @test_throws Mooncake.MooncakeRuleCompilationError DynamicPPL.LogDensityFunction( - m, getlogjoint_internal, linked_varinfo; adtype=adtype - ) - else - @test run_ad( - m, - adtype; - varinfo=linked_varinfo, - test=WithExpectedResult(ref_logp, ref_grad), - ) isa Any - end - end + varinfo = VarInfo(m) + linked_varinfo = DynamicPPL.link(varinfo, m) + f = FastLDF(m, getlogjoint_internal, linked_varinfo) + x = DynamicPPL.getparams(f) + + # Calculate reference logp + gradient of logp using ForwardDiff + ref_ad_result = run_ad(m, ref_adtype; varinfo=linked_varinfo, test=NoTest()) + ref_logp, ref_grad = ref_ad_result.value_actual, ref_ad_result.grad_actual + + @testset "$adtype" for adtype in test_adtypes + @info "Testing AD on: $(m.f) - $adtype" + @test run_ad( + m, + adtype; + varinfo=linked_varinfo, + test=WithExpectedResult(ref_logp, ref_grad), + ) isa Any end end end @@ -83,7 +44,7 @@ using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest test_m = randn(2, 3) function eval_logp_and_grad(model, m, adtype) - ldf = LogDensityFunction(model(); adtype=adtype) + ldf = FastLDF(model(); adtype=adtype) return LogDensityProblems.logdensity_and_gradient(ldf, m[:]) end From 22e32a6fbd0ed365bd327a3dd5fbd4d6d3016bfb Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 6 Nov 2025 14:00:40 +0000 Subject: [PATCH 11/23] Fix tests --- src/fastldf.jl | 2 +- test/ad.jl | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/fastldf.jl b/src/fastldf.jl index 7ec193891..61194ab25 100644 --- a/src/fastldf.jl +++ b/src/fastldf.jl @@ -77,7 +77,7 @@ struct FastLDF{ function FastLDF( model::Model, - getlogdensity::Function, + getlogdensity::Function=getlogjoint_internal, # This only works with typed Metadata-varinfo. # Obviously, this can be generalised later. varinfo::VarInfo{<:NamedTuple{syms}}=VarInfo(model); diff --git a/test/ad.jl b/test/ad.jl index 6d140197e..48b1b64ec 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -1,5 +1,6 @@ using DynamicPPL: FastLDF using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest +using Random: Xoshiro @testset "Automatic differentiation" begin # Used as the ground truth that others are compared against. @@ -17,10 +18,10 @@ using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest @testset "Correctness" begin @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS - varinfo = VarInfo(m) + varinfo = VarInfo(Xoshiro(468), m) linked_varinfo = DynamicPPL.link(varinfo, m) f = FastLDF(m, getlogjoint_internal, linked_varinfo) - x = DynamicPPL.getparams(f) + x = linked_varinfo[:] # Calculate reference logp + gradient of logp using ForwardDiff ref_ad_result = run_ad(m, ref_adtype; varinfo=linked_varinfo, test=NoTest()) From 79cc1286b87d35fe082f2f02b803af4ebb2cc653 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 6 Nov 2025 14:18:08 +0000 Subject: [PATCH 12/23] Optimise away `make_evaluate_args_and_kwargs` --- src/fastldf.jl | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/src/fastldf.jl b/src/fastldf.jl index 61194ab25..ebaf002b4 100644 --- a/src/fastldf.jl +++ b/src/fastldf.jl @@ -147,10 +147,21 @@ function (f::FastLogDensityAt)(params::AbstractVector{<:Real}) accs = AccumulatorTuple(( LogPriorAccumulator(), LogLikelihoodAccumulator(), LogJacobianAccumulator() )) - _, vi = DynamicPPL._evaluate!!(model, OnlyAccsVarInfo(accs)) + # _, vi = DynamicPPL._evaluate!!(model, OnlyAccsVarInfo(accs)) + args = map(maybe_deepcopy, model.args) + _, vi = model.f(model, OnlyAccsVarInfo(accs), args...; model.defaults...) return f._getlogdensity(vi) end +maybe_deepcopy(@nospecialize(x)) = x +function maybe_deepcopy(x::AbstractArray{T}) where {T} + if T >: Missing + deepcopy(x) + else + x + end +end + function LogDensityProblems.logdensity(fldf::FastLDF, params::AbstractVector{<:Real}) return FastLogDensityAt( fldf._model, fldf._getlogdensity, fldf._iden_varname_ranges, fldf._varname_ranges From f7c6a78ba3446246abda6798a77513f6893aa49a Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 6 Nov 2025 14:23:58 +0000 Subject: [PATCH 13/23] const func annotation --- test/integration/enzyme/main.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/integration/enzyme/main.jl b/test/integration/enzyme/main.jl index b40bbeb8f..ea4ec497d 100644 --- a/test/integration/enzyme/main.jl +++ b/test/integration/enzyme/main.jl @@ -6,8 +6,10 @@ import Enzyme: set_runtime_activity, Forward, Reverse, Const using ForwardDiff: ForwardDiff # run_ad uses FD for correctness test ADTYPES = Dict( - "EnzymeForward" => AutoEnzyme(; mode=set_runtime_activity(Forward)), - "EnzymeReverse" => AutoEnzyme(; mode=set_runtime_activity(Reverse)), + "EnzymeForward" => + AutoEnzyme(; mode=set_runtime_activity(Forward), function_annotation=Const), + "EnzymeReverse" => + AutoEnzyme(; mode=set_runtime_activity(Reverse), function_annotation=Const), ) @testset "$ad_key" for (ad_key, ad_type) in ADTYPES From b1a76509b6e88ac7f6099039bd042d3d502d4848 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 6 Nov 2025 14:28:34 +0000 Subject: [PATCH 14/23] Disable benchmarks on non-typed-Metadata-VarInfo --- benchmarks/benchmarks.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/benchmarks/benchmarks.jl b/benchmarks/benchmarks.jl index 035d8ff49..5fe0320cc 100644 --- a/benchmarks/benchmarks.jl +++ b/benchmarks/benchmarks.jl @@ -59,11 +59,11 @@ chosen_combinations = [ false, ), ("Smorgasbord", smorgasbord_instance, :typed, :forwarddiff, false), - ("Smorgasbord", smorgasbord_instance, :simple_namedtuple, :forwarddiff, true), - ("Smorgasbord", smorgasbord_instance, :untyped, :forwarddiff, true), - ("Smorgasbord", smorgasbord_instance, :simple_dict, :forwarddiff, true), - ("Smorgasbord", smorgasbord_instance, :typed_vector, :forwarddiff, true), - ("Smorgasbord", smorgasbord_instance, :untyped_vector, :forwarddiff, true), + # ("Smorgasbord", smorgasbord_instance, :simple_namedtuple, :forwarddiff, true), + # ("Smorgasbord", smorgasbord_instance, :untyped, :forwarddiff, true), + # ("Smorgasbord", smorgasbord_instance, :simple_dict, :forwarddiff, true), + # ("Smorgasbord", smorgasbord_instance, :typed_vector, :forwarddiff, true), + # ("Smorgasbord", smorgasbord_instance, :untyped_vector, :forwarddiff, true), ("Smorgasbord", smorgasbord_instance, :typed, :reversediff, true), ("Smorgasbord", smorgasbord_instance, :typed, :mooncake, true), ("Smorgasbord", smorgasbord_instance, :typed, :enzyme, true), From e60873a7e2f000733ef1390bcca7e164776e56db Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 6 Nov 2025 14:36:34 +0000 Subject: [PATCH 15/23] Fix `_evaluate!!` correctly to handle submodels --- src/fastldf.jl | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/src/fastldf.jl b/src/fastldf.jl index ebaf002b4..309155606 100644 --- a/src/fastldf.jl +++ b/src/fastldf.jl @@ -133,6 +133,23 @@ struct FastLDF{ end end +function _evaluate!!( + model::Model{F,A,D,M,TA,TD,<:FastLDFContext}, varinfo::OnlyAccsVarInfo +) where {F,A,D,M,TA,TD} + args = map(maybe_deepcopy, model.args) + return model.f(model, varinfo, args...; model.defaults...) +end +maybe_deepcopy(@nospecialize(x)) = x +function maybe_deepcopy(x::AbstractArray{T}) where {T} + if T >: Missing + # avoid overwriting missing elements of model arguments when + # evaluating the model. + deepcopy(x) + else + x + end +end + struct FastLogDensityAt{M<:Model,F<:Function,N<:NamedTuple} _model::M _getlogdensity::F @@ -147,21 +164,10 @@ function (f::FastLogDensityAt)(params::AbstractVector{<:Real}) accs = AccumulatorTuple(( LogPriorAccumulator(), LogLikelihoodAccumulator(), LogJacobianAccumulator() )) - # _, vi = DynamicPPL._evaluate!!(model, OnlyAccsVarInfo(accs)) - args = map(maybe_deepcopy, model.args) - _, vi = model.f(model, OnlyAccsVarInfo(accs), args...; model.defaults...) + _, vi = _evaluate!!(model, OnlyAccsVarInfo(accs)) return f._getlogdensity(vi) end -maybe_deepcopy(@nospecialize(x)) = x -function maybe_deepcopy(x::AbstractArray{T}) where {T} - if T >: Missing - deepcopy(x) - else - x - end -end - function LogDensityProblems.logdensity(fldf::FastLDF, params::AbstractVector{<:Real}) return FastLogDensityAt( fldf._model, fldf._getlogdensity, fldf._iden_varname_ranges, fldf._varname_ranges From fa0664ee9ec3c6d0eae891af1508423ac77f5643 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 6 Nov 2025 14:41:51 +0000 Subject: [PATCH 16/23] Actually fix submodel evaluate --- src/fastldf.jl | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/fastldf.jl b/src/fastldf.jl index 309155606..b1794ffa2 100644 --- a/src/fastldf.jl +++ b/src/fastldf.jl @@ -2,7 +2,6 @@ struct OnlyAccsVarInfo{Accs<:AccumulatorTuple} <: AbstractVarInfo accs::Accs end DynamicPPL.getaccs(vi::OnlyAccsVarInfo) = vi.accs -DynamicPPL.maybe_invlink_before_eval!!(vi::OnlyAccsVarInfo, ::Model) = vi DynamicPPL.setaccs!!(::OnlyAccsVarInfo, accs::AccumulatorTuple) = OnlyAccsVarInfo(accs) struct RangeAndLinked @@ -133,11 +132,13 @@ struct FastLDF{ end end -function _evaluate!!( - model::Model{F,A,D,M,TA,TD,<:FastLDFContext}, varinfo::OnlyAccsVarInfo -) where {F,A,D,M,TA,TD} - args = map(maybe_deepcopy, model.args) - return model.f(model, varinfo, args...; model.defaults...) +function _evaluate!!(model::Model, varinfo::OnlyAccsVarInfo) + if leafcontext(model.context) isa FastLDFContext + args = map(maybe_deepcopy, model.args) + return model.f(model, varinfo, args...; model.defaults...) + else + error("Shouldn't happen") + end end maybe_deepcopy(@nospecialize(x)) = x function maybe_deepcopy(x::AbstractArray{T}) where {T} From 09a1fbb4787cf70fd42c794ab4365ff99395964d Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 6 Nov 2025 17:04:26 +0000 Subject: [PATCH 17/23] Document thoroughly and organise code --- src/compiler.jl | 38 +++--- src/fastldf.jl | 341 +++++++++++++++++++++++++++++++++++++----------- src/model.jl | 20 ++- 3 files changed, 306 insertions(+), 93 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index badba9f9d..3324780ca 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -718,14 +718,15 @@ end # TODO(mhauru) matchingvalue has methods that can accept both types and values. Why? # TODO(mhauru) This function needs a more comprehensive docstring. """ - matchingvalue(vi, value) + matchingvalue(param_eltype, value) -Convert the `value` to the correct type for the `vi` object. +Convert the `value` to the correct type, given the element type of the parameters +being used to evaluate the model. """ -function matchingvalue(vi, value) +function matchingvalue(param_eltype, value) T = typeof(value) if hasmissing(T) - _value = convert(get_matching_type(vi, T), value) + _value = convert(get_matching_type(param_eltype, T), value) # TODO(mhauru) Why do we make a deepcopy, even though in the !hasmissing branch we # are happy to return `value` as-is? if _value === value @@ -738,29 +739,30 @@ function matchingvalue(vi, value) end end -function matchingvalue(vi, value::FloatOrArrayType) - return get_matching_type(vi, value) +function matchingvalue(param_eltype, value::FloatOrArrayType) + return get_matching_type(param_eltype, value) end -function matchingvalue(vi, ::TypeWrap{T}) where {T} - return TypeWrap{get_matching_type(vi, T)}() +function matchingvalue(param_eltype, ::TypeWrap{T}) where {T} + return TypeWrap{get_matching_type(param_eltype, T)}() end # TODO(mhauru) This function needs a more comprehensive docstring. What is it for? """ - get_matching_type(vi, ::TypeWrap{T}) where {T} + get_matching_type(param_eltype, ::TypeWrap{T}) where {T} -Get the specialized version of type `T` for `vi`. +Get the specialized version of type `T`, given an element type of the parameters +being used to evaluate the model. """ get_matching_type(_, ::Type{T}) where {T} = T -function get_matching_type(vi, ::Type{<:Union{Missing,AbstractFloat}}) - return Union{Missing,float_type_with_fallback(eltype(vi))} +function get_matching_type(param_eltype, ::Type{<:Union{Missing,AbstractFloat}}) + return Union{Missing,float_type_with_fallback(param_eltype)} end -function get_matching_type(vi, ::Type{<:AbstractFloat}) - return float_type_with_fallback(eltype(vi)) +function get_matching_type(param_eltype, ::Type{<:AbstractFloat}) + return float_type_with_fallback(param_eltype) end -function get_matching_type(vi, ::Type{<:Array{T,N}}) where {T,N} - return Array{get_matching_type(vi, T),N} +function get_matching_type(param_eltype, ::Type{<:Array{T,N}}) where {T,N} + return Array{get_matching_type(param_eltype, T),N} end -function get_matching_type(vi, ::Type{<:Array{T}}) where {T} - return Array{get_matching_type(vi, T)} +function get_matching_type(param_eltype, ::Type{<:Array{T}}) where {T} + return Array{get_matching_type(param_eltype, T)} end diff --git a/src/fastldf.jl b/src/fastldf.jl index b1794ffa2..215202230 100644 --- a/src/fastldf.jl +++ b/src/fastldf.jl @@ -1,9 +1,108 @@ +""" +fasteval.jl +----------- + +Up until DynamicPPL v0.38, there have been two ways of evaluating a DynamicPPL model at a +given set of parameters: + +1. With `unflatten` + `evaluate!!` with `DefaultContext`: this stores a vector of parameters + inside a VarInfo's metadata, then reads parameter values from the VarInfo during evaluation. + +2. With `InitFromParams`: this reads parameter values from a NamedTuple or a Dict, and stores + them inside a VarInfo's metadata. + +In general, both of these approaches work fine, but the fact that they modify the VarInfo's +metadata can often be quite wasteful. In particular, it is very common that the only outputs +we care about from model evaluation are those which are stored in accumulators, such as log +probability densities, or `ValuesAsInModel`. + +To avoid this issue, we implement here `OnlyAccsVarInfo`, which is a VarInfo that only +contains accumulators. When evaluating a model with `OnlyAccsVarInfo`, it is mandatory that +the model's leaf context is a `FastEvalContext`, which provides extremely fast access to +parameter values. No writing of values into VarInfo metadata is performed at all. + +Vector parameters +----------------- + +We first consider the case of parameter vectors, i.e., the case which would normally be +handled by `unflatten` and `evaluate!!`. Unfortunately, it is not enough to just store +the vector of parameters in the `FastEvalContext`, because it is not clear: + + - which parts of the vector correspond to which random variables, and + - whether the variables are linked or unlinked. + +Traditionally, this problem has been solved by `unflatten`, because that function would +place values into the VarInfo's metadata alongside the information about ranges and linking. +However, we want to avoid doing this. Thus, here, we _extract this information from the +VarInfo_ a single time when constructing a `FastLDF` object. + +Note that this assumes that the ranges and link status are static throughout the lifetime of +the `FastLDF` object. Therefore, a `FastLDF` object cannot handle models which have variable +numbers of parameters, or models which may visit random variables in different orders depending +on stochastic control flow. **Indeed, silent errors may occur with such models.** This is a +general limitation of vectorised parameters: the original `unflatten` + `evaluate!!` +approach also fails with such models. + +NamedTuple and Dict parameters +------------------------------ + +Fast evaluation has not yet been extended to NamedTuple and Dict parameters. Such +representations are capable of handling models with variable sizes and stochastic control +flow. + +However, the path towards implementing these is straightforward: + +1. Currently, `FastLDFVectorContext` allows users to input a VarName and obtain the parameter + value, plus a boolean indicating whether the value is linked or unlinked. See the + `get_range_and_linked` function for details. + +2. We would need to implement similar contexts for NamedTuple and Dict parameters. The + functionality would be quite similar to `InitContext(InitFromParams(...))`. +""" + +""" + OnlyAccsVarInfo + +This is a wrapper around an `AccumulatorTuple` that implements the minimal `AbstractVarInfo` +interface to work with the `accumulate_assume!!` and `accumulate_observe!!` functions. + +Note that this does not implement almost every other AbstractVarInfo interface function, and +so using this outside of FastLDF will lead to errors. + +Conceptually, one can also think of this as a VarInfo that doesn't contain a metadata field. +That is because values for random variables are obtained by reading from a separate entity +(such as a `FastLDFContext`), rather than from the VarInfo itself. +""" struct OnlyAccsVarInfo{Accs<:AccumulatorTuple} <: AbstractVarInfo accs::Accs end +OnlyAccsVarInfo() = OnlyAccsVarInfo(default_accumulators()) DynamicPPL.getaccs(vi::OnlyAccsVarInfo) = vi.accs DynamicPPL.setaccs!!(::OnlyAccsVarInfo, accs::AccumulatorTuple) = OnlyAccsVarInfo(accs) +function DynamicPPL.get_param_eltype(::OnlyAccsVarInfo, model::Model) + # Because the VarInfo has no parameters stored in it, we need to get the eltype from the + # model's leaf context. This is only possible if said leaf context is indeed a FastEval + # context. + leaf_ctx = DynamicPPL.leafcontext(model) + if leaf_ctx isa FastEvalVectorContext + return eltype(leaf_ctx.params) + else + error( + "OnlyAccsVarInfo can only be used with FastEval contexts, found $(typeof(leaf_ctx))", + ) + end +end +""" + RangeAndLinked + +Suppose we have vectorised parameters `params::AbstractVector{<:Real}`. Each random variable +in the model will in general correspond to a sub-vector of `params`. This struct stores +information about that range, as well as whether the sub-vector represents a linked value or +an unlinked value. + +$(TYPEDFIELDS) +""" struct RangeAndLinked # indices that the variable corresponds to in the vectorised parameter range::UnitRange{Int} @@ -11,30 +110,55 @@ struct RangeAndLinked is_linked::Bool end -struct FastLDFContext{N<:NamedTuple,T<:AbstractVector{<:Real}} <: AbstractContext - # The ranges of identity VarNames are stored in a NamedTuple for improved performance - # (it's around 1.5x faster). +""" + AbstractFastEvalContext + +Abstract type representing fast evaluation contexts. This currently is only subtyped by +`FastEvalVectorContext`. However, in the future, similar contexts may be implemented for +NamedTuple and Dict parameters. +""" +abstract type AbstractFastEvalContext <: AbstractContext end +DynamicPPL.NodeTrait(::AbstractFastEvalContext) = IsLeaf() + +""" + FastEvalVectorContext( + iden_varname_ranges::NamedTuple, + varname_ranges::Dict{VarName,RangeAndLinked}, + params::AbstractVector{<:Real}, + ) + +A context that wraps a vector of parameter values, plus information about how random +variables map to ranges in that vector. + +In the simplest case, this could be accomplished only with a single dictionary mapping +VarNames to ranges and link status. However, for performance reasons, we separate out +VarNames with identity optics into a NamedTuple (`iden_varname_ranges`). All +non-identity-optic VarNames are stored in the `varname_ranges` Dict. + +It would be nice to unify the NamedTuple and Dict approach. See, e.g. +https://github.com/TuringLang/DynamicPPL.jl/issues/1116. +""" +struct FastEvalVectorContext{N<:NamedTuple,T<:AbstractVector{<:Real}} <: AbstractContext + # This NamedTuple stores the ranges for identity VarNames iden_varname_ranges::N # This Dict stores the ranges for all other VarNames varname_ranges::Dict{VarName,RangeAndLinked} # The full parameter vector which we index into to get variable values params::T end -DynamicPPL.NodeTrait(::FastLDFContext) = IsLeaf() function get_range_and_linked( - ctx::FastLDFContext, ::VarName{sym,typeof(identity)} + ctx::FastEvalVectorContext, ::VarName{sym,typeof(identity)} ) where {sym} return ctx.iden_varname_ranges[sym] end -function get_range_and_linked(ctx::FastLDFContext, vn::VarName) +function get_range_and_linked(ctx::FastEvalVectorContext, vn::VarName) return ctx.varname_ranges[vn] end function tilde_assume!!( - ctx::FastLDFContext, right::Distribution, vn::VarName, vi::OnlyAccsVarInfo + ctx::FastEvalVectorContext, right::Distribution, vn::VarName, vi::AbstractVarInfo ) - # Don't need to read the data from the varinfo at all since it's - # all inside the context. + # Note that this function does not use the metadata field of `vi` at all. range_and_linked = get_range_and_linked(ctx, vn) y = @view ctx.params[range_and_linked.range] f = if range_and_linked.is_linked @@ -48,64 +172,111 @@ function tilde_assume!!( end function tilde_observe!!( - ::FastLDFContext, + ::FastEvalVectorContext, right::Distribution, left, vn::Union{VarName,Nothing}, - vi::OnlyAccsVarInfo, + vi::AbstractVarInfo, ) # This is the same as for DefaultContext vi = accumulate_observe!!(vi, right, left, vn) return left, vi end +######################################## +# Log-density functions using FastEval # +######################################## + +""" + FastLDF( + model::Model, + getlogdensity::Function=getlogjoint_internal, + varinfo::AbstractVarInfo=VarInfo(model); + adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, + ) + +A struct which contains a model, along with all the information necessary to: + + - calculate its log density at a given point; + - and if `adtype` is provided, calculate the gradient of the log density at that point. + +This information can be extracted using the LogDensityProblems.jl interface, specifically, +using `LogDensityProblems.logdensity` and `LogDensityProblems.logdensity_and_gradient`. If +`adtype` is nothing, then only `logdensity` is implemented. If `adtype` is a concrete AD +backend type, then `logdensity_and_gradient` is also implemented. + +There are several options for `getlogdensity` that are 'supported' out of the box: + +- [`getlogjoint_internal`](@ref): calculate the log joint, including the log-Jacobian term + for any variables that have been linked in the provided VarInfo. +- [`getlogprior_internal`](@ref): calculate the log prior, including the log-Jacobian term + for any variables that have been linked in the provided VarInfo. +- [`getlogjoint`](@ref): calculate the log joint in the model space, ignoring any effects of + linking +- [`getlogprior`](@ref): calculate the log prior in the model space, ignoring any effects of + linking +- [`getloglikelihood`](@ref): calculate the log likelihood (this is unaffected by linking, + since transforms are only applied to random variables) + +!!! note + By default, `FastLDF` uses `getlogjoint_internal`, i.e., the result of + `LogDensityProblems.logdensity(f, x)` will depend on whether the `FastLDF` was created + with a linked or unlinked VarInfo. This is done primarily to ease interoperability with + MCMC samplers. + +If you provide one of these functions, a `VarInfo` will be automatically created for you. If +you provide a different function, you have to manually create a VarInfo and pass it as the +third argument. + +If the `adtype` keyword argument is provided, then this struct will also store the adtype +along with other information for efficient calculation of the gradient of the log density. +Note that preparing a `FastLDF` with an AD type `AutoBackend()` requires the AD backend +itself to have been loaded (e.g. with `import Backend`). + +## Fields + +Note that it is undefined behaviour to access any of a `FastLDF`'s fields, apart from: + +- `fastldf.model`: The original model from which this `FastLDF` was constructed. +- `fastldf.adtype`: The AD type used for gradient calculations, or `nothing` if no AD + type was provided. + +## Extended help + +`FastLDF` uses `FastEvalVectorContext` internally to provide extremely rapid evaluation of +the model given a vector of parameters. + +Because it is common to call `LogDensityProblems.logdensity` and +`LogDensityProblems.logdensity_and_gradient` within tight loops, it is beneficial for us to +pre-compute as much of the information as possible when constructing the `FastLDF` object. +In particular, we use the provided VarInfo's metadata to extract the mapping from VarNames +to ranges and link status, and store this mapping inside the `FastLDF` object. We can later +use this to construct a FastEvalVectorContext, without having to look into a metadata again. +""" struct FastLDF{ M<:Model, + AD<:Union{ADTypes.AbstractADType,Nothing}, F<:Function, N<:NamedTuple, - AD<:Union{ADTypes.AbstractADType,Nothing}, ADP<:Union{Nothing,DI.GradientPrep}, } - _model::M + model::M + adtype::AD _getlogdensity::F - # See FastLDFContext for explanation of these two fields + # See FastLDFContext for explanation of these two fields. _iden_varname_ranges::N _varname_ranges::Dict{VarName,RangeAndLinked} - _adtype::AD _adprep::ADP function FastLDF( model::Model, getlogdensity::Function=getlogjoint_internal, - # This only works with typed Metadata-varinfo. - # Obviously, this can be generalised later. - varinfo::VarInfo{<:NamedTuple{syms}}=VarInfo(model); + varinfo::AbstractVarInfo=VarInfo(model); adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, - ) where {syms} + ) # Figure out which variable corresponds to which index, and # which variables are linked. - all_iden_ranges = NamedTuple() - all_ranges = Dict{VarName,RangeAndLinked}() - offset = 1 - for sym in syms - md = varinfo.metadata[sym] - for (vn, idx) in md.idcs - len = length(md.ranges[idx]) - is_linked = md.is_transformed[idx] - range = offset:(offset + len - 1) - if AbstractPPL.getoptic(vn) === identity - all_iden_ranges = merge( - all_iden_ranges, - NamedTuple(( - AbstractPPL.getsym(vn) => RangeAndLinked(range, is_linked), - )), - ) - else - all_ranges[vn] = RangeAndLinked(range, is_linked) - end - offset += len - end - end + all_iden_ranges, all_ranges = get_ranges_and_linked(varinfo) # Do AD prep if needed prep = if adtype === nothing nothing @@ -119,37 +290,37 @@ struct FastLDF{ x, ) end - return new{ typeof(model), + typeof(adtype), typeof(getlogdensity), typeof(all_iden_ranges), - typeof(adtype), typeof(prep), }( - model, getlogdensity, all_iden_ranges, all_ranges, adtype, prep + model, adtype, getlogdensity, all_iden_ranges, all_ranges, prep ) end end -function _evaluate!!(model::Model, varinfo::OnlyAccsVarInfo) - if leafcontext(model.context) isa FastLDFContext - args = map(maybe_deepcopy, model.args) - return model.f(model, varinfo, args...; model.defaults...) - else - error("Shouldn't happen") - end +################################### +# LogDensityProblems.jl interface # +################################### +""" + fast_ldf_accs(getlogdensity::Function) + +Determine which accumulators are needed for fast evaluation with the given +`getlogdensity` function. +""" +fast_ldf_accs(::Function) = default_accumulators() +fast_ldf_accs(::typeof(getlogjoint_internal)) = default_accumulators() +function fast_ldf_accs(::typeof(getlogjoint)) + return AccumulatorTuple((LogPriorAccumulator(), LogLikelihoodAccumulator())) end -maybe_deepcopy(@nospecialize(x)) = x -function maybe_deepcopy(x::AbstractArray{T}) where {T} - if T >: Missing - # avoid overwriting missing elements of model arguments when - # evaluating the model. - deepcopy(x) - else - x - end +function fast_ldf_accs(::typeof(getlogprior_internal)) + return AccumulatorTuple((LogPriorAccumulator(), LogJacobianAccumulator())) end +fast_ldf_accs(::typeof(getlogprior)) = AccumulatorTuple((LogPriorAccumulator(),)) +fast_ldf_accs(::typeof(getloglikelihood)) = AccumulatorTuple((LogLikelihoodAccumulator(),)) struct FastLogDensityAt{M<:Model,F<:Function,N<:NamedTuple} _model::M @@ -158,20 +329,15 @@ struct FastLogDensityAt{M<:Model,F<:Function,N<:NamedTuple} _varname_ranges::Dict{VarName,RangeAndLinked} end function (f::FastLogDensityAt)(params::AbstractVector{<:Real}) - ctx = FastLDFContext(f._iden_varname_ranges, f._varname_ranges, params) + ctx = FastEvalVectorContext(f._iden_varname_ranges, f._varname_ranges, params) model = DynamicPPL.setleafcontext(f._model, ctx) - # This can obviously also be optimised for the case where not - # all accumulators are needed. - accs = AccumulatorTuple(( - LogPriorAccumulator(), LogLikelihoodAccumulator(), LogJacobianAccumulator() - )) - _, vi = _evaluate!!(model, OnlyAccsVarInfo(accs)) + _, vi = _evaluate!!(model, OnlyAccsVarInfo(fast_ldf_accs(f._getlogdensity))) return f._getlogdensity(vi) end function LogDensityProblems.logdensity(fldf::FastLDF, params::AbstractVector{<:Real}) return FastLogDensityAt( - fldf._model, fldf._getlogdensity, fldf._iden_varname_ranges, fldf._varname_ranges + fldf.model, fldf._getlogdensity, fldf._iden_varname_ranges, fldf._varname_ranges )( params ) @@ -182,13 +348,42 @@ function LogDensityProblems.logdensity_and_gradient( ) return DI.value_and_gradient( FastLogDensityAt( - fldf._model, - fldf._getlogdensity, - fldf._iden_varname_ranges, - fldf._varname_ranges, + fldf.model, fldf._getlogdensity, fldf._iden_varname_ranges, fldf._varname_ranges ), fldf._adprep, - fldf._adtype, + fldf.adtype, params, ) end + +###################################################### +# Helper functions to extract ranges and link status # +###################################################### + +# TODO: Fails for other VarInfo types. +function get_ranges_and_linked(varinfo::VarInfo{<:NamedTuple{syms}}) where {syms} + all_iden_ranges = NamedTuple() + all_ranges = Dict{VarName,RangeAndLinked}() + offset = 1 + for sym in syms + md = varinfo.metadata[sym] + # TODO: Fails for VarNamedVector. + for (vn, idx) in md.idcs + len = length(md.ranges[idx]) + is_linked = md.is_transformed[idx] + range = offset:(offset + len - 1) + if AbstractPPL.getoptic(vn) === identity + all_iden_ranges = merge( + all_iden_ranges, + NamedTuple(( + AbstractPPL.getsym(vn) => RangeAndLinked(range, is_linked), + )), + ) + else + all_ranges[vn] = RangeAndLinked(range, is_linked) + end + offset += len + end + end + return all_iden_ranges, all_ranges +end diff --git a/src/model.jl b/src/model.jl index edb042ba9..6ca06aea6 100644 --- a/src/model.jl +++ b/src/model.jl @@ -986,9 +986,9 @@ Return the arguments and keyword arguments to be passed to the evaluator of the ) where {_F,argnames} unwrap_args = [ if is_splat_symbol(var) - :($matchingvalue(varinfo, model.args.$var)...) + :($matchingvalue($get_param_eltype(varinfo, model), model.args.$var)...) else - :($matchingvalue(varinfo, model.args.$var)) + :($matchingvalue($get_param_eltype(varinfo, model), model.args.$var)) end for var in argnames ] return quote @@ -1006,6 +1006,22 @@ Return the arguments and keyword arguments to be passed to the evaluator of the end end +""" + get_param_eltype(varinfo::AbstractVarInfo, model::Model) + +Get the element type of the parameters being used to evaluate the `model` from the +`varinfo`. For example, when performing AD with ForwardDiff, this should return +`ForwardDiff.Dual`. + +By default, this uses `eltype(varinfo)` which is slightly cursed. This relies on the fact +that typically, before evaluation, the parameters will have been inserted into the VarInfo's +metadata field. + +See `OnlyAccsVarInfo` for an example of where this is not true (the parameters are instead +stored in the model's context). +""" +get_param_eltype(varinfo::AbstractVarInfo, ::Model) = eltype(varinfo) + """ getargnames(model::Model) From 7306ba46158cffa4a1dca405f5fe8fd68930046c Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 6 Nov 2025 17:27:57 +0000 Subject: [PATCH 18/23] Support more VarInfos, make it thread-safe (?) --- src/fastldf.jl | 99 +++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 78 insertions(+), 21 deletions(-) diff --git a/src/fastldf.jl b/src/fastldf.jl index 215202230..87dd698dc 100644 --- a/src/fastldf.jl +++ b/src/fastldf.jl @@ -77,13 +77,14 @@ struct OnlyAccsVarInfo{Accs<:AccumulatorTuple} <: AbstractVarInfo accs::Accs end OnlyAccsVarInfo() = OnlyAccsVarInfo(default_accumulators()) +DynamicPPL.maybe_invlink_before_eval!!(vi::OnlyAccsVarInfo, ::Model) = vi DynamicPPL.getaccs(vi::OnlyAccsVarInfo) = vi.accs DynamicPPL.setaccs!!(::OnlyAccsVarInfo, accs::AccumulatorTuple) = OnlyAccsVarInfo(accs) function DynamicPPL.get_param_eltype(::OnlyAccsVarInfo, model::Model) # Because the VarInfo has no parameters stored in it, we need to get the eltype from the # model's leaf context. This is only possible if said leaf context is indeed a FastEval # context. - leaf_ctx = DynamicPPL.leafcontext(model) + leaf_ctx = DynamicPPL.leafcontext(model.context) if leaf_ctx isa FastEvalVectorContext return eltype(leaf_ctx.params) else @@ -138,7 +139,8 @@ non-identity-optic VarNames are stored in the `varname_ranges` Dict. It would be nice to unify the NamedTuple and Dict approach. See, e.g. https://github.com/TuringLang/DynamicPPL.jl/issues/1116. """ -struct FastEvalVectorContext{N<:NamedTuple,T<:AbstractVector{<:Real}} <: AbstractContext +struct FastEvalVectorContext{N<:NamedTuple,T<:AbstractVector{<:Real}} <: + AbstractFastEvalContext # This NamedTuple stores the ranges for identity VarNames iden_varname_ranges::N # This Dict stores the ranges for all other VarNames @@ -331,7 +333,17 @@ end function (f::FastLogDensityAt)(params::AbstractVector{<:Real}) ctx = FastEvalVectorContext(f._iden_varname_ranges, f._varname_ranges, params) model = DynamicPPL.setleafcontext(f._model, ctx) - _, vi = _evaluate!!(model, OnlyAccsVarInfo(fast_ldf_accs(f._getlogdensity))) + only_accs_vi = OnlyAccsVarInfo(fast_ldf_accs(f._getlogdensity)) + # Calling `evaluate!!` would be fine, but would lead to an extra call to resetaccs!!, + # which is unnecessary. So we shortcircuit this by simply calling `_evaluate!!` + # directly. To preserve thread-safety we need to reproduce the ThreadSafeVarInfo logic + # here. + vi = if Threads.nthreads() > 1 + ThreadSafeVarInfo(only_accs_vi) + else + only_accs_vi + end + _, vi = _evaluate!!(model, vi) return f._getlogdensity(vi) end @@ -360,30 +372,75 @@ end # Helper functions to extract ranges and link status # ###################################################### -# TODO: Fails for other VarInfo types. +# TODO: Fails for SimpleVarInfo. Do I really care enough? Ehhh, honestly, debatable. + +""" + get_ranges_and_linked(varinfo::VarInfo) + +Given a `VarInfo`, extract the ranges of each variable in the vectorised parameter +representation, along with whether each variable is linked or unlinked. + +This function should return a tuple containing: + +- A NamedTuple mapping VarNames with identity optics to their corresponding `RangeAndLinked` +- A Dict mapping all other VarNames to their corresponding `RangeAndLinked`. +""" function get_ranges_and_linked(varinfo::VarInfo{<:NamedTuple{syms}}) where {syms} all_iden_ranges = NamedTuple() all_ranges = Dict{VarName,RangeAndLinked}() offset = 1 for sym in syms md = varinfo.metadata[sym] - # TODO: Fails for VarNamedVector. - for (vn, idx) in md.idcs - len = length(md.ranges[idx]) - is_linked = md.is_transformed[idx] - range = offset:(offset + len - 1) - if AbstractPPL.getoptic(vn) === identity - all_iden_ranges = merge( - all_iden_ranges, - NamedTuple(( - AbstractPPL.getsym(vn) => RangeAndLinked(range, is_linked), - )), - ) - else - all_ranges[vn] = RangeAndLinked(range, is_linked) - end - offset += len - end + this_md_iden, this_md_others, new_offset = get_ranges_and_linked_metadata( + md, offset + ) + all_iden_ranges = merge(all_iden_ranges, this_md_iden) + all_ranges = merge(all_ranges, this_md_others) + offset = new_offset end return all_iden_ranges, all_ranges end +function get_ranges_and_linked(varinfo::VarInfo{<:Metadata}) + all_iden, all_others, _ = get_ranges_and_linked_metadata(varinfo.metadata, 1) + return all_iden, all_others +end +function get_ranges_and_linked_metadata(md::Metadata, start_offset::Int) + all_iden_ranges = NamedTuple() + all_ranges = Dict{VarName,RangeAndLinked}() + offset = start_offset + for (vn, idx) in md.idcs + len = length(md.ranges[idx]) + is_linked = md.is_transformed[idx] + range = offset:(offset + len - 1) + if AbstractPPL.getoptic(vn) === identity + all_iden_ranges = merge( + all_iden_ranges, + NamedTuple((AbstractPPL.getsym(vn) => RangeAndLinked(range, is_linked),)), + ) + else + all_ranges[vn] = RangeAndLinked(range, is_linked) + end + offset += len + end + return all_iden_ranges, all_ranges, offset +end +function get_ranges_and_linked_metadata(vnv::VarNamedVector, start_offset::Int) + all_iden_ranges = NamedTuple() + all_ranges = Dict{VarName,RangeAndLinked}() + offset = start_offset + for (vn, idx) in vnv.varname_to_index + len = length(vnv.ranges[idx]) + is_linked = vnv.is_unconstrained[idx] + range = offset:(offset + len - 1) + if AbstractPPL.getoptic(vn) === identity + all_iden_ranges = merge( + all_iden_ranges, + NamedTuple((AbstractPPL.getsym(vn) => RangeAndLinked(range, is_linked),)), + ) + else + all_ranges[vn] = RangeAndLinked(range, is_linked) + end + offset += len + end + return all_iden_ranges, all_ranges, offset +end From 53bccc13dbd7976a2e12a4c3057ba71c2d56e1ad Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 6 Nov 2025 17:46:06 +0000 Subject: [PATCH 19/23] fix bug in parsing ranges from metadata/VNV --- src/fastldf.jl | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/src/fastldf.jl b/src/fastldf.jl index 87dd698dc..15326a83b 100644 --- a/src/fastldf.jl +++ b/src/fastldf.jl @@ -391,16 +391,13 @@ function get_ranges_and_linked(varinfo::VarInfo{<:NamedTuple{syms}}) where {syms offset = 1 for sym in syms md = varinfo.metadata[sym] - this_md_iden, this_md_others, new_offset = get_ranges_and_linked_metadata( - md, offset - ) + this_md_iden, this_md_others, offset = get_ranges_and_linked_metadata(md, offset) all_iden_ranges = merge(all_iden_ranges, this_md_iden) all_ranges = merge(all_ranges, this_md_others) - offset = new_offset end return all_iden_ranges, all_ranges end -function get_ranges_and_linked(varinfo::VarInfo{<:Metadata}) +function get_ranges_and_linked(varinfo::VarInfo{<:Union{Metadata,VarNamedVector}}) all_iden, all_others, _ = get_ranges_and_linked_metadata(varinfo.metadata, 1) return all_iden, all_others end @@ -409,9 +406,8 @@ function get_ranges_and_linked_metadata(md::Metadata, start_offset::Int) all_ranges = Dict{VarName,RangeAndLinked}() offset = start_offset for (vn, idx) in md.idcs - len = length(md.ranges[idx]) is_linked = md.is_transformed[idx] - range = offset:(offset + len - 1) + range = md.ranges[idx] .+ (start_offset - 1) if AbstractPPL.getoptic(vn) === identity all_iden_ranges = merge( all_iden_ranges, @@ -420,7 +416,7 @@ function get_ranges_and_linked_metadata(md::Metadata, start_offset::Int) else all_ranges[vn] = RangeAndLinked(range, is_linked) end - offset += len + offset += length(range) end return all_iden_ranges, all_ranges, offset end @@ -429,9 +425,8 @@ function get_ranges_and_linked_metadata(vnv::VarNamedVector, start_offset::Int) all_ranges = Dict{VarName,RangeAndLinked}() offset = start_offset for (vn, idx) in vnv.varname_to_index - len = length(vnv.ranges[idx]) is_linked = vnv.is_unconstrained[idx] - range = offset:(offset + len - 1) + range = vnv.ranges[idx] .+ (start_offset - 1) if AbstractPPL.getoptic(vn) === identity all_iden_ranges = merge( all_iden_ranges, @@ -440,7 +435,7 @@ function get_ranges_and_linked_metadata(vnv::VarNamedVector, start_offset::Int) else all_ranges[vn] = RangeAndLinked(range, is_linked) end - offset += len + offset += length(range) end return all_iden_ranges, all_ranges, offset end From 30b9247080f7e15fa8e5259b2fbc8237c58b5487 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 6 Nov 2025 18:04:40 +0000 Subject: [PATCH 20/23] Fix get_param_eltype for TSVI --- src/fastldf.jl | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/fastldf.jl b/src/fastldf.jl index 15326a83b..eaca0c795 100644 --- a/src/fastldf.jl +++ b/src/fastldf.jl @@ -80,7 +80,9 @@ OnlyAccsVarInfo() = OnlyAccsVarInfo(default_accumulators()) DynamicPPL.maybe_invlink_before_eval!!(vi::OnlyAccsVarInfo, ::Model) = vi DynamicPPL.getaccs(vi::OnlyAccsVarInfo) = vi.accs DynamicPPL.setaccs!!(::OnlyAccsVarInfo, accs::AccumulatorTuple) = OnlyAccsVarInfo(accs) -function DynamicPPL.get_param_eltype(::OnlyAccsVarInfo, model::Model) +function DynamicPPL.get_param_eltype( + ::Union{OnlyAccsVarInfo,ThreadSafeVarInfo{<:OnlyAccsVarInfo}}, model::Model +) # Because the VarInfo has no parameters stored in it, we need to get the eltype from the # model's leaf context. This is only possible if said leaf context is indeed a FastEval # context. @@ -333,15 +335,16 @@ end function (f::FastLogDensityAt)(params::AbstractVector{<:Real}) ctx = FastEvalVectorContext(f._iden_varname_ranges, f._varname_ranges, params) model = DynamicPPL.setleafcontext(f._model, ctx) - only_accs_vi = OnlyAccsVarInfo(fast_ldf_accs(f._getlogdensity)) + accs = fast_ldf_accs(f._getlogdensity) # Calling `evaluate!!` would be fine, but would lead to an extra call to resetaccs!!, # which is unnecessary. So we shortcircuit this by simply calling `_evaluate!!` # directly. To preserve thread-safety we need to reproduce the ThreadSafeVarInfo logic # here. vi = if Threads.nthreads() > 1 - ThreadSafeVarInfo(only_accs_vi) + accs = map(acc -> convert_eltype(float_type_with_fallback(eltype(params)), acc), accs) + ThreadSafeVarInfo(OnlyAccsVarInfo(accs)) else - only_accs_vi + OnlyAccsVarInfo(accs) end _, vi = _evaluate!!(model, vi) return f._getlogdensity(vi) From 316937a2ff3014b4a2383d2dadc79134db60a258 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 6 Nov 2025 18:17:26 +0000 Subject: [PATCH 21/23] Disable Enzyme benchmark --- benchmarks/benchmarks.jl | 2 +- src/fastldf.jl | 19 +++++++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/benchmarks/benchmarks.jl b/benchmarks/benchmarks.jl index 5fe0320cc..e78bf602f 100644 --- a/benchmarks/benchmarks.jl +++ b/benchmarks/benchmarks.jl @@ -66,7 +66,7 @@ chosen_combinations = [ # ("Smorgasbord", smorgasbord_instance, :untyped_vector, :forwarddiff, true), ("Smorgasbord", smorgasbord_instance, :typed, :reversediff, true), ("Smorgasbord", smorgasbord_instance, :typed, :mooncake, true), - ("Smorgasbord", smorgasbord_instance, :typed, :enzyme, true), + # ("Smorgasbord", smorgasbord_instance, :typed, :enzyme, true), ("Loop univariate 1k", loop_univariate1k, :typed, :mooncake, true), ("Multivariate 1k", multivariate1k, :typed, :mooncake, true), ("Loop univariate 10k", loop_univariate10k, :typed, :mooncake, true), diff --git a/src/fastldf.jl b/src/fastldf.jl index eaca0c795..c06f3495c 100644 --- a/src/fastldf.jl +++ b/src/fastldf.jl @@ -350,6 +350,25 @@ function (f::FastLogDensityAt)(params::AbstractVector{<:Real}) return f._getlogdensity(vi) end +function _evaluate!!(model::Model, varinfo::OnlyAccsVarInfo) + if leafcontext(model.context) isa FastEvalVectorContext + args = map(maybe_deepcopy, model.args) + return model.f(model, varinfo, args...; model.defaults...) + else + error("Shouldn't happen") + end +end +maybe_deepcopy(@nospecialize(x)) = x +function maybe_deepcopy(x::AbstractArray{T}) where {T} + if T >: Missing + # avoid overwriting missing elements of model arguments when + # evaluating the model. + deepcopy(x) + else + x + end +end + function LogDensityProblems.logdensity(fldf::FastLDF, params::AbstractVector{<:Real}) return FastLogDensityAt( fldf.model, fldf._getlogdensity, fldf._iden_varname_ranges, fldf._varname_ranges From 9c71e8119553000105495559213d8ecd96d1dd81 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 6 Nov 2025 18:20:05 +0000 Subject: [PATCH 22/23] Revert "Handle errors in benchmark suite" This reverts commit 93daa2b3286d1cc68e1c176f9f34b001e3561a95. --- benchmarks/benchmarks.jl | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/benchmarks/benchmarks.jl b/benchmarks/benchmarks.jl index e5e88f067..e78bf602f 100644 --- a/benchmarks/benchmarks.jl +++ b/benchmarks/benchmarks.jl @@ -87,18 +87,12 @@ results_table = Tuple{String,Int,String,String,Bool,Float64,Float64}[] for (model_name, model, varinfo_choice, adbackend, islinked) in chosen_combinations @info "Running benchmark for $model_name" - try - suite = make_suite(model, varinfo_choice, adbackend, islinked) - results = run(suite) - eval_time = median(results["evaluation"]).time - relative_eval_time = eval_time / reference_time - ad_eval_time = median(results["gradient"]).time - relative_ad_eval_time = ad_eval_time / eval_time - catch e - @warn "Benchmark failed for $model_name with error: $e" - relative_eval_time = NaN - relative_ad_eval_time = NaN - end + suite = make_suite(model, varinfo_choice, adbackend, islinked) + results = run(suite) + eval_time = median(results["evaluation"]).time + relative_eval_time = eval_time / reference_time + ad_eval_time = median(results["gradient"]).time + relative_ad_eval_time = ad_eval_time / eval_time push!( results_table, ( From 075cee8a3da81e1a96558274448fe8a7458f4f6e Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 6 Nov 2025 18:20:31 +0000 Subject: [PATCH 23/23] Don't override _evaluate!!, that breaks ForwardDiff (sometimes) --- src/fastldf.jl | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/src/fastldf.jl b/src/fastldf.jl index c06f3495c..eaca0c795 100644 --- a/src/fastldf.jl +++ b/src/fastldf.jl @@ -350,25 +350,6 @@ function (f::FastLogDensityAt)(params::AbstractVector{<:Real}) return f._getlogdensity(vi) end -function _evaluate!!(model::Model, varinfo::OnlyAccsVarInfo) - if leafcontext(model.context) isa FastEvalVectorContext - args = map(maybe_deepcopy, model.args) - return model.f(model, varinfo, args...; model.defaults...) - else - error("Shouldn't happen") - end -end -maybe_deepcopy(@nospecialize(x)) = x -function maybe_deepcopy(x::AbstractArray{T}) where {T} - if T >: Missing - # avoid overwriting missing elements of model arguments when - # evaluating the model. - deepcopy(x) - else - x - end -end - function LogDensityProblems.logdensity(fldf::FastLDF, params::AbstractVector{<:Real}) return FastLogDensityAt( fldf.model, fldf._getlogdensity, fldf._iden_varname_ranges, fldf._varname_ranges