Skip to content

Commit 6c2c0fa

Browse files
committed
Make FastLDF the default
1 parent 3cd8d34 commit 6c2c0fa

File tree

11 files changed

+162
-694
lines changed

11 files changed

+162
-694
lines changed

HISTORY.md

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,17 @@
44

55
### Breaking changes
66

7+
#### Fast Log Density Functions
8+
9+
This version provides a reimplementation of `LogDensityFunction` that provides performance improvements on the order of 2–10× for both model evaluation as well as automatic differentiation.
10+
Exact speedups depend on the model size: larger models have less significant speedups because the bulk of the work is done in calls to `logpdf`.
11+
12+
For more information about how this is accomplished, please see https://github.com/TuringLang/DynamicPPL.jl/pull/1113 as well as the `src/fasteval.jl` file, which contains extensive comments.
13+
14+
As a result of this change, `LogDensityFunction` no longer stores a VarInfo inside it.
15+
In general, if `ldf` is a `LogDensityFunction`, it is now only valid to access `ldf.model` and `ldf.adtype`.
16+
If you were previously relying on this behaviour, you will need to store a VarInfo separately.
17+
718
#### Parent and leaf contexts
819

920
The `DynamicPPL.NodeTrait` function has been removed.
@@ -24,18 +35,6 @@ Removed the method `returned(::Model, values, keys)`; please use `returned(::Mod
2435
The method `DynamicPPL.init` (for implementing `AbstractInitStrategy`) now has a different signature: it must return a tuple of the generated value, plus a transform function that maps it back to unlinked space.
2536
This is a generalisation of the previous behaviour, where `init` would always return an unlinked value (in effect forcing the transform to be the identity function).
2637

27-
### Other changes
28-
29-
#### FastLDF
30-
31-
Added `DynamicPPL.Experimental.FastLDF`, a version of `LogDensityFunction` that provides performance improvements on the order of 2–10× for both model evaluation as well as automatic differentiation.
32-
Exact speedups depend on the model size: larger models have less significant speedups because the bulk of the work is done in calls to `logpdf`.
33-
34-
Please note that `FastLDF` is currently considered internal and its API may change without warning.
35-
We intend to replace `LogDensityFunction` with `FastLDF` in a release in the near future, but until then we recommend not using it.
36-
37-
For more information about `FastLDF`, please see https://github.com/TuringLang/DynamicPPL.jl/pull/1113 as well as the `src/fasteval.jl` file, which contains extensive comments.
38-
3938
## 0.38.9
4039

4140
Remove warning when using Enzyme as the AD backend.

docs/src/api.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,13 @@ The [LogDensityProblems.jl](https://github.com/tpapp/LogDensityProblems.jl) inte
6666
LogDensityFunction
6767
```
6868

69+
Internally, this is accomplished using:
70+
71+
```@docs
72+
OnlyAccsVarInfo
73+
fast_evaluate!!
74+
```
75+
6976
## Condition and decondition
7077

7178
A [`Model`](@ref) can be conditioned on a set of observations with [`AbstractPPL.condition`](@ref) or its alias [`|`](@ref).

