Skip to content

Commit fa0022b

Browse files
committed
Improve type stability when all parameters are linked or unlinked
1 parent 6849bc2 commit fa0022b

File tree

3 files changed

+101
-22
lines changed

3 files changed

+101
-22
lines changed

src/contexts/init.jl

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ struct RangeAndLinked
214214
end
215215

216216
"""
217-
VectorWithRanges(
217+
VectorWithRanges{Tlink}(
218218
iden_varname_ranges::NamedTuple,
219219
varname_ranges::Dict{VarName,RangeAndLinked},
220220
vect::AbstractVector{<:Real},
@@ -231,13 +231,19 @@ non-identity-optic VarNames are stored in the `varname_ranges` Dict.
231231
It would be nice to improve the NamedTuple and Dict approach. See, e.g.
232232
https://github.com/TuringLang/DynamicPPL.jl/issues/1116.
233233
"""
234-
struct VectorWithRanges{N<:NamedTuple,T<:AbstractVector{<:Real}}
234+
struct VectorWithRanges{Tlink,N<:NamedTuple,T<:AbstractVector{<:Real}}
235235
# This NamedTuple stores the ranges for identity VarNames
236236
iden_varname_ranges::N
237237
# This Dict stores the ranges for all other VarNames
238238
varname_ranges::Dict{VarName,RangeAndLinked}
239239
# The full parameter vector which we index into to get variable values
240240
vect::T
241+
242+
function VectorWithRanges{Tlink}(
243+
iden_varname_ranges::N, varname_ranges::Dict{VarName,RangeAndLinked}, vect::T
244+
) where {Tlink,N,T}
245+
return new{Tlink,N,T}(iden_varname_ranges, varname_ranges, vect)
246+
end
241247
end
242248

