Skip to content

Commit 0cc9278

Browse files
committed
Remove SimpleVarInfo
1 parent 612d9ec commit 0cc9278

22 files changed

+195
-1250
lines changed

HISTORY.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ As a result of this change, `LogDensityFunction` no longer stores a VarInfo insi
1515
In general, if `ldf` is a `LogDensityFunction`, it is now only valid to access `ldf.model` and `ldf.adtype`.
1616
If you were previously relying on this behaviour, you will need to store a VarInfo separately.
1717

18+
Along with this change, DynamicPPL now exposes the `fast_evaluate!!` method which allows you to hook into this 'fast evaluation' pipeline directly.
19+
Please see the documentation for details.
20+
1821
#### Parent and leaf contexts
1922

2023
The `DynamicPPL.NodeTrait` function has been removed.
@@ -28,6 +31,17 @@ Leaf contexts require no changes, apart from a removal of the `NodeTrait` functi
2831
`ConditionContext` and `PrefixContext` are no longer exported.
2932
You should not need to use these directly, please use `AbstractPPL.condition` and `DynamicPPL.prefix` instead.
3033

34+
#### SimpleVarInfo
35+
36+
`SimpleVarInfo` has been removed.
37+
Its main purpose was for evaluating models rapidly.
38+
However, `fast_evaluate!!` provides a cleaner way of doing this.
39+
In particular, if you want to evaluate a model at a given set of parameters, you can do:
40+
41+
```julia
42+
retval, vi = DynamicPPL.fast_evaluate!!(rng, model, InitFromParams(params), accs)
43+
```
44+
3145
#### Miscellaneous
3246

3347
Removed the method `returned(::Model, values, keys)`; please use `returned(::Model, ::AbstractDict{<:VarName})` instead.

benchmarks/benchmarks.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,7 @@ chosen_combinations = [
5959
false,
6060
),
6161
("Smorgasbord", smorgasbord_instance, :typed, :forwarddiff, false),
62-
("Smorgasbord", smorgasbord_instance, :simple_namedtuple, :forwarddiff, true),
6362
("Smorgasbord", smorgasbord_instance, :untyped, :forwarddiff, true),
64-
("Smorgasbord", smorgasbord_instance, :simple_dict, :forwarddiff, true),
6563
("Smorgasbord", smorgasbord_instance, :typed_vector, :forwarddiff, true),
6664
("Smorgasbord", smorgasbord_instance, :untyped_vector, :forwarddiff, true),
6765
("Smorgasbord", smorgasbord_instance, :typed, :reversediff, true),

benchmarks/src/DynamicPPLBenchmarks.jl

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module DynamicPPLBenchmarks
22

