Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 31 additions & 3 deletions src/contexts/init.jl
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ struct RangeAndLinked
end

"""
VectorWithRanges(
VectorWithRanges{Tlink}(
iden_varname_ranges::NamedTuple,
varname_ranges::Dict{VarName,RangeAndLinked},
vect::AbstractVector{<:Real},
Expand All @@ -231,13 +231,19 @@ non-identity-optic VarNames are stored in the `varname_ranges` Dict.
It would be nice to improve the NamedTuple and Dict approach. See, e.g.
https://github.com/TuringLang/DynamicPPL.jl/issues/1116.
"""
struct VectorWithRanges{N<:NamedTuple,T<:AbstractVector{<:Real}}
struct VectorWithRanges{Tlink,N<:NamedTuple,T<:AbstractVector{<:Real}}
# This NamedTuple stores the ranges for identity VarNames
iden_varname_ranges::N
# This Dict stores the ranges for all other VarNames
varname_ranges::Dict{VarName,RangeAndLinked}
# The full parameter vector which we index into to get variable values
vect::T

function VectorWithRanges{Tlink}(
iden_varname_ranges::N, varname_ranges::Dict{VarName,RangeAndLinked}, vect::T
) where {Tlink,N,T}
return new{Tlink,N,T}(iden_varname_ranges, varname_ranges, vect)
end
end