ext/DynamicPPLMarginalLogDensitiesExt.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,13 @@ using MarginalLogDensities: MarginalLogDensities
66
# A thin wrapper to adapt a DynamicPPL.LogDensityFunction to the interface expected by
77
# MarginalLogDensities. It's helpful to have a struct so that we can dispatch on its type
88
# below.
9-
struct LogDensityFunctionWrapper{L<:DynamicPPL.LogDensityFunction}
9+
struct LogDensityFunctionWrapper{
10+
L<:DynamicPPL.LogDensityFunction,V<:DynamicPPL.AbstractVarInfo
11+
}
1012
logdensity::L
13+
# This field is used only to reconstruct the VarInfo later on; it's not needed for the
14+
# actual log-density evaluation.
15+
varinfo::V
1116
end
1217
function (lw::LogDensityFunctionWrapper)(x, _)
1318
return LogDensityProblems.logdensity(lw.logdensity, x)
@@ -101,7 +106,7 @@ function DynamicPPL.marginalize(
101106
# Construct the marginal log-density model.
102107
f = DynamicPPL.LogDensityFunction(model, getlogprob, varinfo)
103108
mld = MarginalLogDensities.MarginalLogDensity(
104-
LogDensityFunctionWrapper(f), varinfo[:], varindices, (), method; kwargs...
109+
LogDensityFunctionWrapper(f, varinfo), varinfo[:], varindices, (), method; kwargs...
105110
)
106111
return mld
107112
end
@@ -190,7 +195,7 @@ function DynamicPPL.VarInfo(
190195
unmarginalized_params::Union{AbstractVector,Nothing}=nothing,
191196
)
192197
# Extract the original VarInfo. Its contents will in general be junk.
193-
original_vi = mld.logdensity.logdensity.varinfo
198+
original_vi = mld.logdensity.varinfo
194199
# Extract the stored parameters, which includes the modes for any marginalized
195200
# parameters
196201
full_params = MarginalLogDensities.cached_params(mld)

src/DynamicPPL.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,10 @@ export AbstractVarInfo,
9292
getargnames,
9393
extract_priors,
9494
values_as_in_model,
95-
# LogDensityFunction
95+
# LogDensityFunction and fasteval
9696
LogDensityFunction,
97+
fast_evaluate!!,
98+
OnlyAccsVarInfo,
9799
# Leaf contexts
98100
AbstractContext,
99101
contextualize,
@@ -198,7 +200,7 @@ include("simple_varinfo.jl")
198200
include("onlyaccs.jl")
199201
include("compiler.jl")
200202
include("pointwise_logdensities.jl")
201-
include("logdensityfunction.jl")
203+
include("fasteval.jl")
202204
include("model_utils.jl")
203205
include("extract_priors.jl")
204206
include("values_as_in_model.jl")

src/experimental.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@ module Experimental
22

33
using DynamicPPL: DynamicPPL
44

5-
include("fasteval.jl")
6-
75
# This file only defines the names of the functions, and their docstrings. The actual implementations are in `ext/DynamicPPLJETExt.jl`, since we don't want to depend on JET.jl other than as a weak dependency.
86
"""
97
is_suitable_varinfo(model::Model, varinfo::AbstractVarInfo; kwargs...)

src/fasteval.jl

Lines changed: 112 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,61 @@ import DifferentiationInterface as DI
2929
using Random: Random
3030

3131
"""
32-
FastLDF(
32+
DynamicPPL.fast_evaluate!!(
33+
[rng::Random.AbstractRNG,]
34+
model::Model,
35+
strategy::AbstractInitStrategy,
36+
accs::AccumulatorTuple, params::AbstractVector{<:Real}
37+
)
38+
39+
Evaluate a model using parameters obtained via `strategy`, and only computing the results in
40+
the provided accumulators.
41+
42+
It is assumed that the accumulators passed in have been initialised to appropriate values,
43+
as this function will not reset them. The default constructors for each accumulator will do
44+
this for you correctly.
45+
46+
Returns a tuple of the model's return value, plus an `OnlyAccsVarInfo`. Note that the `accs`
47+
argument may be mutated (depending on how the accumulators are implemented); hence the `!!`
48+
in the function name.
49+
"""
50+
@inline function fast_evaluate!!(
51+
# Note that this `@inline` is mandatory for performance. If it's not inlined, it leads
52+
# to extra allocations (even for trivial models) and much slower runtime.
53+
rng::Random.AbstractRNG,
54+
model::Model,
55+
strategy::AbstractInitStrategy,
56+
accs::AccumulatorTuple,
57+
)
58+
ctx = InitContext(rng, strategy)
59+
model = DynamicPPL.setleafcontext(model, ctx)
60+
# Calling `evaluate!!` would be fine, but would lead to an extra call to resetaccs!!,
61+
# which is unnecessary. So we shortcircuit this by simply calling `_evaluate!!`
62+
# directly. To preserve thread-safety we need to reproduce the ThreadSafeVarInfo logic
63+
# here.
64+
# TODO(penelopeysm): This should _not_ check Threads.nthreads(). I still don't know what
65+
# it _should_ do, but this is wrong regardless.
66+
# https://github.com/TuringLang/DynamicPPL.jl/issues/1086
67+
vi = if Threads.nthreads() > 1
68+
param_eltype = DynamicPPL.get_param_eltype(strategy)
69+
accs = map(accs) do acc
70+
DynamicPPL.convert_eltype(float_type_with_fallback(param_eltype), acc)
71+
end
72+
ThreadSafeVarInfo(OnlyAccsVarInfo(accs))
73+
else
74+
OnlyAccsVarInfo(accs)
75+
end
76+
return DynamicPPL._evaluate!!(model, vi)
77+
end
78+
@inline function fast_evaluate!!(
79+
model::Model, strategy::AbstractInitStrategy, accs::AccumulatorTuple
80+
)
81+
# This `@inline` is also mandatory for performance
82+
return fast_evaluate!!(Random.default_rng(), model, strategy, accs)
83+
end
84+
85+
"""
86+
DynamicPPL.LogDensityFunction(
3387
model::Model,
3488
getlogdensity::Function=getlogjoint_internal,
3589
varinfo::AbstractVarInfo=VarInfo(model);
@@ -60,26 +114,27 @@ There are several options for `getlogdensity` that are 'supported' out of the bo
60114
since transforms are only applied to random variables)
61115
62116
!!! note
63-
By default, `FastLDF` uses `getlogjoint_internal`, i.e., the result of
64-
`LogDensityProblems.logdensity(f, x)` will depend on whether the `FastLDF` was created
65-
with a linked or unlinked VarInfo. This is done primarily to ease interoperability with
66-
MCMC samplers.
117+
By default, `LogDensityFunction` uses `getlogjoint_internal`, i.e., the result of
118+
`LogDensityProblems.logdensity(f, x)` will depend on whether the `LogDensityFunction`
119+
was created with a linked or unlinked VarInfo. This is done primarily to ease
120+
interoperability with MCMC samplers.
67121
68122
If you provide one of these functions, a `VarInfo` will be automatically created for you. If
69123
you provide a different function, you have to manually create a VarInfo and pass it as the
70124
third argument.
71125
72126
If the `adtype` keyword argument is provided, then this struct will also store the adtype
73127
along with other information for efficient calculation of the gradient of the log density.
74-
Note that preparing a `FastLDF` with an AD type `AutoBackend()` requires the AD backend
75-
itself to have been loaded (e.g. with `import Backend`).
128+
Note that preparing a `LogDensityFunction` with an AD type `AutoBackend()` requires the AD
129+
backend itself to have been loaded (e.g. with `import Backend`).
76130
77131
## Fields
78132
79-
Note that it is undefined behaviour to access any of a `FastLDF`'s fields, apart from:
133+
Note that it is undefined behaviour to access any of a `LogDensityFunction`'s fields, apart
134+
from:
80135
81-
- `fastldf.model`: The original model from which this `FastLDF` was constructed.
82-
- `fastldf.adtype`: The AD type used for gradient calculations, or `nothing` if no AD
136+
- `ldf.model`: The original model from which this `LogDensityFunction` was constructed.
137+
- `ldf.adtype`: The AD type used for gradient calculations, or `nothing` if no AD
83138
type was provided.
84139
85140
# Extended help
@@ -117,8 +172,9 @@ Traditionally, this problem has been solved by `unflatten`, because that functio
117172
place values into the VarInfo's metadata alongside the information about ranges and linking.
118173
That way, when we evaluate with `DefaultContext`, we can read this information out again.
119174
However, we want to avoid using a metadata. Thus, here, we _extract this information from
120-
the VarInfo_ a single time when constructing a `FastLDF` object. Inside the FastLDF, we
121-
store a mapping from VarNames to ranges in that vector, along with link status.
175+
the VarInfo_ a single time when constructing a `LogDensityFunction` object. Inside the
176+
LogDensityFunction, we store a mapping from VarNames to ranges in that vector, along with
177+
link status.
122178
123179
For VarNames with identity optics, this is stored in a NamedTuple for efficiency. For all
124180
other VarNames, this is stored in a Dict. The internal data structure used to represent this
@@ -130,13 +186,13 @@ ranges to create an `InitFromParams{VectorWithRanges}`, which lets us very quick
130186
parameter values from the vector.
131187
132188
Note that this assumes that the ranges and link status are static throughout the lifetime of
133-
the `FastLDF` object. Therefore, a `FastLDF` object cannot handle models which have variable
134-
numbers of parameters, or models which may visit random variables in different orders depending
135-
on stochastic control flow. **Indeed, silent errors may occur with such models.** This is a
136-
general limitation of vectorised parameters: the original `unflatten` + `evaluate!!`
137-
approach also fails with such models.
189+
the `LogDensityFunction` object. Therefore, a `LogDensityFunction` object cannot handle
190+
models which have variable numbers of parameters, or models which may visit random variables
191+
in different orders depending on stochastic control flow. **Indeed, silent errors may occur
192+
with such models.** This is a general limitation of vectorised parameters: the original
193+
`unflatten` + `evaluate!!` approach also fails with such models.
138194
"""
139-
struct FastLDF{
195+
struct LogDensityFunction{
140196
M<:Model,
141197
AD<:Union{ADTypes.AbstractADType,Nothing},
142198
F<:Function,
@@ -151,7 +207,7 @@ struct FastLDF{
151207
_adprep::ADP
152208
_dim::Int
153209

154-
function FastLDF(
210+
function LogDensityFunction(
155211
model::Model,
156212
getlogdensity::Function=getlogjoint_internal,
157213
varinfo::AbstractVarInfo=VarInfo(model);
@@ -169,7 +225,7 @@ struct FastLDF{
169225
# Make backend-specific tweaks to the adtype
170226
adtype = DynamicPPL.tweak_adtype(adtype, model, varinfo)
171227
DI.prepare_gradient(
172-
FastLogDensityAt(model, getlogdensity, all_iden_ranges, all_ranges),
228+
LogDensityAt(model, getlogdensity, all_iden_ranges, all_ranges),
173229
adtype,
174230
x,
175231
)
@@ -206,76 +262,73 @@ end
206262
fast_ldf_accs(::typeof(getlogprior)) = AccumulatorTuple((LogPriorAccumulator(),))
207263
fast_ldf_accs(::typeof(getloglikelihood)) = AccumulatorTuple((LogLikelihoodAccumulator(),))
208264

209-
struct FastLogDensityAt{M<:Model,F<:Function,N<:NamedTuple}
265+
struct LogDensityAt{M<:Model,F<:Function,N<:NamedTuple}
210266
model::M
211267
getlogdensity::F
212268
iden_varname_ranges::N
213269
varname_ranges::Dict{VarName,RangeAndLinked}
214270
end
215-
function (f::FastLogDensityAt)(params::AbstractVector{<:Real})
216-
ctx = InitContext(
217-
Random.default_rng(),
218-
InitFromParams(
219-
VectorWithRanges(f.iden_varname_ranges, f.varname_ranges, params), nothing
220-
),
271+
function (f::LogDensityAt)(params::AbstractVector{<:Real})
272+
strategy = InitFromParams(
273+
VectorWithRanges(f.iden_varname_ranges, f.varname_ranges, params), nothing
221274
)
222-
model = DynamicPPL.setleafcontext(f.model, ctx)
223275
accs = fast_ldf_accs(f.getlogdensity)
224-
# Calling `evaluate!!` would be fine, but would lead to an extra call to resetaccs!!,
225-
# which is unnecessary. So we shortcircuit this by simply calling `_evaluate!!`
226-
# directly. To preserve thread-safety we need to reproduce the ThreadSafeVarInfo logic
227-
# here.
228-
# TODO(penelopeysm): This should _not_ check Threads.nthreads(). I still don't know what
229-
# it _should_ do, but this is wrong regardless.
230-
# https://github.com/TuringLang/DynamicPPL.jl/issues/1086
231-
vi = if Threads.nthreads() > 1
232-
accs = map(
233-
acc -> DynamicPPL.convert_eltype(float_type_with_fallback(eltype(params)), acc),
234-
accs,
235-
)
236-
ThreadSafeVarInfo(OnlyAccsVarInfo(accs))
237-
else
238-
OnlyAccsVarInfo(accs)
239-
end
240-
_, vi = DynamicPPL._evaluate!!(model, vi)
276+
_, vi = DynamicPPL.fast_evaluate!!(f.model, strategy, accs)
241277
return f.getlogdensity(vi)
242278
end
243279

244-
function LogDensityProblems.logdensity(fldf::FastLDF, params::AbstractVector{<:Real})
245-
return FastLogDensityAt(
246-
fldf.model, fldf._getlogdensity, fldf._iden_varname_ranges, fldf._varname_ranges
280+
function LogDensityProblems.logdensity(
281+
ldf::LogDensityFunction, params::AbstractVector{<:Real}
282+
)
283+
return LogDensityAt(
284+
ldf.model, ldf._getlogdensity, ldf._iden_varname_ranges, ldf._varname_ranges
247285
)(
248286
params
249287
)
250288
end
251289

252290
function LogDensityProblems.logdensity_and_gradient(
253-
fldf::FastLDF, params::AbstractVector{<:Real}
291+
ldf::LogDensityFunction, params::AbstractVector{<:Real}
254292
)
255293
return DI.value_and_gradient(
256-
FastLogDensityAt(
257-
fldf.model, fldf._getlogdensity, fldf._iden_varname_ranges, fldf._varname_ranges
294+
LogDensityAt(
295+
ldf.model, ldf._getlogdensity, ldf._iden_varname_ranges, ldf._varname_ranges
258296
),
259-
fldf._adprep,
260-
fldf.adtype,
297+
ldf._adprep,
298+
ldf.adtype,
261299
params,
262300
)
263301
end
264302

265-
function LogDensityProblems.capabilities(
266-
::Type{<:DynamicPPL.Experimental.FastLDF{M,Nothing}}
267-
) where {M}
303+
function LogDensityProblems.capabilities(::Type{<:LogDensityFunction{M,Nothing}}) where {M}
268304
return LogDensityProblems.LogDensityOrder{0}()
269305
end
270306
function LogDensityProblems.capabilities(
271-
::Type{<:DynamicPPL.Experimental.FastLDF{M,<:ADTypes.AbstractADType}}
307+
::Type{<:LogDensityFunction{M,<:ADTypes.AbstractADType}}
272308
) where {M}
273309
return LogDensityProblems.LogDensityOrder{1}()
274310
end
275-
function LogDensityProblems.dimension(fldf::FastLDF)
276-
return fldf._dim
311+
function LogDensityProblems.dimension(ldf::LogDensityFunction)
312+
return ldf._dim
277313
end
278314

315+
"""
316+
tweak_adtype(
317+
adtype::ADTypes.AbstractADType,
318+
model::Model,
319+
varinfo::AbstractVarInfo,
320+
)
321+
322+
Return an 'optimised' form of the adtype. This is useful for doing
323+
backend-specific optimisation of the adtype (e.g., for ForwardDiff, calculating
324+
the chunk size: see the method override in `ext/DynamicPPLForwardDiffExt.jl`).
325+
The model is passed as a parameter in case the optimisation depends on the
326+
model.
327+
328+
By default, this just returns the input unchanged.
329+
"""
330+
tweak_adtype(adtype::ADTypes.AbstractADType, ::Model, ::AbstractVarInfo) = adtype
331+
279332
######################################################
280333
# Helper functions to extract ranges and link status #
281334
######################################################

0 commit comments

Comments
 (0)