@@ -3,6 +3,7 @@ using DynamicPPL:
33 AccumulatorTuple,
44 InitContext,
55 InitFromParams,
6+ AbstractInitStrategy,
67 LogJacobianAccumulator,
78 LogLikelihoodAccumulator,
89 LogPriorAccumulator,
@@ -28,6 +29,49 @@ using LogDensityProblems: LogDensityProblems
2829import DifferentiationInterface as DI
2930using Random: Random
3031
32+ """
33+ DynamicPPL.Experimental.fast_evaluate!!(
34+ [rng::Random.AbstractRNG,]
35+ model::Model,
36+ strategy::AbstractInitStrategy,
37+ accs::AccumulatorTuple, params::AbstractVector{<:Real}
38+ )
39+
40+ Evaluate a model using parameters obtained via `strategy`, and only computing the results in
41+ the provided accumulators.
42+ """
43+ @inline function fast_evaluate!! (
44+ rng:: Random.AbstractRNG ,
45+ model:: Model ,
46+ strategy:: AbstractInitStrategy ,
47+ accs:: AccumulatorTuple ,
48+ )
49+ ctx = InitContext (rng, strategy)
50+ model = DynamicPPL. setleafcontext (model, ctx)
51+ # Calling `evaluate!!` would be fine, but would lead to an extra call to resetaccs!!,
52+ # which is unnecessary. So we shortcircuit this by simply calling `_evaluate!!`
53+ # directly. To preserve thread-safety we need to reproduce the ThreadSafeVarInfo logic
54+ # here.
55+ # TODO (penelopeysm): This should _not_ check Threads.nthreads(). I still don't know what
56+ # it _should_ do, but this is wrong regardless.
57+ # https://github.com/TuringLang/DynamicPPL.jl/issues/1086
58+ vi = if Threads. nthreads () > 1
59+ param_eltype = DynamicPPL. get_param_eltype (strategy)
60+ accs = map (accs) do acc
61+ DynamicPPL. convert_eltype (float_type_with_fallback (param_eltype), acc)
62+ end
63+ ThreadSafeVarInfo (OnlyAccsVarInfo (accs))
64+ else
65+ OnlyAccsVarInfo (accs)
66+ end
67+ return DynamicPPL. _evaluate!! (model, vi)
68+ end
69+ @inline function fast_evaluate!! (
70+ model:: Model , strategy:: AbstractInitStrategy , accs:: AccumulatorTuple
71+ )
72+ return fast_evaluate!! (Random. default_rng (), model, strategy, accs)
73+ end
74+
3175"""
3276 FastLDF(
3377 model::Model,
@@ -213,31 +257,11 @@ struct FastLogDensityAt{M<:Model,F<:Function,N<:NamedTuple}
213257 varname_ranges:: Dict{VarName,RangeAndLinked}
214258end
215259function (f:: FastLogDensityAt )(params:: AbstractVector{<:Real} )
216- ctx = InitContext (
217- Random. default_rng (),
218- InitFromParams (
219- VectorWithRanges (f. iden_varname_ranges, f. varname_ranges, params), nothing
220- ),
260+ strategy = InitFromParams (
261+ VectorWithRanges (f. iden_varname_ranges, f. varname_ranges, params), nothing
221262 )
222- model = DynamicPPL. setleafcontext (f. model, ctx)
223263 accs = fast_ldf_accs (f. getlogdensity)
224- # Calling `evaluate!!` would be fine, but would lead to an extra call to resetaccs!!,
225- # which is unnecessary. So we shortcircuit this by simply calling `_evaluate!!`
226- # directly. To preserve thread-safety we need to reproduce the ThreadSafeVarInfo logic
227- # here.
228- # TODO (penelopeysm): This should _not_ check Threads.nthreads(). I still don't know what
229- # it _should_ do, but this is wrong regardless.
230- # https://github.com/TuringLang/DynamicPPL.jl/issues/1086
231- vi = if Threads. nthreads () > 1
232- accs = map (
233- acc -> DynamicPPL. convert_eltype (float_type_with_fallback (eltype (params)), acc),
234- accs,
235- )
236- ThreadSafeVarInfo (OnlyAccsVarInfo (accs))
237- else
238- OnlyAccsVarInfo (accs)
239- end
240- _, vi = DynamicPPL. _evaluate!! (model, vi)
264+ _, vi = fast_evaluate!! (f. model, strategy, accs)
241265 return f. getlogdensity (vi)
242266end
243267
0 commit comments