function _get_range_and_linked(
Expand All @@ -252,7 +258,29 @@ function init(
::Random.AbstractRNG,
vn::VarName,
dist::Distribution,
p::InitFromParams{<:VectorWithRanges},
p::InitFromParams{<:VectorWithRanges{true}},
)
vr = p.params
range_and_linked = _get_range_and_linked(vr, vn)
transform = from_linked_vec_transform(dist)
return (@view vr.vect[range_and_linked.range]), transform
end
function init(
::Random.AbstractRNG,
vn::VarName,
dist::Distribution,
p::InitFromParams{<:VectorWithRanges{false}},
)
vr = p.params
range_and_linked = _get_range_and_linked(vr, vn)
transform = from_vec_transform(dist)
return (@view vr.vect[range_and_linked.range]), transform
end
function init(
::Random.AbstractRNG,
vn::VarName,
dist::Distribution,
p::InitFromParams{<:VectorWithRanges{nothing}},
)
vr = p.params
range_and_linked = _get_range_and_linked(vr, vn)
Expand Down
73 changes: 54 additions & 19 deletions src/fasteval.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,16 +64,21 @@ in the function name.
# TODO(penelopeysm): This should _not_ check Threads.nthreads(). I still don't know what
# it _should_ do, but this is wrong regardless.
# https://github.com/TuringLang/DynamicPPL.jl/issues/1086
vi = if Threads.nthreads() > 1
param_eltype = DynamicPPL.get_param_eltype(strategy)
return if Threads.nthreads() > 1
# WARNING: Do NOT move get_param_eltype(strategy) into an intermediate variable, it
# will cause type instabilities! See also unflatten in src/varinfo.jl.
accs = map(accs) do acc
DynamicPPL.convert_eltype(float_type_with_fallback(param_eltype), acc)
DynamicPPL.convert_eltype(
float_type_with_fallback(DynamicPPL.get_param_eltype(strategy)), acc
)
end
ThreadSafeVarInfo(OnlyAccsVarInfo(accs))
tsvi = ThreadSafeVarInfo(OnlyAccsVarInfo(accs))
retval, tsvi_new = DynamicPPL._evaluate!!(model, tsvi)
retval, setaccs!!(tsvi_new.varinfo, getaccs(tsvi_new))
else
OnlyAccsVarInfo(accs)
vi = OnlyAccsVarInfo(accs)
DynamicPPL._evaluate!!(model, vi)
end
return DynamicPPL._evaluate!!(model, vi)
end
@inline function fast_evaluate!!(
model::Model, strategy::AbstractInitStrategy, accs::AccumulatorTuple
Expand Down Expand Up @@ -193,6 +198,9 @@ with such models.** This is a general limitation of vectorised parameters: the o
`unflatten` + `evaluate!!` approach also fails with such models.
"""
struct LogDensityFunction{
# true if all variables are linked; false if all variables are unlinked; nothing if
# mixed
Tlink,
M<:Model,
AD<:Union{ADTypes.AbstractADType,Nothing},
F<:Function,
Expand All @@ -216,6 +224,21 @@ struct LogDensityFunction{
# Figure out which variable corresponds to which index, and
# which variables are linked.
all_iden_ranges, all_ranges = get_ranges_and_linked(varinfo)
# Figure out if all variables are linked, unlinked, or mixed
link_statuses = Bool[]
for ral in all_iden_ranges
push!(link_statuses, ral.is_linked)
end
for (_, ral) in all_ranges
push!(link_statuses, ral.is_linked)
end
Tlink = if all(link_statuses)
true
elseif all(!s for s in link_statuses)
false
else
nothing
end
x = [val for val in varinfo[:]]
dim = length(x)
# Do AD prep if needed
Expand All @@ -225,12 +248,13 @@ struct LogDensityFunction{
# Make backend-specific tweaks to the adtype
adtype = DynamicPPL.tweak_adtype(adtype, model, varinfo)
DI.prepare_gradient(
LogDensityAt(model, getlogdensity, all_iden_ranges, all_ranges),
LogDensityAt{Tlink}(model, getlogdensity, all_iden_ranges, all_ranges),
adtype,
x,
)
end
return new{
Tlink,
typeof(model),
typeof(adtype),
typeof(getlogdensity),
Expand Down Expand Up @@ -262,36 +286,45 @@ end
fast_ldf_accs(::typeof(getlogprior)) = AccumulatorTuple((LogPriorAccumulator(),))
fast_ldf_accs(::typeof(getloglikelihood)) = AccumulatorTuple((LogLikelihoodAccumulator(),))

struct LogDensityAt{M<:Model,F<:Function,N<:NamedTuple}
struct LogDensityAt{Tlink,M<:Model,F<:Function,N<:NamedTuple}
model::M
getlogdensity::F
iden_varname_ranges::N
varname_ranges::Dict{VarName,RangeAndLinked}

function LogDensityAt{Tlink}(
model::M,
getlogdensity::F,
iden_varname_ranges::N,
varname_ranges::Dict{VarName,RangeAndLinked},
) where {Tlink,M,F,N}
return new{Tlink,M,F,N}(model, getlogdensity, iden_varname_ranges, varname_ranges)
end
end
function (f::LogDensityAt)(params::AbstractVector{<:Real})
function (f::LogDensityAt{Tlink})(params::AbstractVector{<:Real}) where {Tlink}
strategy = InitFromParams(
VectorWithRanges(f.iden_varname_ranges, f.varname_ranges, params), nothing
VectorWithRanges{Tlink}(f.iden_varname_ranges, f.varname_ranges, params), nothing
)
accs = fast_ldf_accs(f.getlogdensity)
_, vi = DynamicPPL.fast_evaluate!!(f.model, strategy, accs)
return f.getlogdensity(vi)
end

function LogDensityProblems.logdensity(
ldf::LogDensityFunction, params::AbstractVector{<:Real}
)
return LogDensityAt(
ldf::LogDensityFunction{Tlink}, params::AbstractVector{<:Real}
) where {Tlink}
return LogDensityAt{Tlink}(
ldf.model, ldf._getlogdensity, ldf._iden_varname_ranges, ldf._varname_ranges
)(
params
)
end

function LogDensityProblems.logdensity_and_gradient(
ldf::LogDensityFunction, params::AbstractVector{<:Real}
)
ldf::LogDensityFunction{Tlink}, params::AbstractVector{<:Real}
) where {Tlink}
return DI.value_and_gradient(
LogDensityAt(
LogDensityAt{Tlink}(
ldf.model, ldf._getlogdensity, ldf._iden_varname_ranges, ldf._varname_ranges
),
ldf._adprep,
Expand All @@ -300,12 +333,14 @@ function LogDensityProblems.logdensity_and_gradient(
)
end

function LogDensityProblems.capabilities(::Type{<:LogDensityFunction{M,Nothing}}) where {M}
function LogDensityProblems.capabilities(
::Type{<:LogDensityFunction{T,M,Nothing}}
) where {T,M}
return LogDensityProblems.LogDensityOrder{0}()
end
function LogDensityProblems.capabilities(
::Type{<:LogDensityFunction{M,<:ADTypes.AbstractADType}}
) where {M}
::Type{<:LogDensityFunction{T,M,<:ADTypes.AbstractADType}}
) where {T,M}
return LogDensityProblems.LogDensityOrder{1}()
end
function LogDensityProblems.dimension(ldf::LogDensityFunction)
Expand Down
16 changes: 16 additions & 0 deletions test/fasteval.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,22 @@ using Mooncake: Mooncake
end
end

@testset "LogDensityFunction: Type stability" begin
@testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS
unlinked_vi = DynamicPPL.VarInfo(m)
@testset "$islinked" for islinked in (false, true)
vi = if islinked
DynamicPPL.link!!(unlinked_vi, m)
else
unlinked_vi
end
ldf = DynamicPPL.LogDensityFunction(m, DynamicPPL.getlogjoint_internal, vi)
x = vi[:]
@inferred LogDensityProblems.logdensity(ldf, x)
end
end
end

@testset "Fast evaluation: performance" begin
if Threads.nthreads() == 1
# Evaluating these three models with OnlyAccsVarInfo should not lead to any
Expand Down