diff --git a/src/contexts/init.jl b/src/contexts/init.jl index a79969a13..fd3d5ef3e 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -214,7 +214,7 @@ struct RangeAndLinked end """ - VectorWithRanges( + VectorWithRanges{Tlink}( iden_varname_ranges::NamedTuple, varname_ranges::Dict{VarName,RangeAndLinked}, vect::AbstractVector{<:Real}, @@ -231,13 +231,19 @@ non-identity-optic VarNames are stored in the `varname_ranges` Dict. It would be nice to improve the NamedTuple and Dict approach. See, e.g. https://github.com/TuringLang/DynamicPPL.jl/issues/1116. """ -struct VectorWithRanges{N<:NamedTuple,T<:AbstractVector{<:Real}} +struct VectorWithRanges{Tlink,N<:NamedTuple,T<:AbstractVector{<:Real}} # This NamedTuple stores the ranges for identity VarNames iden_varname_ranges::N # This Dict stores the ranges for all other VarNames varname_ranges::Dict{VarName,RangeAndLinked} # The full parameter vector which we index into to get variable values vect::T + + function VectorWithRanges{Tlink}( + iden_varname_ranges::N, varname_ranges::Dict{VarName,RangeAndLinked}, vect::T + ) where {Tlink,N,T} + return new{Tlink,N,T}(iden_varname_ranges, varname_ranges, vect) + end end function _get_range_and_linked( @@ -252,7 +258,29 @@ function init( ::Random.AbstractRNG, vn::VarName, dist::Distribution, - p::InitFromParams{<:VectorWithRanges}, + p::InitFromParams{<:VectorWithRanges{true}}, +) + vr = p.params + range_and_linked = _get_range_and_linked(vr, vn) + transform = from_linked_vec_transform(dist) + return (@view vr.vect[range_and_linked.range]), transform +end +function init( + ::Random.AbstractRNG, + vn::VarName, + dist::Distribution, + p::InitFromParams{<:VectorWithRanges{false}}, +) + vr = p.params + range_and_linked = _get_range_and_linked(vr, vn) + transform = from_vec_transform(dist) + return (@view vr.vect[range_and_linked.range]), transform +end +function init( + ::Random.AbstractRNG, + vn::VarName, + dist::Distribution, + p::InitFromParams{<:VectorWithRanges{nothing}}, ) vr = p.params range_and_linked = _get_range_and_linked(vr, vn) diff --git a/src/fasteval.jl b/src/fasteval.jl index 639e5b6c1..8c14613bc 100644 --- a/src/fasteval.jl +++ b/src/fasteval.jl @@ -64,16 +64,21 @@ in the function name. # TODO(penelopeysm): This should _not_ check Threads.nthreads(). I still don't know what # it _should_ do, but this is wrong regardless. # https://github.com/TuringLang/DynamicPPL.jl/issues/1086 - vi = if Threads.nthreads() > 1 - param_eltype = DynamicPPL.get_param_eltype(strategy) + return if Threads.nthreads() > 1 + # WARNING: Do NOT move get_param_eltype(strategy) into an intermediate variable, it + # will cause type instabilities! See also unflatten in src/varinfo.jl. accs = map(accs) do acc - DynamicPPL.convert_eltype(float_type_with_fallback(param_eltype), acc) + DynamicPPL.convert_eltype( + float_type_with_fallback(DynamicPPL.get_param_eltype(strategy)), acc + ) end - ThreadSafeVarInfo(OnlyAccsVarInfo(accs)) + tsvi = ThreadSafeVarInfo(OnlyAccsVarInfo(accs)) + retval, tsvi_new = DynamicPPL._evaluate!!(model, tsvi) + retval, setaccs!!(tsvi_new.varinfo, getaccs(tsvi_new)) else - OnlyAccsVarInfo(accs) + vi = OnlyAccsVarInfo(accs) + DynamicPPL._evaluate!!(model, vi) end - return DynamicPPL._evaluate!!(model, vi) end @inline function fast_evaluate!!( model::Model, strategy::AbstractInitStrategy, accs::AccumulatorTuple @@ -193,6 +198,9 @@ with such models.** This is a general limitation of vectorised parameters: the o `unflatten` + `evaluate!!` approach also fails with such models. """ struct LogDensityFunction{ + # true if all variables are linked; false if all variables are unlinked; nothing if + # mixed + Tlink, M<:Model, AD<:Union{ADTypes.AbstractADType,Nothing}, F<:Function, @@ -216,6 +224,21 @@ struct LogDensityFunction{ # Figure out which variable corresponds to which index, and # which variables are linked. all_iden_ranges, all_ranges = get_ranges_and_linked(varinfo) + # Figure out if all variables are linked, unlinked, or mixed + link_statuses = Bool[] + for ral in all_iden_ranges + push!(link_statuses, ral.is_linked) + end + for (_, ral) in all_ranges + push!(link_statuses, ral.is_linked) + end + Tlink = if all(link_statuses) + true + elseif all(!s for s in link_statuses) + false + else + nothing + end x = [val for val in varinfo[:]] dim = length(x) # Do AD prep if needed @@ -225,12 +248,13 @@ struct LogDensityFunction{ # Make backend-specific tweaks to the adtype adtype = DynamicPPL.tweak_adtype(adtype, model, varinfo) DI.prepare_gradient( - LogDensityAt(model, getlogdensity, all_iden_ranges, all_ranges), + LogDensityAt{Tlink}(model, getlogdensity, all_iden_ranges, all_ranges), adtype, x, ) end return new{ + Tlink, typeof(model), typeof(adtype), typeof(getlogdensity), @@ -262,15 +286,24 @@ end fast_ldf_accs(::typeof(getlogprior)) = AccumulatorTuple((LogPriorAccumulator(),)) fast_ldf_accs(::typeof(getloglikelihood)) = AccumulatorTuple((LogLikelihoodAccumulator(),)) -struct LogDensityAt{M<:Model,F<:Function,N<:NamedTuple} +struct LogDensityAt{Tlink,M<:Model,F<:Function,N<:NamedTuple} model::M getlogdensity::F iden_varname_ranges::N varname_ranges::Dict{VarName,RangeAndLinked} + + function LogDensityAt{Tlink}( + model::M, + getlogdensity::F, + iden_varname_ranges::N, + varname_ranges::Dict{VarName,RangeAndLinked}, + ) where {Tlink,M,F,N} + return new{Tlink,M,F,N}(model, getlogdensity, iden_varname_ranges, varname_ranges) + end end -function (f::LogDensityAt)(params::AbstractVector{<:Real}) +function (f::LogDensityAt{Tlink})(params::AbstractVector{<:Real}) where {Tlink} strategy = InitFromParams( - VectorWithRanges(f.iden_varname_ranges, f.varname_ranges, params), nothing + VectorWithRanges{Tlink}(f.iden_varname_ranges, f.varname_ranges, params), nothing ) accs = fast_ldf_accs(f.getlogdensity) _, vi = DynamicPPL.fast_evaluate!!(f.model, strategy, accs) @@ -278,9 +311,9 @@ function (f::LogDensityAt)(params::AbstractVector{<:Real}) end function LogDensityProblems.logdensity( - ldf::LogDensityFunction, params::AbstractVector{<:Real} -) - return LogDensityAt( + ldf::LogDensityFunction{Tlink}, params::AbstractVector{<:Real} +) where {Tlink} + return LogDensityAt{Tlink}( ldf.model, ldf._getlogdensity, ldf._iden_varname_ranges, ldf._varname_ranges )( params @@ -288,10 +321,10 @@ function LogDensityProblems.logdensity( end function LogDensityProblems.logdensity_and_gradient( - ldf::LogDensityFunction, params::AbstractVector{<:Real} -) + ldf::LogDensityFunction{Tlink}, params::AbstractVector{<:Real} +) where {Tlink} return DI.value_and_gradient( - LogDensityAt( + LogDensityAt{Tlink}( ldf.model, ldf._getlogdensity, ldf._iden_varname_ranges, ldf._varname_ranges ), ldf._adprep, @@ -300,12 +333,14 @@ function LogDensityProblems.logdensity_and_gradient( ) end -function LogDensityProblems.capabilities(::Type{<:LogDensityFunction{M,Nothing}}) where {M} +function LogDensityProblems.capabilities( + ::Type{<:LogDensityFunction{T,M,Nothing}} +) where {T,M} return LogDensityProblems.LogDensityOrder{0}() end function LogDensityProblems.capabilities( - ::Type{<:LogDensityFunction{M,<:ADTypes.AbstractADType}} -) where {M} + ::Type{<:LogDensityFunction{T,M,<:ADTypes.AbstractADType}} +) where {T,M} return LogDensityProblems.LogDensityOrder{1}() end function LogDensityProblems.dimension(ldf::LogDensityFunction) diff --git a/test/fasteval.jl b/test/fasteval.jl index a75441c93..213e8da83 100644 --- a/test/fasteval.jl +++ b/test/fasteval.jl @@ -69,6 +69,22 @@ using Mooncake: Mooncake end end +@testset "LogDensityFunction: Type stability" begin + @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS + unlinked_vi = DynamicPPL.VarInfo(m) + @testset "$islinked" for islinked in (false, true) + vi = if islinked + DynamicPPL.link!!(unlinked_vi, m) + else + unlinked_vi + end + ldf = DynamicPPL.LogDensityFunction(m, DynamicPPL.getlogjoint_internal, vi) + x = vi[:] + @inferred LogDensityProblems.logdensity(ldf, x) + end + end +end + @testset "Fast evaluation: performance" begin if Threads.nthreads() == 1 # Evaluating these three models with OnlyAccsVarInfo should not lead to any