243249
function _get_range_and_linked(
@@ -252,7 +258,29 @@ function init(
252258
::Random.AbstractRNG,
253259
vn::VarName,
254260
dist::Distribution,
255-
p::InitFromParams{<:VectorWithRanges},
261+
p::InitFromParams{<:VectorWithRanges{true}},
262+
)
263+
vr = p.params
264+
range_and_linked = _get_range_and_linked(vr, vn)
265+
transform = from_linked_vec_transform(dist)
266+
return (@view vr.vect[range_and_linked.range]), transform
267+
end
268+
function init(
269+
::Random.AbstractRNG,
270+
vn::VarName,
271+
dist::Distribution,
272+
p::InitFromParams{<:VectorWithRanges{false}},
273+
)
274+
vr = p.params
275+
range_and_linked = _get_range_and_linked(vr, vn)
276+
transform = from_vec_transform(dist)
277+
return (@view vr.vect[range_and_linked.range]), transform
278+
end
279+
function init(
280+
::Random.AbstractRNG,
281+
vn::VarName,
282+
dist::Distribution,
283+
p::InitFromParams{<:VectorWithRanges{nothing}},
256284
)
257285
vr = p.params
258286
range_and_linked = _get_range_and_linked(vr, vn)

src/fasteval.jl

Lines changed: 54 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -64,16 +64,21 @@ in the function name.
6464
# TODO(penelopeysm): This should _not_ check Threads.nthreads(). I still don't know what
6565
# it _should_ do, but this is wrong regardless.
6666
# https://github.com/TuringLang/DynamicPPL.jl/issues/1086
67-
vi = if Threads.nthreads() > 1
68-
param_eltype = DynamicPPL.get_param_eltype(strategy)
67+
return if Threads.nthreads() > 1
68+
# WARNING: Do NOT move get_param_eltype(strategy) into an intermediate variable, it
69+
# will cause type instabilities! See also unflatten in src/varinfo.jl.
6970
accs = map(accs) do acc
70-
DynamicPPL.convert_eltype(float_type_with_fallback(param_eltype), acc)
71+
DynamicPPL.convert_eltype(
72+
float_type_with_fallback(DynamicPPL.get_param_eltype(strategy)), acc
73+
)
7174
end
72-
ThreadSafeVarInfo(OnlyAccsVarInfo(accs))
75+
tsvi = ThreadSafeVarInfo(OnlyAccsVarInfo(accs))
76+
retval, tsvi_new = DynamicPPL._evaluate!!(model, tsvi)
77+
retval, setaccs!!(tsvi_new.varinfo, getaccs(tsvi_new))
7378
else
74-
OnlyAccsVarInfo(accs)
79+
vi = OnlyAccsVarInfo(accs)
80+
DynamicPPL._evaluate!!(model, vi)
7581
end
76-
return DynamicPPL._evaluate!!(model, vi)
7782
end
7883
@inline function fast_evaluate!!(
7984
model::Model, strategy::AbstractInitStrategy, accs::AccumulatorTuple
@@ -193,6 +198,9 @@ with such models.** This is a general limitation of vectorised parameters: the o
193198
`unflatten` + `evaluate!!` approach also fails with such models.
194199
"""
195200
struct LogDensityFunction{
201+
# true if all variables are linked; false if all variables are unlinked; nothing if
202+
# mixed
203+
Tlink,
196204
M<:Model,
197205
AD<:Union{ADTypes.AbstractADType,Nothing},
198206
F<:Function,
@@ -216,6 +224,21 @@ struct LogDensityFunction{
216224
# Figure out which variable corresponds to which index, and
217225
# which variables are linked.
218226
all_iden_ranges, all_ranges = get_ranges_and_linked(varinfo)
227+
# Figure out if all variables are linked, unlinked, or mixed
228+
link_statuses = Bool[]
229+
for ral in all_iden_ranges
230+
push!(link_statuses, ral.is_linked)
231+
end
232+
for (_, ral) in all_ranges
233+
push!(link_statuses, ral.is_linked)
234+
end
235+
Tlink = if all(link_statuses)
236+
true
237+
elseif all(!s for s in link_statuses)
238+
false
239+
else
240+
nothing
241+
end
219242
x = [val for val in varinfo[:]]
220243
dim = length(x)
221244
# Do AD prep if needed
@@ -225,12 +248,13 @@ struct LogDensityFunction{
225248
# Make backend-specific tweaks to the adtype
226249
adtype = DynamicPPL.tweak_adtype(adtype, model, varinfo)
227250
DI.prepare_gradient(
228-
LogDensityAt(model, getlogdensity, all_iden_ranges, all_ranges),
251+
LogDensityAt{Tlink}(model, getlogdensity, all_iden_ranges, all_ranges),
229252
adtype,
230253
x,
231254
)
232255
end
233256
return new{
257+
Tlink,
234258
typeof(model),
235259
typeof(adtype),
236260
typeof(getlogdensity),
@@ -262,36 +286,45 @@ end
262286
fast_ldf_accs(::typeof(getlogprior)) = AccumulatorTuple((LogPriorAccumulator(),))
263287
fast_ldf_accs(::typeof(getloglikelihood)) = AccumulatorTuple((LogLikelihoodAccumulator(),))
264288

265-
struct LogDensityAt{M<:Model,F<:Function,N<:NamedTuple}
289+
struct LogDensityAt{Tlink,M<:Model,F<:Function,N<:NamedTuple}
266290
model::M
267291
getlogdensity::F
268292
iden_varname_ranges::N
269293
varname_ranges::Dict{VarName,RangeAndLinked}
294+
295+
function LogDensityAt{Tlink}(
296+
model::M,
297+
getlogdensity::F,
298+
iden_varname_ranges::N,
299+
varname_ranges::Dict{VarName,RangeAndLinked},
300+
) where {Tlink,M,F,N}
301+
return new{Tlink,M,F,N}(model, getlogdensity, iden_varname_ranges, varname_ranges)
302+
end
270303
end
271-
function (f::LogDensityAt)(params::AbstractVector{<:Real})
304+
function (f::LogDensityAt{Tlink})(params::AbstractVector{<:Real}) where {Tlink}
272305
strategy = InitFromParams(
273-
VectorWithRanges(f.iden_varname_ranges, f.varname_ranges, params), nothing
306+
VectorWithRanges{Tlink}(f.iden_varname_ranges, f.varname_ranges, params), nothing
274307
)
275308
accs = fast_ldf_accs(f.getlogdensity)
276309
_, vi = DynamicPPL.fast_evaluate!!(f.model, strategy, accs)
277310
return f.getlogdensity(vi)
278311
end
279312

280313
function LogDensityProblems.logdensity(
281-
ldf::LogDensityFunction, params::AbstractVector{<:Real}
282-
)
283-
return LogDensityAt(
314+
ldf::LogDensityFunction{Tlink}, params::AbstractVector{<:Real}
315+
) where {Tlink}
316+
return LogDensityAt{Tlink}(
284317
ldf.model, ldf._getlogdensity, ldf._iden_varname_ranges, ldf._varname_ranges
285318
)(
286319
params
287320
)
288321
end
289322

290323
function LogDensityProblems.logdensity_and_gradient(
291-
ldf::LogDensityFunction, params::AbstractVector{<:Real}
292-
)
324+
ldf::LogDensityFunction{Tlink}, params::AbstractVector{<:Real}
325+
) where {Tlink}
293326
return DI.value_and_gradient(
294-
LogDensityAt(
327+
LogDensityAt{Tlink}(
295328
ldf.model, ldf._getlogdensity, ldf._iden_varname_ranges, ldf._varname_ranges
296329
),
297330
ldf._adprep,
@@ -300,12 +333,14 @@ function LogDensityProblems.logdensity_and_gradient(
300333
)
301334
end
302335

303-
function LogDensityProblems.capabilities(::Type{<:LogDensityFunction{M,Nothing}}) where {M}
336+
function LogDensityProblems.capabilities(
337+
::Type{<:LogDensityFunction{T,M,Nothing}}
338+
) where {T,M}
304339
return LogDensityProblems.LogDensityOrder{0}()
305340
end
306341
function LogDensityProblems.capabilities(
307-
::Type{<:LogDensityFunction{M,<:ADTypes.AbstractADType}}
308-
) where {M}
342+
::Type{<:LogDensityFunction{T,M,<:ADTypes.AbstractADType}}
343+
) where {T,M}
309344
return LogDensityProblems.LogDensityOrder{1}()
310345
end
311346
function LogDensityProblems.dimension(ldf::LogDensityFunction)

test/fasteval.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,22 @@ using Mooncake: Mooncake
6969
end
7070
end
7171

72+
@testset "LogDensityFunction: Type stability" begin
73+
@testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS
74+
unlinked_vi = DynamicPPL.VarInfo(m)
75+
@testset "$islinked" for islinked in (false, true)
76+
vi = if islinked
77+
DynamicPPL.link!!(unlinked_vi, m)
78+
else
79+
unlinked_vi
80+
end
81+
ldf = DynamicPPL.LogDensityFunction(m, DynamicPPL.getlogjoint_internal, vi)
82+
x = vi[:]
83+
@inferred LogDensityProblems.logdensity(ldf, x)
84+
end
85+
end
86+
end
87+
7288
@testset "Fast evaluation: performance" begin
7389
if Threads.nthreads() == 1
7490
# Evaluating these three models with OnlyAccsVarInfo should not lead to any

0 commit comments

Comments
 (0)