Skip to content

Commit 95b84ac

Browse files
committed
Make FastLDF the default
1 parent 2fad97b commit 95b84ac

File tree

13 files changed

+114
-645
lines changed

13 files changed

+114
-645
lines changed

HISTORY.md

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,17 @@
44

55
### Breaking changes
66

7+
#### Fast Log Density Functions
8+
9+
This version provides a reimplementation of `LogDensityFunction` that provides performance improvements on the order of 2–10× for both model evaluation as well as automatic differentiation.
10+
Exact speedups depend on the model size: larger models have less significant speedups because the bulk of the work is done in calls to `logpdf`.
11+
12+
For more information about how this is accomplished, please see https://github.com/TuringLang/DynamicPPL.jl/pull/1113 as well as the `src/fasteval.jl` file, which contains extensive comments.
13+
14+
As a result of this change, `LogDensityFunction` no longer stores a VarInfo inside it.
15+
In general, if `ldf` is a `LogDensityFunction`, it is now only valid to access `ldf.model` and `ldf.adtype`.
16+
If you were previously relying on this behaviour, you will need to store a VarInfo separately.
17+
718
#### Parent and leaf contexts
819

920
The `DynamicPPL.NodeTrait` function has been removed.
@@ -24,18 +35,6 @@ Removed the method `returned(::Model, values, keys)`; please use `returned(::Mod
2435
The method `DynamicPPL.init` (for implementing `AbstractInitStrategy`) now has a different signature: it must return a tuple of the generated value, plus a transform function that maps it back to unlinked space.
2536
This is a generalisation of the previous behaviour, where `init` would always return an unlinked value (in effect forcing the transform to be the identity function).
2637

27-
### Other changes
28-
29-
#### FastLDF
30-
31-
Added `DynamicPPL.Experimental.FastLDF`, a version of `LogDensityFunction` that provides performance improvements on the order of 2–10× for both model evaluation as well as automatic differentiation.
32-
Exact speedups depend on the model size: larger models have less significant speedups because the bulk of the work is done in calls to `logpdf`.
33-
34-
Please note that `FastLDF` is currently considered internal and its API may change without warning.
35-
We intend to replace `LogDensityFunction` with `FastLDF` in a release in the near future, but until then we recommend not using it.
36-
37-
For more information about `FastLDF`, please see https://github.com/TuringLang/DynamicPPL.jl/pull/1113 as well as the `src/fasteval.jl` file, which contains extensive comments.
38-
3938
## 0.38.9
4039

4140
Remove warning when using Enzyme as the AD backend.

docs/src/api.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,13 @@ The [LogDensityProblems.jl](https://github.com/tpapp/LogDensityProblems.jl) inte
6666
LogDensityFunction
6767
```
6868

69+
Internally, this is accomplished using:
70+
71+
```@docs
72+
OnlyAccsVarInfo
73+
fast_evaluate!!
74+
```
75+
6976
## Condition and decondition
7077

7178
A [`Model`](@ref) can be conditioned on a set of observations with [`AbstractPPL.condition`](@ref) or its alias [`|`](@ref).