3-
using DynamicPPL: VarInfo, SimpleVarInfo, VarName
3+
using DynamicPPL: VarInfo, VarName
44
using BenchmarkTools: BenchmarkGroup, @benchmarkable
55
using DynamicPPL: DynamicPPL
66
using ADTypes: ADTypes
@@ -58,8 +58,6 @@ Create a benchmark suite for `model` using the selected varinfo type and AD back
5858
Available varinfo choices:
5959
• `:untyped` → uses `DynamicPPL.untyped_varinfo(model)`
6060
• `:typed` → uses `DynamicPPL.typed_varinfo(model)`
61-
• `:simple_namedtuple` → uses `SimpleVarInfo{Float64}(model())`
62-
• `:simple_dict` → builds a `SimpleVarInfo{Float64}` from a Dict (pre-populated with the model’s outputs)
6361
6462
The AD backend should be specified as a Symbol (e.g. `:forwarddiff`, `:reversediff`, `:zygote`).
6563
@@ -74,12 +72,6 @@ function make_suite(model, varinfo_choice::Symbol, adbackend::Symbol, islinked::
7472
DynamicPPL.untyped_varinfo(rng, model)
7573
elseif varinfo_choice == :typed
7674
DynamicPPL.typed_varinfo(rng, model)
77-
elseif varinfo_choice == :simple_namedtuple
78-
SimpleVarInfo{Float64}(model(rng))
79-
elseif varinfo_choice == :simple_dict
80-
retvals = model(rng)
81-
vns = [VarName{k}() for k in keys(retvals)]
82-
SimpleVarInfo{Float64}(Dict(zip(vns, values(retvals))))
8375
elseif varinfo_choice == :typed_vector
8476
DynamicPPL.typed_vector_varinfo(rng, model)
8577
elseif varinfo_choice == :untyped_vector

benchmarks/src/Models.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
Models for benchmarking Turing.jl.
33
44
Each model returns a NamedTuple of all the random variables in the model that are not
5-
observed (this is used for constructing SimpleVarInfos).
5+
observed.
66
"""
77
module Models
88

docs/src/api.md

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -359,12 +359,6 @@ set_transformed!!
359359
Base.empty!
360360
```
361361

362-
#### `SimpleVarInfo`
363-
364-
```@docs
365-
SimpleVarInfo
366-
```
367-
368362
### Accumulators
369363

370364
The subtypes of [`AbstractVarInfo`](@ref) store the cumulative log prior and log likelihood, and sometimes other variables that change during executing, in what are called accumulators.

src/DynamicPPL.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ import Base:
4646
# VarInfo
4747
export AbstractVarInfo,
4848
VarInfo,
49-
SimpleVarInfo,
5049
AbstractAccumulator,
5150
LogLikelihoodAccumulator,
5251
LogPriorAccumulator,
@@ -174,7 +173,7 @@ Abstract supertype for data structures that capture random variables when execut
174173
probabilistic model and accumulate log densities such as the log likelihood or the
175174
log joint probability of the model.
176175
177-
See also: [`VarInfo`](@ref), [`SimpleVarInfo`](@ref).
176+
See also: [`VarInfo`](@ref).
178177
"""
179178
abstract type AbstractVarInfo <: AbstractModelTrace end
180179

@@ -196,7 +195,6 @@ include("default_accumulators.jl")
196195
include("abstract_varinfo.jl")
197196
include("threadsafe.jl")
198197
include("varinfo.jl")
199-
include("simple_varinfo.jl")
200198
include("onlyaccs.jl")
201199
include("compiler.jl")
202200
include("pointwise_logdensities.jl")

src/abstract_varinfo.jl

Lines changed: 4 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -502,52 +502,6 @@ If no `Type` is provided, return values as stored in `varinfo`.
502502
503503
# Examples
504504
505-
`SimpleVarInfo` with `NamedTuple`:
506-
507-
```jldoctest
508-
julia> data = (x = 1.0, m = [2.0]);
509-
510-
julia> values_as(SimpleVarInfo(data))
511-
(x = 1.0, m = [2.0])
512-
513-
julia> values_as(SimpleVarInfo(data), NamedTuple)
514-
(x = 1.0, m = [2.0])
515-
516-
julia> values_as(SimpleVarInfo(data), OrderedDict)
517-
OrderedDict{VarName{sym, typeof(identity)} where sym, Any} with 2 entries:
518-
x => 1.0
519-
m => [2.0]
520-
521-
julia> values_as(SimpleVarInfo(data), Vector)
522-
2-element Vector{Float64}:
523-
1.0
524-
2.0
525-
```
526-
527-
`SimpleVarInfo` with `OrderedDict`:
528-
529-
```jldoctest
530-
julia> data = OrderedDict{Any,Any}(@varname(x) => 1.0, @varname(m) => [2.0]);
531-
532-
julia> values_as(SimpleVarInfo(data))
533-
OrderedDict{Any, Any} with 2 entries:
534-
x => 1.0
535-
m => [2.0]
536-
537-
julia> values_as(SimpleVarInfo(data), NamedTuple)
538-
(x = 1.0, m = [2.0])
539-
540-
julia> values_as(SimpleVarInfo(data), OrderedDict)
541-
OrderedDict{Any, Any} with 2 entries:
542-
x => 1.0
543-
m => [2.0]
544-
545-
julia> values_as(SimpleVarInfo(data), Vector)
546-
2-element Vector{Float64}:
547-
1.0
548-
2.0
549-
```
550-
551505
`VarInfo` with `NamedTuple` of `Metadata`:
552506
553507
```jldoctest
@@ -828,8 +782,8 @@ function link!!(vi::AbstractVarInfo, vns::VarNameTuple, model::Model)
828782
return link!!(default_transformation(model, vi), vi, vns, model)
829783
end
830784
function link!!(t::DynamicTransformation, vi::AbstractVarInfo, model::Model)
831-
# Note that in practice this method is only called for SimpleVarInfo, because VarInfo
832-
# has a dedicated implementation
785+
# Note that VarInfo has a dedicated implementation so this is only a generic
786+
# fallback (previously used for SimpleVarInfo)
833787
model = setleafcontext(model, DynamicTransformationContext{false}())
834788
vi = last(evaluate!!(model, vi))
835789
return set_transformed!!(vi, t)
@@ -890,8 +844,8 @@ function invlink!!(vi::AbstractVarInfo, vns::VarNameTuple, model::Model)
890844
return invlink!!(default_transformation(model, vi), vi, vns, model)
891845
end
892846
function invlink!!(::DynamicTransformation, vi::AbstractVarInfo, model::Model)
893-
# Note that in practice this method is only called for SimpleVarInfo, because VarInfo
894-
# has a dedicated implementation
847+
# Note that VarInfo has a dedicated implementation so this is only a generic
848+
# fallback (previously used for SimpleVarInfo)
895849
model = setleafcontext(model, DynamicTransformationContext{true}())
896850
vi = last(evaluate!!(model, vi))
897851
return set_transformed!!(vi, NoTransformation())
@@ -946,47 +900,6 @@ This will be called prior to `model` evaluation, allowing one to perform a singl
946900
basis as is done with [`DynamicTransformation`](@ref).
947901
948902
See also: [`StaticTransformation`](@ref), [`DynamicTransformation`](@ref).
949-
950-
# Examples
951-
```julia-repl
952-
julia> using DynamicPPL, Distributions, Bijectors
953-
954-
julia> @model demo() = x ~ Normal()
955-
demo (generic function with 2 methods)
956-
957-
julia> # By subtyping `Transform`, we inherit the `(inv)link!!`.
958-
struct MyBijector <: Bijectors.Transform end
959-
960-
julia> # Define some dummy `inverse` which will be used in the `link!!` call.
961-
Bijectors.inverse(f::MyBijector) = identity
962-
963-
julia> # We need to define `with_logabsdet_jacobian` for `MyBijector`
964-
# (`identity` already has `with_logabsdet_jacobian` defined)
965-
function Bijectors.with_logabsdet_jacobian(::MyBijector, x)
966-
# Just using a large number of the logabsdet-jacobian term
967-
# for demonstration purposes.
968-
return (x, 1000)
969-
end
970-
971-
julia> # Change the `default_transformation` for our model to be a
972-
# `StaticTransformation` using `MyBijector`.
973-
function DynamicPPL.default_transformation(::Model{typeof(demo)})
974-
return DynamicPPL.StaticTransformation(MyBijector())
975-
end
976-
977-
julia> model = demo();
978-
979-
julia> vi = SimpleVarInfo(x=1.0)
980-
SimpleVarInfo((x = 1.0,), 0.0)
981-
982-
julia> # Uses the `inverse` of `MyBijector`, which we have defined as `identity`
983-
vi_linked = link!!(vi, model)
984-
Transformed SimpleVarInfo((x = 1.0,), 0.0)
985-
986-
julia> # Now performs a single `invlink!!` before model evaluation.
987-
logjoint(model, vi_linked)
988-
-1001.4189385332047
989-
```
990903
"""
991904
function maybe_invlink_before_eval!!(vi::AbstractVarInfo, model::Model)
992905
return maybe_invlink_before_eval!!(transformation(vi), vi, model)

src/contexts/transformation.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ constrained space if `isinverse` or unconstrained if `!isinverse`.
77
Note that some `AbstractVarInfo` types, must notably `VarInfo`, override the
88
`DynamicTransformationContext` methods with more efficient implementations.
99
`DynamicTransformationContext` is a fallback for when we need to evaluate the model to know
10-
how to do the transformation, used by e.g. `SimpleVarInfo`.
10+
how to do the transformation.
1111
"""
1212
struct DynamicTransformationContext{isinverse} <: AbstractContext end
1313

src/fasteval.jl

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -334,11 +334,6 @@ tweak_adtype(adtype::ADTypes.AbstractADType, ::Model, ::AbstractVarInfo) = adtyp
334334
# Helper functions to extract ranges and link status #
335335
######################################################
336336

337-
# This fails for SimpleVarInfo, but honestly there is no reason to support that here. The
338-
# fact is that evaluation doesn't use a VarInfo, it only uses it once to generate the ranges
339-
# and link status. So there is no motivation to use SimpleVarInfo inside a
340-
# LogDensityFunction any more, we can just always use typed VarInfo. In fact one could argue
341-
# that there is no purpose in supporting untyped VarInfo either.
342337
"""
343338
get_ranges_and_linked(varinfo::VarInfo)
344339

src/model.jl

Lines changed: 116 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1062,8 +1062,11 @@ Base.nameof(model::Model{<:Function}) = nameof(model.f)
10621062
Generate a sample of type `T` from the prior distribution of the `model`.
10631063
"""
10641064
function Base.rand(rng::Random.AbstractRNG, ::Type{T}, model::Model) where {T}
1065-
x = last(init!!(rng, model, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}())))
1066-
return values_as(x, T)
1065+
# TODO(penelopeysm): This can be done with an accumulator instead. For
1066+
# T = Dict, ValuesAsInModelAcc can already do it. For T = NamedTuple we
1067+
# would just need a similar accumulator that collects into a NamedTuple
1068+
# rather than a Dict.
1069+
return values_as(VarInfo(rng, model), T)
10671070
end
10681071

10691072
# Default RNG and type
@@ -1155,12 +1158,115 @@ julia> returned(model, Dict{VarName,Float64}(@varname(m) => 2.0))
11551158
```
11561159
"""
11571160
function returned(model::Model, parameters::Union{NamedTuple,AbstractDict{<:VarName}})
1158-
vi = DynamicPPL.setaccs!!(VarInfo(), ())
1159-
# Note: we can't use `fix(model, parameters)` because
1160-
# https://github.com/TuringLang/DynamicPPL.jl/issues/1097
1161-
# Use `nothing` as the fallback to ensure that any missing parameters cause an error
1162-
ctx = InitContext(Random.default_rng(), InitFromParams(parameters, nothing))
1163-
new_model = setleafcontext(model, ctx)
1164-
# We can't use new_model() because that overwrites it with an InitContext of its own.
1165-
return first(evaluate!!(new_model, vi))
1161+
accs = AccumulatorTuple()
1162+
retval, _ = DynamicPPL.fast_evaluate!!(model, InitFromParams(parameters, nothing), accs)
1163+
return retval
1164+
end
1165+
1166+
"""
1167+
logjoint(model::Model, θ::Union{NamedTuple,AbstractDict})
1168+
1169+
Return the log joint probability of variables `θ` for the probabilistic `model`.
1170+
1171+
See [`logprior`](@ref) and [`loglikelihood`](@ref).
1172+
1173+
# Examples
1174+
```jldoctest; setup=:(using Distributions)
1175+
julia> @model function demo(x)
1176+
m ~ Normal()
1177+
for i in eachindex(x)
1178+
x[i] ~ Normal(m, 1.0)
1179+
end
1180+
end
1181+
demo (generic function with 2 methods)
1182+
1183+
julia> # Using a `NamedTuple`.
1184+
logjoint(demo([1.0]), (m = 100.0, ))
1185+
-9902.33787706641
1186+
1187+
julia> # Using a `OrderedDict`.
1188+
logjoint(demo([1.0]), OrderedDict(@varname(m) => 100.0))
1189+
-9902.33787706641
1190+
1191+
julia> # Truth.
1192+
logpdf(Normal(100.0, 1.0), 1.0) + logpdf(Normal(), 100.0)
1193+
-9902.33787706641
1194+
```
1195+
"""
1196+
function logjoint(model::Model, θ::Union{NamedTuple,AbstractDict})
1197+
accs = AccumulatorTuple((LogPriorAccumulator(), LogLikelihoodAccumulator()))
1198+
_, vi = DynamicPPL.fast_evaluate!!(model, InitFromParams(θ, nothing), accs)
1199+
return getlogjoint(vi)
1200+
end
1201+
1202+
"""
1203+
logprior(model::Model, θ::Union{NamedTuple,AbstractDict})
1204+
1205+
Return the log prior probability of variables `θ` for the probabilistic `model`.
1206+
1207+
See also [`logjoint`](@ref) and [`loglikelihood`](@ref).
1208+
1209+
# Examples
1210+
```jldoctest; setup=:(using Distributions)
1211+
julia> @model function demo(x)
1212+
m ~ Normal()
1213+
for i in eachindex(x)
1214+
x[i] ~ Normal(m, 1.0)
1215+
end
1216+
end
1217+
demo (generic function with 2 methods)
1218+
1219+
julia> # Using a `NamedTuple`.
1220+
logprior(demo([1.0]), (m = 100.0, ))
1221+
-5000.918938533205
1222+
1223+
julia> # Using a `OrderedDict`.
1224+
logprior(demo([1.0]), OrderedDict(@varname(m) => 100.0))
1225+
-5000.918938533205
1226+
1227+
julia> # Truth.
1228+
logpdf(Normal(), 100.0)
1229+
-5000.918938533205
1230+
```
1231+
"""
1232+
function logprior(model::Model, θ::Union{NamedTuple,AbstractDict})
1233+
accs = AccumulatorTuple((LogPriorAccumulator(),))
1234+
_, vi = DynamicPPL.fast_evaluate!!(model, InitFromParams(θ, nothing), accs)
1235+
return getlogprior(vi)
1236+
end
1237+
1238+
"""
1239+
loglikelihood(model::Model, θ::Union{NamedTuple,AbstractDict})
1240+
1241+
Return the log likelihood of variables `θ` for the probabilistic `model`.
1242+
1243+
See also [`logjoint`](@ref) and [`logprior`](@ref).
1244+
1245+
# Examples
1246+
```jldoctest; setup=:(using Distributions)
1247+
julia> @model function demo(x)
1248+
m ~ Normal()
1249+
for i in eachindex(x)
1250+
x[i] ~ Normal(m, 1.0)
1251+
end
1252+
end
1253+
demo (generic function with 2 methods)
1254+
1255+
julia> # Using a `NamedTuple`.
1256+
loglikelihood(demo([1.0]), (m = 100.0, ))
1257+
-4901.418938533205
1258+
1259+
julia> # Using a `OrderedDict`.
1260+
loglikelihood(demo([1.0]), OrderedDict(@varname(m) => 100.0))
1261+
-4901.418938533205
1262+
1263+
julia> # Truth.
1264+
logpdf(Normal(100.0, 1.0), 1.0)
1265+
-4901.418938533205
1266+
```
1267+
"""
1268+
function Distributions.loglikelihood(model::Model, θ::Union{NamedTuple,AbstractDict})
1269+
accs = AccumulatorTuple((LogLikelihoodAccumulator(),))
1270+
_, vi = DynamicPPL.fast_evaluate!!(model, InitFromParams(θ, nothing), accs)
1271+
return getloglikelihood(vi)
11661272
end

0 commit comments

Comments
 (0)