@@ -64,16 +64,21 @@ in the function name.
6464 # TODO (penelopeysm): This should _not_ check Threads.nthreads(). I still don't know what
6565 # it _should_ do, but this is wrong regardless.
6666 # https://github.com/TuringLang/DynamicPPL.jl/issues/1086
67- vi = if Threads. nthreads () > 1
68- param_eltype = DynamicPPL. get_param_eltype (strategy)
67+ return if Threads. nthreads () > 1
68+ # WARNING: Do NOT move get_param_eltype(strategy) into an intermediate variable, it
69+ # will cause type instabilities! See also unflatten in src/varinfo.jl.
6970 accs = map (accs) do acc
70- DynamicPPL. convert_eltype (float_type_with_fallback (param_eltype), acc)
71+ DynamicPPL. convert_eltype (
72+ float_type_with_fallback (DynamicPPL. get_param_eltype (strategy)), acc
73+ )
7174 end
72- ThreadSafeVarInfo (OnlyAccsVarInfo (accs))
75+ tsvi = ThreadSafeVarInfo (OnlyAccsVarInfo (accs))
76+ retval, tsvi_new = DynamicPPL. _evaluate!! (model, tsvi)
77+ retval, setaccs!! (tsvi_new. varinfo, getaccs (tsvi_new))
7378 else
74- OnlyAccsVarInfo (accs)
79+ vi = OnlyAccsVarInfo (accs)
80+ DynamicPPL. _evaluate!! (model, vi)
7581 end
76- return DynamicPPL. _evaluate!! (model, vi)
7782end
7883@inline function fast_evaluate!! (
7984 model:: Model , strategy:: AbstractInitStrategy , accs:: AccumulatorTuple
@@ -193,6 +198,9 @@ with such models.** This is a general limitation of vectorised parameters: the o
193198`unflatten` + `evaluate!!` approach also fails with such models.
194199"""
195200struct LogDensityFunction{
201+ # true if all variables are linked; false if all variables are unlinked; nothing if
202+ # mixed
203+ Tlink,
196204 M<: Model ,
197205 AD<: Union{ADTypes.AbstractADType,Nothing} ,
198206 F<: Function ,
@@ -216,6 +224,21 @@ struct LogDensityFunction{
216224 # Figure out which variable corresponds to which index, and
217225 # which variables are linked.
218226 all_iden_ranges, all_ranges = get_ranges_and_linked (varinfo)
227+ # Figure out if all variables are linked, unlinked, or mixed
228+ link_statuses = Bool[]
229+ for ral in all_iden_ranges
230+ push! (link_statuses, ral. is_linked)
231+ end
232+ for (_, ral) in all_ranges
233+ push! (link_statuses, ral. is_linked)
234+ end
235+ Tlink = if all (link_statuses)
236+ true
237+ elseif all (! s for s in link_statuses)
238+ false
239+ else
240+ nothing
241+ end
219242 x = [val for val in varinfo[:]]
220243 dim = length (x)
221244 # Do AD prep if needed
@@ -225,12 +248,13 @@ struct LogDensityFunction{
225248 # Make backend-specific tweaks to the adtype
226249 adtype = DynamicPPL. tweak_adtype (adtype, model, varinfo)
227250 DI. prepare_gradient (
228- LogDensityAt (model, getlogdensity, all_iden_ranges, all_ranges),
251+ LogDensityAt {Tlink} (model, getlogdensity, all_iden_ranges, all_ranges),
229252 adtype,
230253 x,
231254 )
232255 end
233256 return new{
257+ Tlink,
234258 typeof (model),
235259 typeof (adtype),
236260 typeof (getlogdensity),
@@ -262,36 +286,45 @@ end
262286fast_ldf_accs (:: typeof (getlogprior)) = AccumulatorTuple ((LogPriorAccumulator (),))
263287fast_ldf_accs (:: typeof (getloglikelihood)) = AccumulatorTuple ((LogLikelihoodAccumulator (),))
264288
265- struct LogDensityAt{M<: Model ,F<: Function ,N<: NamedTuple }
289+ struct LogDensityAt{Tlink, M<: Model ,F<: Function ,N<: NamedTuple }
266290 model:: M
267291 getlogdensity:: F
268292 iden_varname_ranges:: N
269293 varname_ranges:: Dict{VarName,RangeAndLinked}
294+
295+ function LogDensityAt {Tlink} (
296+ model:: M ,
297+ getlogdensity:: F ,
298+ iden_varname_ranges:: N ,
299+ varname_ranges:: Dict{VarName,RangeAndLinked} ,
300+ ) where {Tlink,M,F,N}
301+ return new {Tlink,M,F,N} (model, getlogdensity, iden_varname_ranges, varname_ranges)
302+ end
270303end
271- function (f:: LogDensityAt )(params:: AbstractVector{<:Real} )
304+ function (f:: LogDensityAt{Tlink} )(params:: AbstractVector{<:Real} ) where {Tlink}
272305 strategy = InitFromParams (
273- VectorWithRanges (f. iden_varname_ranges, f. varname_ranges, params), nothing
306+ VectorWithRanges {Tlink} (f. iden_varname_ranges, f. varname_ranges, params), nothing
274307 )
275308 accs = fast_ldf_accs (f. getlogdensity)
276309 _, vi = DynamicPPL. fast_evaluate!! (f. model, strategy, accs)
277310 return f. getlogdensity (vi)
278311end
279312
280313function LogDensityProblems. logdensity (
281- ldf:: LogDensityFunction , params:: AbstractVector{<:Real}
282- )
283- return LogDensityAt (
314+ ldf:: LogDensityFunction{Tlink} , params:: AbstractVector{<:Real}
315+ ) where {Tlink}
316+ return LogDensityAt {Tlink} (
284317 ldf. model, ldf. _getlogdensity, ldf. _iden_varname_ranges, ldf. _varname_ranges
285318 )(
286319 params
287320 )
288321end
289322
290323function LogDensityProblems. logdensity_and_gradient (
291- ldf:: LogDensityFunction , params:: AbstractVector{<:Real}
292- )
324+ ldf:: LogDensityFunction{Tlink} , params:: AbstractVector{<:Real}
325+ ) where {Tlink}
293326 return DI. value_and_gradient (
294- LogDensityAt (
327+ LogDensityAt {Tlink} (
295328 ldf. model, ldf. _getlogdensity, ldf. _iden_varname_ranges, ldf. _varname_ranges
296329 ),
297330 ldf. _adprep,
@@ -300,12 +333,14 @@ function LogDensityProblems.logdensity_and_gradient(
300333 )
301334end
302335
303- function LogDensityProblems. capabilities (:: Type{<:LogDensityFunction{M,Nothing}} ) where {M}
336+ function LogDensityProblems. capabilities (
337+ :: Type{<:LogDensityFunction{T,M,Nothing}}
338+ ) where {T,M}
304339 return LogDensityProblems. LogDensityOrder {0} ()
305340end
306341function LogDensityProblems. capabilities (
307- :: Type{<:LogDensityFunction{M,<:ADTypes.AbstractADType}}
308- ) where {M}
342+ :: Type{<:LogDensityFunction{T, M,<:ADTypes.AbstractADType}}
343+ ) where {T, M}
309344 return LogDensityProblems. LogDensityOrder {1} ()
310345end
311346function LogDensityProblems. dimension (ldf:: LogDensityFunction )
0 commit comments