ext/DynamicPPLMarginalLogDensitiesExt.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,13 @@ using MarginalLogDensities: MarginalLogDensities
66
# A thin wrapper to adapt a DynamicPPL.LogDensityFunction to the interface expected by
77
# MarginalLogDensities. It's helpful to have a struct so that we can dispatch on its type
88
# below.
9-
struct LogDensityFunctionWrapper{L<:DynamicPPL.LogDensityFunction}
9+
struct LogDensityFunctionWrapper{
10+
L<:DynamicPPL.LogDensityFunction,V<:DynamicPPL.AbstractVarInfo
11+
}
1012
logdensity::L
13+
# This field is used only to reconstruct the VarInfo later on; it's not needed for the
14+
# actual log-density evaluation.
15+
varinfo::V
1116
end
1217
function (lw::LogDensityFunctionWrapper)(x, _)
1318
return LogDensityProblems.logdensity(lw.logdensity, x)
@@ -101,7 +106,7 @@ function DynamicPPL.marginalize(
101106
# Construct the marginal log-density model.
102107
f = DynamicPPL.LogDensityFunction(model, getlogprob, varinfo)
103108
mld = MarginalLogDensities.MarginalLogDensity(
104-
LogDensityFunctionWrapper(f), varinfo[:], varindices, (), method; kwargs...
109+
LogDensityFunctionWrapper(f, varinfo), varinfo[:], varindices, (), method; kwargs...
105110
)
106111
return mld
107112
end
@@ -190,7 +195,7 @@ function DynamicPPL.VarInfo(
190195
unmarginalized_params::Union{AbstractVector,Nothing}=nothing,
191196
)
192197
# Extract the original VarInfo. Its contents will in general be junk.
193-
original_vi = mld.logdensity.logdensity.varinfo
198+
original_vi = mld.logdensity.varinfo
194199
# Extract the stored parameters, which includes the modes for any marginalized
195200
# parameters
196201
full_params = MarginalLogDensities.cached_params(mld)

src/DynamicPPL.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,10 @@ export AbstractVarInfo,
9292
getargnames,
9393
extract_priors,
9494
values_as_in_model,
95-
# LogDensityFunction
95+
# LogDensityFunction and fasteval
9696
LogDensityFunction,
97+
fast_evaluate!!,
98+
OnlyAccsVarInfo,
9799
# Leaf contexts
98100
AbstractContext,
99101
contextualize,
@@ -198,7 +200,7 @@ include("simple_varinfo.jl")
198200
include("onlyaccs.jl")
199201
include("compiler.jl")
200202
include("pointwise_logdensities.jl")
201-
include("logdensityfunction.jl")
203+
include("fasteval.jl")
202204
include("model_utils.jl")
203205
include("extract_priors.jl")
204206
include("values_as_in_model.jl")

src/chains.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ end
137137
"""
138138
ParamsWithStats(
139139
param_vector::AbstractVector,
140-
ldf::DynamicPPL.Experimental.FastLDF,
140+
ldf::DynamicPPL.LogDensityFunction,
141141
stats::NamedTuple=NamedTuple();
142142
include_colon_eq::Bool=true,
143143
include_log_probs::Bool=true,
@@ -156,7 +156,7 @@ via `unflatten` plus re-evaluation. It is faster for two reasons:
156156
"""
157157
function ParamsWithStats(
158158
param_vector::AbstractVector,
159-
ldf::DynamicPPL.Experimental.FastLDF,
159+
ldf::DynamicPPL.LogDensityFunction,
160160
stats::NamedTuple=NamedTuple();
161161
include_colon_eq::Bool=true,
162162
include_log_probs::Bool=true,
@@ -174,9 +174,7 @@ function ParamsWithStats(
174174
else
175175
(DynamicPPL.ValuesAsInModelAccumulator(include_colon_eq),)
176176
end
177-
_, vi = DynamicPPL.Experimental.fast_evaluate!!(
178-
ldf.model, strategy, AccumulatorTuple(accs)
179-
)
177+
_, vi = DynamicPPL.fast_evaluate!!(ldf.model, strategy, AccumulatorTuple(accs))
180178
params = DynamicPPL.getacc(vi, Val(:ValuesAsInModel)).values
181179
if include_log_probs
182180
stats = merge(

src/experimental.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@ module Experimental
22

33
using DynamicPPL: DynamicPPL
44

5-
include("fasteval.jl")
6-
75
# This file only defines the names of the functions, and their docstrings. The actual implementations are in `ext/DynamicPPLJETExt.jl`, since we don't want to depend on JET.jl other than as a weak dependency.
86
"""
97
is_suitable_varinfo(model::Model, varinfo::AbstractVarInfo; kwargs...)

src/fasteval.jl

Lines changed: 56 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ import DifferentiationInterface as DI
3030
using Random: Random
3131

3232
"""
33-
DynamicPPL.Experimental.fast_evaluate!!(
33+
DynamicPPL.fast_evaluate!!(
3434
[rng::Random.AbstractRNG,]
3535
model::Model,
3636
strategy::AbstractInitStrategy,
@@ -84,7 +84,7 @@ end
8484
end
8585

8686
"""
87-
FastLDF(
87+
DynamicPPL.LogDensityFunction(
8888
model::Model,
8989
getlogdensity::Function=getlogjoint_internal,
9090
varinfo::AbstractVarInfo=VarInfo(model);
@@ -115,26 +115,27 @@ There are several options for `getlogdensity` that are 'supported' out of the bo
115115
since transforms are only applied to random variables)
116116
117117
!!! note
118-
By default, `FastLDF` uses `getlogjoint_internal`, i.e., the result of
119-
`LogDensityProblems.logdensity(f, x)` will depend on whether the `FastLDF` was created
120-
with a linked or unlinked VarInfo. This is done primarily to ease interoperability with
121-
MCMC samplers.
118+
By default, `LogDensityFunction` uses `getlogjoint_internal`, i.e., the result of
119+
`LogDensityProblems.logdensity(f, x)` will depend on whether the `LogDensityFunction`
120+
was created with a linked or unlinked VarInfo. This is done primarily to ease
121+
interoperability with MCMC samplers.
122122
123123
If you provide one of these functions, a `VarInfo` will be automatically created for you. If
124124
you provide a different function, you have to manually create a VarInfo and pass it as the
125125
third argument.
126126
127127
If the `adtype` keyword argument is provided, then this struct will also store the adtype
128128
along with other information for efficient calculation of the gradient of the log density.
129-
Note that preparing a `FastLDF` with an AD type `AutoBackend()` requires the AD backend
130-
itself to have been loaded (e.g. with `import Backend`).
129+
Note that preparing a `LogDensityFunction` with an AD type `AutoBackend()` requires the AD
130+
backend itself to have been loaded (e.g. with `import Backend`).
131131
132132
## Fields
133133
134-
Note that it is undefined behaviour to access any of a `FastLDF`'s fields, apart from:
134+
Note that it is undefined behaviour to access any of a `LogDensityFunction`'s fields, apart
135+
from:
135136
136-
- `fastldf.model`: The original model from which this `FastLDF` was constructed.
137-
- `fastldf.adtype`: The AD type used for gradient calculations, or `nothing` if no AD
137+
- `ldf.model`: The original model from which this `LogDensityFunction` was constructed.
138+
- `ldf.adtype`: The AD type used for gradient calculations, or `nothing` if no AD
138139
type was provided.
139140
140141
# Extended help
@@ -172,8 +173,9 @@ Traditionally, this problem has been solved by `unflatten`, because that functio
172173
place values into the VarInfo's metadata alongside the information about ranges and linking.
173174
That way, when we evaluate with `DefaultContext`, we can read this information out again.
174175
However, we want to avoid using a metadata. Thus, here, we _extract this information from
175-
the VarInfo_ a single time when constructing a `FastLDF` object. Inside the FastLDF, we
176-
store a mapping from VarNames to ranges in that vector, along with link status.
176+
the VarInfo_ a single time when constructing a `LogDensityFunction` object. Inside the
177+
LogDensityFunction, we store a mapping from VarNames to ranges in that vector, along with
178+
link status.
177179
178180
For VarNames with identity optics, this is stored in a NamedTuple for efficiency. For all
179181
other VarNames, this is stored in a Dict. The internal data structure used to represent this
@@ -185,13 +187,13 @@ ranges to create an `InitFromParams{VectorWithRanges}`, which lets us very quick
185187
parameter values from the vector.
186188
187189
Note that this assumes that the ranges and link status are static throughout the lifetime of
188-
the `FastLDF` object. Therefore, a `FastLDF` object cannot handle models which have variable
189-
numbers of parameters, or models which may visit random variables in different orders depending
190-
on stochastic control flow. **Indeed, silent errors may occur with such models.** This is a
191-
general limitation of vectorised parameters: the original `unflatten` + `evaluate!!`
192-
approach also fails with such models.
190+
the `LogDensityFunction` object. Therefore, a `LogDensityFunction` object cannot handle
191+
models which have variable numbers of parameters, or models which may visit random variables
192+
in different orders depending on stochastic control flow. **Indeed, silent errors may occur
193+
with such models.** This is a general limitation of vectorised parameters: the original
194+
`unflatten` + `evaluate!!` approach also fails with such models.
193195
"""
194-
struct FastLDF{
196+
struct LogDensityFunction{
195197
M<:Model,
196198
AD<:Union{ADTypes.AbstractADType,Nothing},
197199
F<:Function,
@@ -206,7 +208,7 @@ struct FastLDF{
206208
_adprep::ADP
207209
_dim::Int
208210

209-
function FastLDF(
211+
function LogDensityFunction(
210212
model::Model,
211213
getlogdensity::Function=getlogjoint_internal,
212214
varinfo::AbstractVarInfo=VarInfo(model);
@@ -224,7 +226,7 @@ struct FastLDF{
224226
# Make backend-specific tweaks to the adtype
225227
adtype = DynamicPPL.tweak_adtype(adtype, model, varinfo)
226228
DI.prepare_gradient(
227-
FastLogDensityAt(model, getlogdensity, all_iden_ranges, all_ranges),
229+
LogDensityAt(model, getlogdensity, all_iden_ranges, all_ranges),
228230
adtype,
229231
x,
230232
)
@@ -261,13 +263,13 @@ end
261263
fast_ldf_accs(::typeof(getlogprior)) = AccumulatorTuple((LogPriorAccumulator(),))
262264
fast_ldf_accs(::typeof(getloglikelihood)) = AccumulatorTuple((LogLikelihoodAccumulator(),))
263265

264-
struct FastLogDensityAt{M<:Model,F<:Function,N<:NamedTuple}
266+
struct LogDensityAt{M<:Model,F<:Function,N<:NamedTuple}
265267
model::M
266268
getlogdensity::F
267269
iden_varname_ranges::N
268270
varname_ranges::Dict{VarName,RangeAndLinked}
269271
end
270-
function (f::FastLogDensityAt)(params::AbstractVector{<:Real})
272+
function (f::LogDensityAt)(params::AbstractVector{<:Real})
271273
strategy = InitFromParams(
272274
VectorWithRanges(f.iden_varname_ranges, f.varname_ranges, params), nothing
273275
)
@@ -276,41 +278,58 @@ function (f::FastLogDensityAt)(params::AbstractVector{<:Real})
276278
return f.getlogdensity(vi)
277279
end
278280

279-
function LogDensityProblems.logdensity(fldf::FastLDF, params::AbstractVector{<:Real})
280-
return FastLogDensityAt(
281-
fldf.model, fldf._getlogdensity, fldf._iden_varname_ranges, fldf._varname_ranges
281+
function LogDensityProblems.logdensity(
282+
ldf::LogDensityFunction, params::AbstractVector{<:Real}
283+
)
284+
return LogDensityAt(
285+
ldf.model, ldf._getlogdensity, ldf._iden_varname_ranges, ldf._varname_ranges
282286
)(
283287
params
284288
)
285289
end
286290

287291
function LogDensityProblems.logdensity_and_gradient(
288-
fldf::FastLDF, params::AbstractVector{<:Real}
292+
ldf::LogDensityFunction, params::AbstractVector{<:Real}
289293
)
290294
return DI.value_and_gradient(
291-
FastLogDensityAt(
292-
fldf.model, fldf._getlogdensity, fldf._iden_varname_ranges, fldf._varname_ranges
295+
LogDensityAt(
296+
ldf.model, ldf._getlogdensity, ldf._iden_varname_ranges, ldf._varname_ranges
293297
),
294-
fldf._adprep,
295-
fldf.adtype,
298+
ldf._adprep,
299+
ldf.adtype,
296300
params,
297301
)
298302
end
299303

300-
function LogDensityProblems.capabilities(
301-
::Type{<:DynamicPPL.Experimental.FastLDF{M,Nothing}}
302-
) where {M}
304+
function LogDensityProblems.capabilities(::Type{<:LogDensityFunction{M,Nothing}}) where {M}
303305
return LogDensityProblems.LogDensityOrder{0}()
304306
end
305307
function LogDensityProblems.capabilities(
306-
::Type{<:DynamicPPL.Experimental.FastLDF{M,<:ADTypes.AbstractADType}}
308+
::Type{<:LogDensityFunction{M,<:ADTypes.AbstractADType}}
307309
) where {M}
308310
return LogDensityProblems.LogDensityOrder{1}()
309311
end
310-
function LogDensityProblems.dimension(fldf::FastLDF)
311-
return fldf._dim
312+
function LogDensityProblems.dimension(ldf::LogDensityFunction)
313+
return ldf._dim
312314
end
313315

316+
"""
317+
tweak_adtype(
318+
adtype::ADTypes.AbstractADType,
319+
model::Model,
320+
varinfo::AbstractVarInfo,
321+
)
322+
323+
Return an 'optimised' form of the adtype. This is useful for doing
324+
backend-specific optimisation of the adtype (e.g., for ForwardDiff, calculating
325+
the chunk size: see the method override in `ext/DynamicPPLForwardDiffExt.jl`).
326+
The model is passed as a parameter in case the optimisation depends on the
327+
model.
328+
329+
By default, this just returns the input unchanged.
330+
"""
331+
tweak_adtype(adtype::ADTypes.AbstractADType, ::Model, ::AbstractVarInfo) = adtype
332+
314333
######################################################
315334
# Helper functions to extract ranges and link status #
316335
######################################################

0 commit comments

Comments
 (0)