From 709dc9ecff1ed7cde441447ca6a6108f182a219c Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 9 Jul 2025 20:47:54 +0100 Subject: [PATCH 01/16] point to unmerged AbstractPPL branch --- test/Project.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/Project.toml b/test/Project.toml index afecba1c4..c9c1a6478 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -50,3 +50,6 @@ ReverseDiff = "1" StableRNGs = "1" Zygote = "0.6, 0.7" julia = "1.10" + +[sources] +AbstractPPL = {url = "https://github.com/TuringLang/AbstractPPL.jl", rev = "py/hasgetvalue"} From a19c9a6fec272877317734a4a454a3fbb874698b Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 9 Jul 2025 20:53:45 +0100 Subject: [PATCH 02/16] Remove code that was moved to AbstractPPL --- Project.toml | 3 + src/utils.jl | 193 --------------------------------------------------- 2 files changed, 3 insertions(+), 193 deletions(-) diff --git a/Project.toml b/Project.toml index c23845b8c..5d9059be9 100644 --- a/Project.toml +++ b/Project.toml @@ -75,3 +75,6 @@ Requires = "1" Statistics = "1" Test = "1.6" julia = "1.10.8" + +[sources] +AbstractPPL = {url = "https://github.com/TuringLang/AbstractPPL.jl", rev = "py/hasgetvalue"} diff --git a/src/utils.jl b/src/utils.jl index 0f4d98b11..af2891a2b 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -751,199 +751,6 @@ function unflatten(original::AbstractDict, x::AbstractVector) return D(zip(keys(original), unflatten(collect(values(original)), x))) end -# TODO: Move `getvalue` and `hasvalue` to AbstractPPL.jl. -""" - getvalue(vals, vn::VarName) - -Return the value(s) in `vals` represented by `vn`. - -Note that this method is different from `getindex`. See examples below. - -# Examples - -For `NamedTuple`: - -```jldoctest -julia> vals = (x = [1.0],); - -julia> DynamicPPL.getvalue(vals, @varname(x)) # same as `getindex` -1-element Vector{Float64}: - 1.0 - -julia> DynamicPPL.getvalue(vals, @varname(x[1])) # different from `getindex` -1.0 - -julia> DynamicPPL.getvalue(vals, @varname(x[2])) -ERROR: BoundsError: attempt to access 1-element Vector{Float64} at index [2] -[...] -``` - -For `AbstractDict`: - -```jldoctest -julia> vals = Dict(@varname(x) => [1.0]); - -julia> DynamicPPL.getvalue(vals, @varname(x)) # same as `getindex` -1-element Vector{Float64}: - 1.0 - -julia> DynamicPPL.getvalue(vals, @varname(x[1])) # different from `getindex` -1.0 - -julia> DynamicPPL.getvalue(vals, @varname(x[2])) -ERROR: BoundsError: attempt to access 1-element Vector{Float64} at index [2] -[...] -``` - -In the `AbstractDict` case we can also have keys such as `v[1]`: - -```jldoctest -julia> vals = Dict(@varname(x[1]) => [1.0,]); - -julia> DynamicPPL.getvalue(vals, @varname(x[1])) # same as `getindex` -1-element Vector{Float64}: - 1.0 - -julia> DynamicPPL.getvalue(vals, @varname(x[1][1])) # different from `getindex` -1.0 - -julia> DynamicPPL.getvalue(vals, @varname(x[1][2])) -ERROR: BoundsError: attempt to access 1-element Vector{Float64} at index [2] -[...] - -julia> DynamicPPL.getvalue(vals, @varname(x[2][1])) -ERROR: KeyError: key x[2][1] not found -[...] -``` -""" -getvalue(vals::NamedTuple, vn::VarName) = get(vals, vn) -getvalue(vals::AbstractDict, vn::VarName) = nested_getindex(vals, vn) - -""" - hasvalue(vals, vn::VarName) - -Determine whether `vals` has a mapping for a given `vn`, as compatible with [`getvalue`](@ref). - -# Examples -With `x` as a `NamedTuple`: - -```jldoctest -julia> DynamicPPL.hasvalue((x = 1.0, ), @varname(x)) -true - -julia> DynamicPPL.hasvalue((x = 1.0, ), @varname(x[1])) -false - -julia> DynamicPPL.hasvalue((x = [1.0],), @varname(x)) -true - -julia> DynamicPPL.hasvalue((x = [1.0],), @varname(x[1])) -true - -julia> DynamicPPL.hasvalue((x = [1.0],), @varname(x[2])) -false -``` - -With `x` as a `AbstractDict`: - -```jldoctest -julia> DynamicPPL.hasvalue(Dict(@varname(x) => 1.0, ), @varname(x)) -true - -julia> DynamicPPL.hasvalue(Dict(@varname(x) => 1.0, ), @varname(x[1])) -false - -julia> DynamicPPL.hasvalue(Dict(@varname(x) => [1.0]), @varname(x)) -true - -julia> DynamicPPL.hasvalue(Dict(@varname(x) => [1.0]), @varname(x[1])) -true - -julia> DynamicPPL.hasvalue(Dict(@varname(x) => [1.0]), @varname(x[2])) -false -``` - -In the `AbstractDict` case we can also have keys such as `v[1]`: - -```jldoctest -julia> vals = Dict(@varname(x[1]) => [1.0,]); - -julia> DynamicPPL.hasvalue(vals, @varname(x[1])) # same as `haskey` -true - -julia> DynamicPPL.hasvalue(vals, @varname(x[1][1])) # different from `haskey` -true - -julia> DynamicPPL.hasvalue(vals, @varname(x[1][2])) -false - -julia> DynamicPPL.hasvalue(vals, @varname(x[2][1])) -false -``` -""" -function hasvalue(vals::NamedTuple, vn::VarName{sym}) where {sym} - # LHS: Ensure that `nt` indeed has the property we want. - # RHS: Ensure that the optic can view into `nt`. - return haskey(vals, sym) && canview(getoptic(vn), getproperty(vals, sym)) -end - -# For `dictlike` we need to check wether `vn` is "immediately" present, or -# if some ancestor of `vn` is present in `dictlike`. -function hasvalue(vals::AbstractDict, vn::VarName) - # First we check if `vn` is present as is. - haskey(vals, vn) && return true - - # If `vn` is not present, we check any parent-varnames by attempting - # to split the optic into the key / `parent` and the extraction optic / `child`. - # If `issuccess` is `true`, we found such a split, and hence `vn` is present. - parent, child, issuccess = splitoptic(getoptic(vn)) do optic - o = optic === nothing ? identity : optic - haskey(vals, VarName{getsym(vn)}(o)) - end - # When combined with `VarInfo`, `nothing` is equivalent to `identity`. - keyoptic = parent === nothing ? identity : parent - - # Return early if no such split could be found. - issuccess || return false - - # At this point we just need to check that we `canview` the value. - value = vals[VarName{getsym(vn)}(keyoptic)] - - return canview(child, value) -end - -""" - nested_getindex(values::AbstractDict, vn::VarName) - -Return value corresponding to `vn` in `values` by also looking -in the the actual values of the dict. -""" -function nested_getindex(values::AbstractDict, vn::VarName) - maybeval = get(values, vn, nothing) - if maybeval !== nothing - return maybeval - end - - # Split the optic into the key / `parent` and the extraction optic / `child`. - parent, child, issuccess = splitoptic(getoptic(vn)) do optic - o = optic === nothing ? identity : optic - haskey(values, VarName{getsym(vn)}(o)) - end - # When combined with `VarInfo`, `nothing` is equivalent to `identity`. - keyoptic = parent === nothing ? identity : parent - - # If we found a valid split, then we can extract the value. - if !issuccess - # At this point we just throw an error since the key could not be found. - throw(KeyError(vn)) - end - - # TODO: Should we also check that we `canview` the extracted `value` - # rather than just let it fail upon `get` call? - value = values[VarName{getsym(vn)}(keyoptic)] - return child(value) -end - """ update_values!!(vi::AbstractVarInfo, vals::NamedTuple, vns) From f00879bec599ebf0dea9c95428730a0d926d83b3 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 9 Jul 2025 22:13:22 +0100 Subject: [PATCH 03/16] Remove Dictionaries with Any key type --- src/DynamicPPL.jl | 2 +- src/model.jl | 12 ++++++++---- src/simple_varinfo.jl | 8 ++++---- src/test_utils/varinfo.jl | 2 +- src/values_as_in_model.jl | 4 ++-- src/varnamedvector.jl | 2 +- test/runtests.jl | 3 +++ test/simple_varinfo.jl | 2 +- test/varinfo.jl | 12 +++++------- 9 files changed, 26 insertions(+), 21 deletions(-) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 69e489ce6..4c2f0bd00 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -23,7 +23,7 @@ using DocStringExtensions using Random: Random # For extending -import AbstractPPL: predict +import AbstractPPL: predict, hasvalue, getvalue # TODO: Remove these when it's possible. import Bijectors: link, invlink diff --git a/src/model.jl b/src/model.jl index 93e77eaec..72a7ac294 100644 --- a/src/model.jl +++ b/src/model.jl @@ -981,7 +981,11 @@ Base.nameof(model::Model{<:Function}) = nameof(model.f) Generate a sample of type `T` from the prior distribution of the `model`. """ function Base.rand(rng::Random.AbstractRNG, ::Type{T}, model::Model) where {T} - x = last(evaluate_and_sample!!(rng, model, SimpleVarInfo{Float64}(OrderedDict()))) + x = last( + evaluate_and_sample!!( + rng, model, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}()) + ), + ) return values_as(x, T) end @@ -1028,7 +1032,7 @@ julia> logjoint(demo_model([1., 2.]), chain); function logjoint(model::Model, chain::AbstractMCMC.AbstractChains) var_info = VarInfo(model) # extract variables info from the model map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx) - argvals_dict = OrderedDict( + argvals_dict = OrderedDict{VarName,Any}( vn_parent => values_from_chain(var_info, vn_parent, chain, chain_idx, iteration_idx) for vn_parent in keys(var_info) @@ -1082,7 +1086,7 @@ julia> logprior(demo_model([1., 2.]), chain); function logprior(model::Model, chain::AbstractMCMC.AbstractChains) var_info = VarInfo(model) # extract variables info from the model map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx) - argvals_dict = OrderedDict( + argvals_dict = OrderedDict{VarName,Any}( vn_parent => values_from_chain(var_info, vn_parent, chain, chain_idx, iteration_idx) for vn_parent in keys(var_info) @@ -1136,7 +1140,7 @@ julia> loglikelihood(demo_model([1., 2.]), chain); function Distributions.loglikelihood(model::Model, chain::AbstractMCMC.AbstractChains) var_info = VarInfo(model) # extract variables info from the model map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx) - argvals_dict = OrderedDict( + argvals_dict = OrderedDict{VarName,Any}( vn_parent => values_from_chain(var_info, vn_parent, chain, chain_idx, iteration_idx) for vn_parent in keys(var_info) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index ddc3275ae..019f9cbc8 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -62,7 +62,7 @@ ERROR: type NamedTuple has no field x [...] julia> # If one does not know the varnames, we can use a `OrderedDict` instead. - _, vi = DynamicPPL.evaluate_and_sample!!(rng, m, SimpleVarInfo{Float64}(OrderedDict())); + _, vi = DynamicPPL.evaluate_and_sample!!(rng, m, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}())); julia> # (✓) Sort of fast, but only possible at runtime. vi[@varname(x[1])] @@ -107,7 +107,7 @@ julia> any(xs .< 0) # (✓) Positive probability mass on negative numbers! true julia> # And with `OrderedDict` of course! - _, vi = DynamicPPL.evaluate_and_sample!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(OrderedDict()), true)); + _, vi = DynamicPPL.evaluate_and_sample!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(OrderedDict{VarName,Any}()), true)); julia> vi[@varname(x)] # (✓) -∞ < x < ∞ 0.6225185067787314 @@ -206,7 +206,7 @@ end function SimpleVarInfo(values) return SimpleVarInfo{LogProbType}(values) end -function SimpleVarInfo(values::Union{<:NamedTuple,<:AbstractDict}) +function SimpleVarInfo(values::Union{<:NamedTuple,<:AbstractDict{<:VarName}}) return if isempty(values) # Can't infer from values, so we just use default. SimpleVarInfo{LogProbType}(values) @@ -258,7 +258,7 @@ function SimpleVarInfo{T}(vi::NTVarInfo, ::Type{D}) where {T<:Real,D} end function untyped_simple_varinfo(model::Model) - varinfo = SimpleVarInfo(OrderedDict()) + varinfo = SimpleVarInfo(OrderedDict{VarName,Any}()) return last(evaluate_and_sample!!(model, varinfo)) end diff --git a/src/test_utils/varinfo.jl b/src/test_utils/varinfo.jl index 542fc17fc..26e2aa7ca 100644 --- a/src/test_utils/varinfo.jl +++ b/src/test_utils/varinfo.jl @@ -34,7 +34,7 @@ function setup_varinfos( # SimpleVarInfo svi_typed = SimpleVarInfo(example_values) - svi_untyped = SimpleVarInfo(OrderedDict()) + svi_untyped = SimpleVarInfo(OrderedDict{VarName,Any}()) svi_vnv = SimpleVarInfo(DynamicPPL.VarNamedVector()) varinfos = map(( diff --git a/src/values_as_in_model.jl b/src/values_as_in_model.jl index 4922ddbb0..a10141dec 100644 --- a/src/values_as_in_model.jl +++ b/src/values_as_in_model.jl @@ -12,12 +12,12 @@ $(TYPEDFIELDS) """ struct ValuesAsInModelAccumulator <: AbstractAccumulator "values that are extracted from the model" - values::OrderedDict + values::OrderedDict{<:VarName} "whether to extract variables on the LHS of :=" include_colon_eq::Bool end function ValuesAsInModelAccumulator(include_colon_eq) - return ValuesAsInModelAccumulator(OrderedDict(), include_colon_eq) + return ValuesAsInModelAccumulator(OrderedDict{VarName,Any}(), include_colon_eq) end accumulator_name(::Type{<:ValuesAsInModelAccumulator}) = :ValuesAsInModel diff --git a/src/varnamedvector.jl b/src/varnamedvector.jl index 965db96d5..5de0874c9 100644 --- a/src/varnamedvector.jl +++ b/src/varnamedvector.jl @@ -1482,7 +1482,7 @@ function values_as(vnv::VarNamedVector, ::Type{D}) where {D<:AbstractDict} end # See the docstring of `getvalue` for the semantics of `hasvalue` and `getvalue`, and how -# they differ from `haskey` and `getindex`. They can be found in src/utils.jl. +# they differ from `haskey` and `getindex`. They can be found in AbstractPPL.jl. # TODO(mhauru) This is tricky to implement in the general case, and the below implementation # only covers some simple cases. It's probably sufficient in most situations though. diff --git a/test/runtests.jl b/test/runtests.jl index c60c06786..e6eb42673 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -27,6 +27,9 @@ using LinearAlgebra # Diagonal using JET: JET +# need to call this to get the AbstractPPL I think +Pkg.update() + using Combinatorics: combinations using OrderedCollections: OrderedSet diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index e300c651e..2f080e66b 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -90,7 +90,7 @@ DynamicPPL.TestUtils.DEMO_MODELS values_constrained = DynamicPPL.TestUtils.rand_prior_true(model) @testset "$(typeof(vi))" for vi in ( - SimpleVarInfo(Dict()), + SimpleVarInfo(Dict{VarName,Any}()), SimpleVarInfo(values_constrained), SimpleVarInfo(DynamicPPL.VarNamedVector()), DynamicPPL.typed_varinfo(model), diff --git a/test/varinfo.jl b/test/varinfo.jl index 75868eb66..e1cb56135 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -110,7 +110,7 @@ end test_base(VarInfo()) test_base(DynamicPPL.typed_varinfo(VarInfo())) test_base(SimpleVarInfo()) - test_base(SimpleVarInfo(Dict())) + test_base(SimpleVarInfo(Dict{VarName,Any}())) test_base(SimpleVarInfo(DynamicPPL.VarNamedVector())) end @@ -604,8 +604,7 @@ end ## `SimpleVarInfo{<:Dict}` vi = DynamicPPL.settrans!!(SimpleVarInfo(Dict()), true) - # Sample in unconstrained space. - vi = last(DynamicPPL.evaluate_and_sample!!(model, vi)) + vi = DynamicPPL.settrans!!(SimpleVarInfo(Dict{VarName,Any}()), true) f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) x = f(DynamicPPL.getindex_internal(vi, vn)) @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) @@ -750,11 +749,10 @@ end model, (; x=1.0), (@varname(x),); include_threadsafe=true ) @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos - # Skip the severely inconcrete `SimpleVarInfo` types, since checking for type + # Skip the inconcrete `SimpleVarInfo` types, since checking for type # stability for them doesn't make much sense anyway. - if varinfo isa SimpleVarInfo{OrderedDict{Any,Any}} || - varinfo isa - DynamicPPL.ThreadSafeVarInfo{<:SimpleVarInfo{OrderedDict{Any,Any}}} + if varinfo isa SimpleVarInfo{<:AbstractDict} || + varinfo isa DynamicPPL.ThreadSafeVarInfo{<:SimpleVarInfo{<:AbstractDict}} continue end @inferred DynamicPPL.unflatten(varinfo, varinfo[:]) From af7c6fc241e4170eba8bfc7d907e8c01f5258b6b Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 9 Jul 2025 22:46:56 +0100 Subject: [PATCH 04/16] Fix bad merge conflict resolution --- test/varinfo.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/varinfo.jl b/test/varinfo.jl index e1cb56135..7f4c2daad 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -603,8 +603,9 @@ end @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) ## `SimpleVarInfo{<:Dict}` - vi = DynamicPPL.settrans!!(SimpleVarInfo(Dict()), true) vi = DynamicPPL.settrans!!(SimpleVarInfo(Dict{VarName,Any}()), true) + # Sample in unconstrained space. + vi = last(DynamicPPL.evaluate_and_sample!!(model, vi)) f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) x = f(DynamicPPL.getindex_internal(vi, vn)) @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) From 21cf5687e5b109191b890baa13840432650a43fa Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 9 Jul 2025 23:02:18 +0100 Subject: [PATCH 05/16] Fix doctests --- src/simple_varinfo.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 019f9cbc8..0dc2415ca 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -70,11 +70,11 @@ julia> # (✓) Sort of fast, but only possible at runtime. julia> # In addtion, we can only access varnames as they appear in the model! vi[@varname(x)] -ERROR: KeyError: key x not found +ERROR: getvalue: x was not found in the values provided [...] julia> vi[@varname(x[1:2])] -ERROR: KeyError: key x[1:2] not found +ERROR: getvalue: x[1:2] was not found in the values provided [...] ``` @@ -177,11 +177,11 @@ julia> svi_dict[@varname(m.a[1])] 1.0 julia> svi_dict[@varname(m.a[2])] -ERROR: BoundsError: attempt to access 1-element Vector{Float64} at index [2] +ERROR: getvalue: m.a[2] was not found in the values provided [...] julia> svi_dict[@varname(m.b)] -ERROR: type NamedTuple has no field b +ERROR: getvalue: m.b was not found in the values provided [...] ``` """ From 54446c17a50c948a63877de64170a64307a9aa92 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 9 Jul 2025 23:20:10 +0100 Subject: [PATCH 06/16] Implement InitContext --- src/DynamicPPL.jl | 7 ++ src/contexts/init.jl | 180 +++++++++++++++++++++++++++++++++++++++++++ src/model.jl | 33 ++++++++ test/contexts.jl | 12 ++- 4 files changed, 231 insertions(+), 1 deletion(-) create mode 100644 src/contexts/init.jl diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 4c2f0bd00..6b20899d9 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -109,6 +109,12 @@ export AbstractVarInfo, ConditionContext, assume, tilde_assume, + # Initialisation + InitContext, + AbstractInitStrategy, + PriorInit, + UniformInit, + ParamsInit, # Pseudo distributions NamedDist, NoDist, @@ -175,6 +181,7 @@ include("sampler.jl") include("varname.jl") include("distribution_wrappers.jl") include("contexts.jl") +include("contexts/init.jl") include("submodel.jl") include("varnamedvector.jl") include("accumulators.jl") diff --git a/src/contexts/init.jl b/src/contexts/init.jl new file mode 100644 index 000000000..580b1a666 --- /dev/null +++ b/src/contexts/init.jl @@ -0,0 +1,180 @@ +""" + AbstractInitStrategy + +Abstract type representing the possible ways of initialising new values for +the random variables in a model (e.g., when creating a new VarInfo). +""" +abstract type AbstractInitStrategy end + +""" + init(rng::Random.AbstractRNG, vn::VarName, dist::Distribution, strategy::AbstractInitStrategy) + +Generate a new value for a random variable with the given distribution. + +!!! warning "Values must be unlinked" + The values returned by `init` are always in the untransformed space, i.e., + they must be within the support of the original distribution. That means that, + for example, `init(rng, dist, u::UniformInit)` will in general return values that + are outside the range [u.lower, u.upper]. +""" +function init end + +""" + PriorInit() + +Obtain new values by sampling from the prior distribution. +""" +struct PriorInit <: AbstractInitStrategy end +init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, ::PriorInit) = rand(rng, dist) + +""" + UniformInit() + UniformInit(lower, upper) + +Obtain new values by first transforming the distribution of the random variable +to unconstrained space, and then sampling a value uniformly between `lower` and +`upper`. + +If unspecified, defaults to `(lower, upper) = (-2, 2)`, which mimics Stan's +default initialisation strategy. + +# References + +[Stan reference manual page on initialization](https://mc-stan.org/docs/reference-manual/execution.html#initialization) +""" +struct UniformInit{T<:AbstractFloat} <: AbstractInitStrategy + lower::T + upper::T + function UniformInit(lower::T, upper::T) where {T<:AbstractFloat} + lower > upper && + throw(ArgumentError("`lower` must be less than or equal to `upper`")) + return new{T}(lower, upper) + end + UniformInit() = UniformInit(-2.0, 2.0) +end +function init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, u::UniformInit) + b = Bijectors.bijector(dist) + sz = Bijectors.output_size(b, size(dist)) + y = rand(rng, Uniform(u.lower, u.upper), sz) + b_inv = Bijectors.inverse(b) + x = b_inv(y) + # 0-dim arrays: https://github.com/TuringLang/Bijectors.jl/issues/398 + if x isa Array{<:Any,0} + x = x[] + end + return x +end + +""" + ParamsInit(params::AbstractDict{<:VarName}, default::AbstractInitStrategy=PriorInit()) + ParamsInit(params::NamedTuple, default::AbstractInitStrategy=PriorInit()) + +Obtain new values by extracting them from the given dictionary or NamedTuple. +The parameter `default` specifies how new values are to be obtained if they +cannot be found in `params`, or they are specified as `missing`. The default +for `default` is `PriorInit()`. + +!!! note + These values must be provided in the space of the untransformed distribution. +""" +struct ParamsInit{P,S<:AbstractInitStrategy} <: AbstractInitStrategy + params::P + default::S + function ParamsInit(params::AbstractDict{<:VarName}, default::AbstractInitStrategy) + return new{typeof(params),typeof(default)}(params, default) + end + ParamsInit(params::AbstractDict{<:VarName}) = ParamsInit(params, PriorInit()) + function ParamsInit(params::NamedTuple, default::AbstractInitStrategy=PriorInit()) + return ParamsInit(to_varname_dict(params), default) + end +end +function init(rng::Random.AbstractRNG, vn::VarName, dist::Distribution, p::ParamsInit) + # TODO(penelopeysm): We should do a check to make sure that all of the + # parameters in `p.params` were actually used, and either warn or error if + # they aren't. This is non-trivial (we need to use something like + # varname_leaves), so I'm going to defer it to a later PR. + return if hasvalue(p.params, vn, dist) + x = getvalue(p.params, vn, dist) + if x === missing + init(rng, vn, dist, p.default) + else + # TODO(penelopeysm): We could also check that the type of x matches + # the dist? + x + end + else + init(rng, vn, dist, p.default) + end +end + +""" + InitContext( + [rng::Random.AbstractRNG=Random.default_rng()], + [strategy::AbstractInitStrategy=PriorInit()], + ) + +A leaf context that indicates that new values for random variables are +currently being obtained through sampling. Used e.g. when initialising a fresh +VarInfo. Note that, if `leafcontext(model.context) isa InitContext`, then +`evaluate!!(model, varinfo)` will override all values in the VarInfo. +""" +struct InitContext{R<:Random.AbstractRNG,S<:AbstractInitStrategy} <: AbstractContext + rng::R + strategy::S + function InitContext( + rng::Random.AbstractRNG, strategy::AbstractInitStrategy=PriorInit() + ) + return new{typeof(rng),typeof(strategy)}(rng, strategy) + end + function InitContext(strategy::AbstractInitStrategy=PriorInit()) + return InitContext(Random.default_rng(), strategy) + end +end +NodeTrait(::InitContext) = IsLeaf() + +function tilde_assume( + ctx::InitContext, dist::Distribution, vn::VarName, vi::AbstractVarInfo +) + in_varinfo = haskey(vi, vn) + # `init()` always returns values in original space, i.e. possibly + # constrained + x = init(ctx.rng, vn, dist, ctx.strategy) + # Determine whether to insert a transformed value into the VarInfo. + # If the VarInfo alrady had a value for this variable, we will + # keep the same linked status as in the original VarInfo. If not, we + # check the rest of the VarInfo to see if other variables are linked. + # istrans(vi) returns true if vi is nonempty and all variables in vi + # are linked. + insert_transformed_value = in_varinfo ? istrans(vi, vn) : istrans(vi) + f = if insert_transformed_value + to_linked_internal_transform(vi, vn, dist) + else + to_internal_transform(vi, vn, dist) + end + # TODO(penelopeysm): We would really like to do: + # y, logjac = with_logabsdet_jacobian(f, x) + # Unfortunately, `to_{linked_}internal_transform` returns a function that + # always converts x to a vector, i.e., if dist is univariate, f(x) will be + # a vector of length 1. It would be nice if we could unify these. + y = f(x) + logjac = logabsdetjac(insert_transformed_value ? link_transform(dist) : identity, x) + # Add the new value to the VarInfo. `push!!` errors if the value already + # exists, hence the need for setindex!!. + if in_varinfo + vi = setindex!!(vi, y, vn) + else + vi = push!!(vi, vn, y, dist) + end + # Neither of these set the `trans` flag so we have to do it manually if + # necessary. + insert_transformed_value && settrans!!(vi, true, vn) + # `accumulate_assume!!` wants untransformed values as the second argument. + vi = accumulate_assume!!(vi, x, -logjac, vn, dist) + # We always return the untransformed value here, as that will determine + # what the lhs of the tilde-statement is set to. + return x, vi +end + +function tilde_observe!!(::InitContext, right, left, vn, vi) + return tilde_observe!!(DefaultContext(), right, left, vn, vi) +end diff --git a/src/model.jl b/src/model.jl index 72a7ac294..f14744bf2 100644 --- a/src/model.jl +++ b/src/model.jl @@ -854,6 +854,39 @@ function evaluate_and_sample!!( return evaluate_and_sample!!(Random.default_rng(), model, varinfo, sampler) end +""" + init!!( + [rng::Random.AbstractRNG,] + model::Model, + varinfo::AbstractVarInfo, + [init_strategy::AbstractInitStrategy=PriorInit()] + ) + +Evaluate the `model` and replace the values of the model's random variables in +the given `varinfo` with new values using a specified initialisation strategy. +If the values in `varinfo` are not already present, they will be added using +that same strategy. + +If `init_strategy` is not provided, defaults to PriorInit(). + +Returns a tuple of the model's return value, plus the updated `varinfo` object. +""" +function init!!( + rng::Random.AbstractRNG, + model::Model, + varinfo::AbstractVarInfo, + init_strategy::AbstractInitStrategy=PriorInit(), +) + new_context = setleafcontext(model.context, InitContext(rng, init_strategy)) + new_model = contextualize(model, new_context) + return evaluate!!(new_model, varinfo) +end +function init!!( + model::Model, varinfo::AbstractVarInfo, init_strategy::AbstractInitStrategy=PriorInit() +) + return init!!(Random.default_rng(), model, varinfo, init_strategy) +end + """ evaluate!!(model::Model, varinfo) diff --git a/test/contexts.jl b/test/contexts.jl index 597ab736c..be976aad4 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -1,5 +1,5 @@ using Test, DynamicPPL, Accessors -using AbstractPPL: getoptic +using AbstractPPL: getoptic, hasvalue, getvalue using DynamicPPL: leafcontext, setleafcontext, @@ -431,4 +431,14 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() @test fixed(c6) == Dict(@varname(a.b.d) => 2) end end + + @testset "InitContext" begin + @testset "PriorInit" begin end + + @testset "UniformInit" begin end + + @testset "ParamsInit" begin end + + @testset "rng is respected (at least with PriorInit" begin end + end end From 1ef1a9285f62b9ca64b1ed8739516874aa1c8701 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 9 Jul 2025 23:25:10 +0100 Subject: [PATCH 07/16] Fix loading order of modules; move `prefix(::Model)` to model.jl --- src/DynamicPPL.jl | 4 ++-- src/contexts.jl | 35 ----------------------------------- src/model.jl | 35 +++++++++++++++++++++++++++++++++++ 3 files changed, 37 insertions(+), 37 deletions(-) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 6b20899d9..bb6af996e 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -176,12 +176,12 @@ abstract type AbstractVarInfo <: AbstractModelTrace end # Necessary forward declarations include("utils.jl") include("chains.jl") +include("contexts.jl") +include("contexts/init.jl") include("model.jl") include("sampler.jl") include("varname.jl") include("distribution_wrappers.jl") -include("contexts.jl") -include("contexts/init.jl") include("submodel.jl") include("varnamedvector.jl") include("accumulators.jl") diff --git a/src/contexts.jl b/src/contexts.jl index addadfa1a..cd9876768 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -280,41 +280,6 @@ function prefix_and_strip_contexts(::IsParent, ctx::AbstractContext, vn::VarName return vn, setchildcontext(ctx, new_ctx) end -""" - prefix(model::Model, x::VarName) - prefix(model::Model, x::Val{sym}) - prefix(model::Model, x::Any) - -Return `model` but with all random variables prefixed by `x`, where `x` is either: -- a `VarName` (e.g. `@varname(a)`), -- a `Val{sym}` (e.g. `Val(:a)`), or -- for any other type, `x` is converted to a Symbol and then to a `VarName`. Note that - this will introduce runtime overheads so is not recommended unless absolutely - necessary. - -# Examples - -```jldoctest -julia> using DynamicPPL: prefix - -julia> @model demo() = x ~ Dirac(1) -demo (generic function with 2 methods) - -julia> rand(prefix(demo(), @varname(my_prefix))) -(var"my_prefix.x" = 1,) - -julia> rand(prefix(demo(), Val(:my_prefix))) -(var"my_prefix.x" = 1,) -``` -""" -prefix(model::Model, x::VarName) = contextualize(model, PrefixContext(x, model.context)) -function prefix(model::Model, x::Val{sym}) where {sym} - return contextualize(model, PrefixContext(VarName{sym}(), model.context)) -end -function prefix(model::Model, x) - return contextualize(model, PrefixContext(VarName{Symbol(x)}(), model.context)) -end - """ ConditionContext{Values<:Union{NamedTuple,AbstractDict},Ctx<:AbstractContext} diff --git a/src/model.jl b/src/model.jl index f14744bf2..6be4eb383 100644 --- a/src/model.jl +++ b/src/model.jl @@ -799,6 +799,41 @@ julia> # Now `a.x` will be sampled. """ fixed(model::Model) = fixed(model.context) +""" + prefix(model::Model, x::VarName) + prefix(model::Model, x::Val{sym}) + prefix(model::Model, x::Any) + +Return `model` but with all random variables prefixed by `x`, where `x` is either: +- a `VarName` (e.g. `@varname(a)`), +- a `Val{sym}` (e.g. `Val(:a)`), or +- for any other type, `x` is converted to a Symbol and then to a `VarName`. Note that + this will introduce runtime overheads so is not recommended unless absolutely + necessary. + +# Examples + +```jldoctest +julia> using DynamicPPL: prefix + +julia> @model demo() = x ~ Dirac(1) +demo (generic function with 2 methods) + +julia> rand(prefix(demo(), @varname(my_prefix))) +(var"my_prefix.x" = 1,) + +julia> rand(prefix(demo(), Val(:my_prefix))) +(var"my_prefix.x" = 1,) +``` +""" +prefix(model::Model, x::VarName) = contextualize(model, PrefixContext(x, model.context)) +function prefix(model::Model, x::Val{sym}) where {sym} + return contextualize(model, PrefixContext(VarName{sym}(), model.context)) +end +function prefix(model::Model, x) + return contextualize(model, PrefixContext(VarName{Symbol(x)}(), model.context)) +end + """ (model::Model)([rng, varinfo]) From a90d95e8dadd76604a1ffe18155265ed5b8a2239 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 10 Jul 2025 00:21:53 +0100 Subject: [PATCH 08/16] Add tests for InitContext behaviour --- src/contexts/init.jl | 12 +-- test/contexts.jl | 183 +++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 179 insertions(+), 16 deletions(-) diff --git a/src/contexts/init.jl b/src/contexts/init.jl index 580b1a666..6ff276d21 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -147,17 +147,11 @@ function tilde_assume( # are linked. insert_transformed_value = in_varinfo ? istrans(vi, vn) : istrans(vi) f = if insert_transformed_value - to_linked_internal_transform(vi, vn, dist) + link_transform(dist) else - to_internal_transform(vi, vn, dist) + identity end - # TODO(penelopeysm): We would really like to do: - # y, logjac = with_logabsdet_jacobian(f, x) - # Unfortunately, `to_{linked_}internal_transform` returns a function that - # always converts x to a vector, i.e., if dist is univariate, f(x) will be - # a vector of length 1. It would be nice if we could unify these. - y = f(x) - logjac = logabsdetjac(insert_transformed_value ? link_transform(dist) : identity, x) + y, logjac = with_logabsdet_jacobian(f, x) # Add the new value to the VarInfo. `push!!` errors if the value already # exists, hence the need for setindex!!. if in_varinfo diff --git a/test/contexts.jl b/test/contexts.jl index be976aad4..5768757bb 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -20,8 +20,9 @@ using DynamicPPL: hasconditioned_nested, getconditioned_nested, collapse_prefix_stack, - prefix_cond_and_fixed_variables, - getvalue + prefix_cond_and_fixed_variables +using LinearAlgebra: I +using Random: Xoshiro using EnzymeCore @@ -103,7 +104,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() # sometimes only the main symbol (e.g. it contains `x` when # `vn` is `x[1]`) for vn in conditioned_vns - val = DynamicPPL.getvalue(conditioned_values, vn) + val = getvalue(conditioned_values, vn) # These VarNames are present in the conditioning values, so # we should always be able to extract the value. @test hasconditioned_nested(context, vn) @@ -433,12 +434,180 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() end @testset "InitContext" begin - @testset "PriorInit" begin end + empty_varinfos = [ + VarInfo(), + DynamicPPL.typed_varinfo(VarInfo()), + VarInfo(DynamicPPL.VarNamedVector()), + DynamicPPL.typed_vector_varinfo(DynamicPPL.typed_varinfo(VarInfo())), + SimpleVarInfo(), + SimpleVarInfo(Dict{VarName,Any}()), + ] + + @model function test_init_model() + x ~ Normal() + y ~ MvNormal(fill(x, 2), I) + 1.0 ~ Normal() + return nothing + end + function test_generating_new_values(strategy::AbstractInitStrategy) + @testset "generating new values: $(typeof(strategy))" begin + # Check that init!! can generate values that weren't there + # previously. + model = test_init_model() + for empty_vi in empty_varinfos + this_vi = deepcopy(empty_vi) + _, vi = DynamicPPL.init!!(model, this_vi, strategy) + @test Set(keys(vi)) == Set([@varname(x), @varname(y)]) + x, y = vi[@varname(x)], vi[@varname(y)] + @test x isa Real + @test y isa AbstractVector{<:Real} + @test length(y) == 2 + (; logprior, loglikelihood) = getlogp(vi) + @test logpdf(Normal(), x) + logpdf(MvNormal(fill(x, 2), I), y) == + logprior + @test logpdf(Normal(), 1.0) == loglikelihood + end + end + end + function test_replacing_values(strategy::AbstractInitStrategy) + @testset "replacing old values: $(typeof(strategy))" begin + # Check that init!! can overwrite values that were already there. + model = test_init_model() + for empty_vi in empty_varinfos + # start by generating some rubbish values + vi = deepcopy(empty_vi) + old_x, old_y = 100000.00, [300000.00, 500000.00] + push!!(vi, @varname(x), old_x, Normal()) + push!!(vi, @varname(y), old_y, MvNormal(fill(old_x, 2), I)) + # then overwrite it + _, new_vi = DynamicPPL.init!!(model, vi, strategy) + new_x, new_y = new_vi[@varname(x)], new_vi[@varname(y)] + # check that the values are (presumably) different + @test old_x != new_x + @test old_y != new_y + end + end + end + function test_rng_respected(strategy::AbstractInitStrategy) + @testset "check that RNG is respected: $(typeof(strategy))" begin + model = test_init_model() + for empty_vi in empty_varinfos + _, vi1 = DynamicPPL.init!!( + Xoshiro(468), model, deepcopy(empty_vi), strategy + ) + _, vi2 = DynamicPPL.init!!( + Xoshiro(468), model, deepcopy(empty_vi), strategy + ) + _, vi3 = DynamicPPL.init!!( + Xoshiro(469), model, deepcopy(empty_vi), strategy + ) + @test vi1[@varname(x)] == vi2[@varname(x)] + @test vi1[@varname(y)] == vi2[@varname(y)] + @test vi1[@varname(x)] != vi3[@varname(x)] + @test vi1[@varname(y)] != vi3[@varname(y)] + end + end + end - @testset "UniformInit" begin end + @testset "PriorInit" begin + test_generating_new_values(PriorInit()) + test_replacing_values(PriorInit()) + test_rng_respected(PriorInit()) + + @testset "check that values are within support" begin + # Not many other sensible checks we can do for priors. + @model just_unif() = x ~ Uniform(0.0, 1e-7) + for _ in 1:100 + _, vi = DynamicPPL.init!!(just_unif(), VarInfo(), PriorInit()) + @test vi[@varname(x)] isa Real + @test 0.0 <= vi[@varname(x)] <= 1e-7 + end + end + end - @testset "ParamsInit" begin end + @testset "UniformInit" begin + test_generating_new_values(UniformInit()) + test_replacing_values(UniformInit()) + test_rng_respected(UniformInit()) + + @testset "check that bounds are respected" begin + @testset "unconstrained" begin + umin, umax = -1.0, 1.0 + @model just_norm() = x ~ Normal() + for _ in 1:100 + _, vi = DynamicPPL.init!!( + just_norm(), VarInfo(), UniformInit(umin, umax) + ) + @test vi[@varname(x)] isa Real + @test umin <= vi[@varname(x)] <= umax + end + end + @testset "constrained" begin + umin, umax = -1.0, 1.0 + @model just_beta() = x ~ Beta(2, 2) + inv_bijector = inverse(Bijectors.bijector(Beta(2, 2))) + tmin, tmax = inv_bijector(umin), inv_bijector(umax) + for _ in 1:100 + _, vi = DynamicPPL.init!!( + just_beta(), VarInfo(), UniformInit(umin, umax) + ) + @test vi[@varname(x)] isa Real + @test tmin <= vi[@varname(x)] <= tmax + end + end + end + end - @testset "rng is respected (at least with PriorInit" begin end + @testset "ParamsInit" begin + @testset "given full set of parameters" begin + # test_init_model has x ~ Normal() and y ~ MvNormal(zeros(2), I) + my_x, my_y = 1.0, [2.0, 3.0] + params_nt = (; x=my_x, y=my_y) + params_dict = Dict(@varname(x) => my_x, @varname(y) => my_y) + model = test_init_model() + for empty_vi in empty_varinfos + _, vi = DynamicPPL.init!!( + model, deepcopy(empty_vi), ParamsInit(params_nt) + ) + @test vi[@varname(x)] == my_x + @test vi[@varname(y)] == my_y + logp_nt = getlogp(vi) + _, vi = DynamicPPL.init!!( + model, deepcopy(empty_vi), ParamsInit(params_dict) + ) + @test vi[@varname(x)] == my_x + @test vi[@varname(y)] == my_y + logp_dict = getlogp(vi) + @test logp_nt == logp_dict + end + end + + @testset "given only partial parameters" begin + # In this case, we expect `ParamsInit` to use the value of x, and + # generate a new value for y. + my_x = 1.0 + params_nt = (; x=my_x) + params_dict = Dict(@varname(x) => my_x) + model = test_init_model() + for empty_vi in empty_varinfos + _, vi = DynamicPPL.init!!( + Xoshiro(468), model, deepcopy(empty_vi), ParamsInit(params_nt) + ) + @test vi[@varname(x)] == my_x + nt_y = vi[@varname(y)] + @test nt_y isa AbstractVector{<:Real} + @test length(nt_y) == 2 + _, vi = DynamicPPL.init!!( + Xoshiro(469), model, deepcopy(empty_vi), ParamsInit(params_dict) + ) + @test vi[@varname(x)] == my_x + dict_y = vi[@varname(y)] + @test dict_y isa AbstractVector{<:Real} + @test length(dict_y) == 2 + # the values should be different since we used different seeds + @test dict_y != nt_y + end + end + end end end From 001a05aad080b18cc2d57187b3e0983a09f8b984 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 10 Jul 2025 00:30:07 +0100 Subject: [PATCH 09/16] inline `rand(::Distributions.Uniform)` Note that, apart from being simpler code, Distributions.Uniform also doesn't allow the lower and upper bounds to be exactly equal (but we might like to keep that option open in DynamicPPL, e.g. if the user wants to initialise all values to the same value in linked space). --- src/contexts/init.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/contexts/init.jl b/src/contexts/init.jl index 6ff276d21..3b7007f51 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -38,6 +38,8 @@ to unconstrained space, and then sampling a value uniformly between `lower` and If unspecified, defaults to `(lower, upper) = (-2, 2)`, which mimics Stan's default initialisation strategy. +Requires that `lower <= upper`. + # References [Stan reference manual page on initialization](https://mc-stan.org/docs/reference-manual/execution.html#initialization) @@ -55,7 +57,7 @@ end function init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, u::UniformInit) b = Bijectors.bijector(dist) sz = Bijectors.output_size(b, size(dist)) - y = rand(rng, Uniform(u.lower, u.upper), sz) + y = u.lower .+ ((u.upper - u.lower) .* rand(rng, sz...)) b_inv = Bijectors.inverse(b) x = b_inv(y) # 0-dim arrays: https://github.com/TuringLang/Bijectors.jl/issues/398 From b55c1e17f97ae518d1d149122e1fb1055557183f Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 10 Jul 2025 00:46:43 +0100 Subject: [PATCH 10/16] Document --- docs/src/api.md | 21 +++++++++++++++++++++ src/contexts/init.jl | 20 ++++++++++---------- 2 files changed, 31 insertions(+), 10 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index e918a095c..3d5c681cf 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -470,6 +470,27 @@ SamplingContext DefaultContext PrefixContext ConditionContext +InitContext +``` + +### VarInfo initialisation + +`InitContext` is used to initialise, or overwrite, values in a VarInfo. + +To accomplish this, an initialisation _strategy_ is required, which defines how new values are to be obtained. +There are three concrete strategies provided in DynamicPPL: + +```@docs +PriorInit +UniformInit +ParamsInit +``` + +If you wish to write your own, you have to subtype [`DynamicPPL.AbstractInitStrategy`](@ref) and implement the `init` method. + +```@docs +DynamicPPL.AbstractInitStrategy +DynamicPPL.init ``` ### Samplers diff --git a/src/contexts/init.jl b/src/contexts/init.jl index 3b7007f51..2b87b533b 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -32,11 +32,11 @@ init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, ::PriorInit) = rand UniformInit(lower, upper) Obtain new values by first transforming the distribution of the random variable -to unconstrained space, and then sampling a value uniformly between `lower` and -`upper`. +to unconstrained space, then sampling a value uniformly between `lower` and +`upper`, and transforming that value back to the original space. -If unspecified, defaults to `(lower, upper) = (-2, 2)`, which mimics Stan's -default initialisation strategy. +If `lower` and `upper` are unspecified, they default to `(-2, 2)`, which mimics +Stan's default initialisation strategy. Requires that `lower <= upper`. @@ -91,17 +91,17 @@ struct ParamsInit{P,S<:AbstractInitStrategy} <: AbstractInitStrategy end end function init(rng::Random.AbstractRNG, vn::VarName, dist::Distribution, p::ParamsInit) - # TODO(penelopeysm): We should do a check to make sure that all of the - # parameters in `p.params` were actually used, and either warn or error if - # they aren't. This is non-trivial (we need to use something like - # varname_leaves), so I'm going to defer it to a later PR. + # TODO(penelopeysm): It would be nice to do a check to make sure that all + # of the parameters in `p.params` were actually used, and either warn or + # error if they aren't. This is actually quite non-trivial though because + # the structure of Dicts in particular can have arbitrary nesting. return if hasvalue(p.params, vn, dist) x = getvalue(p.params, vn, dist) if x === missing init(rng, vn, dist, p.default) else - # TODO(penelopeysm): We could also check that the type of x matches - # the dist? + # TODO(penelopeysm): Since x is user-supplied, maybe we could also + # check here that the type / size of x matches the dist? x end else From aa944b6dfeed2aef33d0a17b881d8565180e2829 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 10 Jul 2025 01:40:13 +0100 Subject: [PATCH 11/16] Replace `evaluate_and_sample!!` -> `init!!` --- docs/src/api.md | 5 --- src/extract_priors.jl | 2 +- src/model.jl | 47 ++++---------------- src/sampler.jl | 3 +- src/simple_varinfo.jl | 47 ++++++++++++-------- src/test_utils/contexts.jl | 72 ++++++++++++++++++++----------- src/test_utils/model_interface.jl | 4 +- src/varinfo.jl | 68 ++++++++++++++++------------- test/compiler.jl | 13 +++--- test/contexts.jl | 19 ++++---- test/model.jl | 25 +---------- test/sampler.jl | 4 +- test/simple_varinfo.jl | 8 ++-- test/varinfo.jl | 53 +++++++++++------------ test/varnamedvector.jl | 4 +- 15 files changed, 173 insertions(+), 201 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index 3d5c681cf..07843b34b 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -456,11 +456,6 @@ AbstractPPL.evaluate!! This method mutates the `varinfo` used for execution. By default, it does not perform any actual sampling: it only evaluates the model using the values of the variables that are already in the `varinfo`. -To perform sampling, you can either wrap `model.context` in a `SamplingContext`, or use this convenience method: - -```@docs -DynamicPPL.evaluate_and_sample!! -``` The behaviour of a model execution can be changed with evaluation contexts, which are a field of the model. Contexts are subtypes of `AbstractPPL.AbstractContext`. diff --git a/src/extract_priors.jl b/src/extract_priors.jl index bd6bdb2f2..557ed394a 100644 --- a/src/extract_priors.jl +++ b/src/extract_priors.jl @@ -116,7 +116,7 @@ function extract_priors(rng::Random.AbstractRNG, model::Model) # workaround for the fact that `order` is still hardcoded in VarInfo, and hence you # can't push new variables without knowing the num_produce. Remove this when possible. varinfo = setaccs!!(varinfo, (PriorDistributionAccumulator(), NumProduceAccumulator())) - varinfo = last(evaluate_and_sample!!(rng, model, varinfo)) + varinfo = last(init!!(rng, model, varinfo)) return getacc(varinfo, Val(:PriorDistributionAccumulator)).priors end diff --git a/src/model.jl b/src/model.jl index 6be4eb383..7d7a7921d 100644 --- a/src/model.jl +++ b/src/model.jl @@ -850,7 +850,7 @@ end # ^ Weird Documenter.jl bug means that we have to write the two above separately # as it can only detect the `function`-less syntax. function (model::Model)(rng::Random.AbstractRNG, varinfo::AbstractVarInfo=VarInfo()) - return first(evaluate_and_sample!!(rng, model, varinfo)) + return first(init!!(rng, model, varinfo)) end """ @@ -863,46 +863,19 @@ function use_threadsafe_eval(context::AbstractContext, varinfo::AbstractVarInfo) return Threads.nthreads() > 1 end -""" - evaluate_and_sample!!([rng::Random.AbstractRNG, ]model::Model, varinfo[, sampler]) - -Evaluate the `model` with the given `varinfo`, but perform sampling during the -evaluation using the given `sampler` by wrapping the model's context in a -`SamplingContext`. - -If `sampler` is not provided, defaults to [`SampleFromPrior`](@ref). - -Returns a tuple of the model's return value, plus the updated `varinfo` object. -""" -function evaluate_and_sample!!( - rng::Random.AbstractRNG, - model::Model, - varinfo::AbstractVarInfo, - sampler::AbstractSampler=SampleFromPrior(), -) - sampling_model = contextualize(model, SamplingContext(rng, sampler, model.context)) - return evaluate!!(sampling_model, varinfo) -end -function evaluate_and_sample!!( - model::Model, varinfo::AbstractVarInfo, sampler::AbstractSampler=SampleFromPrior() -) - return evaluate_and_sample!!(Random.default_rng(), model, varinfo, sampler) -end - """ init!!( - [rng::Random.AbstractRNG,] + [rng::Random.AbstractRNG, ] model::Model, varinfo::AbstractVarInfo, [init_strategy::AbstractInitStrategy=PriorInit()] ) -Evaluate the `model` and replace the values of the model's random variables in -the given `varinfo` with new values using a specified initialisation strategy. -If the values in `varinfo` are not already present, they will be added using -that same strategy. - -If `init_strategy` is not provided, defaults to PriorInit(). +Evaluate the `model` and replace the values of the model's random variables +in the given `varinfo` with new values, using a specified initialisation strategy. +If the values in `varinfo` are not set, they will be added. +using a specified initialisation strategy. If `init_strategy` is not provided, +defaults to PriorInit(). Returns a tuple of the model's return value, plus the updated `varinfo` object. """ @@ -1049,11 +1022,7 @@ Base.nameof(model::Model{<:Function}) = nameof(model.f) Generate a sample of type `T` from the prior distribution of the `model`. """ function Base.rand(rng::Random.AbstractRNG, ::Type{T}, model::Model) where {T} - x = last( - evaluate_and_sample!!( - rng, model, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}()) - ), - ) + x = last(init!!(rng, model, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}()))) return values_as(x, T) end diff --git a/src/sampler.jl b/src/sampler.jl index 673b5128f..184e4b70a 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -58,7 +58,8 @@ function AbstractMCMC.step( kwargs..., ) vi = VarInfo() - DynamicPPL.evaluate_and_sample!!(rng, model, vi, sampler) + strategy = sampler isa SampleFromPrior ? PriorInit() : UniformInit() + DynamicPPL.init!!(rng, model, vi, strategy) return vi, nothing end diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 0dc2415ca..5e4371fba 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -39,7 +39,7 @@ julia> rng = StableRNG(42); julia> # In the `NamedTuple` version we need to provide the place-holder values for # the variables which are using "containers", e.g. `Array`. # In this case, this means that we need to specify `x` but not `m`. - _, vi = DynamicPPL.evaluate_and_sample!!(rng, m, SimpleVarInfo((x = ones(2), ))); + _, vi = DynamicPPL.init!!(rng, m, SimpleVarInfo((x = ones(2), ))); julia> # (✓) Vroom, vroom! FAST!!! vi[@varname(x[1])] @@ -57,12 +57,12 @@ julia> vi[@varname(x[1:2])] 1.3736306979834252 julia> # (×) If we don't provide the container... - _, vi = DynamicPPL.evaluate_and_sample!!(rng, m, SimpleVarInfo()); vi + _, vi = DynamicPPL.init!!(rng, m, SimpleVarInfo()); vi ERROR: type NamedTuple has no field x [...] julia> # If one does not know the varnames, we can use a `OrderedDict` instead. - _, vi = DynamicPPL.evaluate_and_sample!!(rng, m, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}())); + _, vi = DynamicPPL.init!!(rng, m, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}())); julia> # (✓) Sort of fast, but only possible at runtime. vi[@varname(x[1])] @@ -91,28 +91,28 @@ demo_constrained (generic function with 2 methods) julia> m = demo_constrained(); -julia> _, vi = DynamicPPL.evaluate_and_sample!!(rng, m, SimpleVarInfo()); +julia> _, vi = DynamicPPL.init!!(rng, m, SimpleVarInfo()); julia> vi[@varname(x)] # (✓) 0 ≤ x < ∞ 1.8632965762164932 -julia> _, vi = DynamicPPL.evaluate_and_sample!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)); +julia> _, vi = DynamicPPL.init!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)); julia> vi[@varname(x)] # (✓) -∞ < x < ∞ -0.21080155351918753 -julia> xs = [last(DynamicPPL.evaluate_and_sample!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10]; +julia> xs = [last(DynamicPPL.init!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10]; julia> any(xs .< 0) # (✓) Positive probability mass on negative numbers! true julia> # And with `OrderedDict` of course! - _, vi = DynamicPPL.evaluate_and_sample!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(OrderedDict{VarName,Any}()), true)); + _, vi = DynamicPPL.init!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(OrderedDict{VarName,Any}()), true)); julia> vi[@varname(x)] # (✓) -∞ < x < ∞ 0.6225185067787314 -julia> xs = [last(DynamicPPL.evaluate_and_sample!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10]; +julia> xs = [last(DynamicPPL.init!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10]; julia> any(xs .< 0) # (✓) Positive probability mass on negative numbers! true @@ -226,24 +226,25 @@ end # Constructor from `Model`. function SimpleVarInfo{T}( - rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() + rng::Random.AbstractRNG, model::Model, init_strategy::AbstractInitStrategy=PriorInit() ) where {T<:Real} - new_model = contextualize(model, SamplingContext(rng, sampler, model.context)) + new_context = setleafcontext(model.context, InitContext(rng, init_strategy)) + new_model = contextualize(model, new_context) return last(evaluate!!(new_model, SimpleVarInfo{T}())) end function SimpleVarInfo{T}( - model::Model, sampler::AbstractSampler=SampleFromPrior() + model::Model, init_strategy::AbstractInitStrategy=PriorInit() ) where {T<:Real} - return SimpleVarInfo{T}(Random.default_rng(), model, sampler) + return SimpleVarInfo{T}(Random.default_rng(), model, init_strategy) end # Constructors without type param function SimpleVarInfo( - rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() + rng::Random.AbstractRNG, model::Model, init_strategy::AbstractInitStrategy=PriorInit() ) - return SimpleVarInfo{LogProbType}(rng, model, sampler) + return SimpleVarInfo{LogProbType}(rng, model, init_strategy) end -function SimpleVarInfo(model::Model, sampler::AbstractSampler=SampleFromPrior()) - return SimpleVarInfo{LogProbType}(Random.default_rng(), model, sampler) +function SimpleVarInfo(model::Model, init_strategy::AbstractInitStrategy=PriorInit()) + return SimpleVarInfo{LogProbType}(Random.default_rng(), model, init_strategy) end # Constructor from `VarInfo`. @@ -259,12 +260,12 @@ end function untyped_simple_varinfo(model::Model) varinfo = SimpleVarInfo(OrderedDict{VarName,Any}()) - return last(evaluate_and_sample!!(model, varinfo)) + return last(init!!(model, varinfo)) end function typed_simple_varinfo(model::Model) varinfo = SimpleVarInfo{Float64}() - return last(evaluate_and_sample!!(model, varinfo)) + return last(init!!(model, varinfo)) end function unflatten(svi::SimpleVarInfo, x::AbstractVector) @@ -474,7 +475,6 @@ function assume( return value, vi end -# NOTE: We don't implement `settrans!!(vi, trans, vn)`. function settrans!!(vi::SimpleVarInfo, trans) return settrans!!(vi, trans ? DynamicTransformation() : NoTransformation()) end @@ -484,6 +484,15 @@ end function settrans!!(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, trans) return Accessors.@set vi.varinfo = settrans!!(vi.varinfo, trans) end +function settrans!!(vi::SimpleOrThreadSafeSimple, trans::Bool, ::VarName) + # We keep this method around just to obey the AbstractVarInfo interface; however, + # this is only a valid operation if it would be a no-op. + if trans != istrans(vi) + error( + "Individual variables in SimpleVarInfo cannot have different `settrans` statuses.", + ) + end +end istrans(vi::SimpleVarInfo) = !(vi.transformation isa NoTransformation) istrans(vi::SimpleVarInfo, ::VarName) = istrans(vi) diff --git a/src/test_utils/contexts.jl b/src/test_utils/contexts.jl index 863db4262..4a019441b 100644 --- a/src/test_utils/contexts.jl +++ b/src/test_utils/contexts.jl @@ -29,21 +29,45 @@ function test_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Mod node_trait = DynamicPPL.NodeTrait(context) # Throw error immediately if it it's missing a `NodeTrait` implementation. node_trait isa Union{DynamicPPL.IsLeaf,DynamicPPL.IsParent} || - throw(ValueError("Invalid NodeTrait: $node_trait")) + error("Invalid NodeTrait: $node_trait") - # To see change, let's make sure we're using a different leaf context than the current. - leafcontext_new = if DynamicPPL.leafcontext(context) isa DefaultContext - DynamicPPL.DynamicTransformationContext{false}() + if node_trait isa DynamicPPL.IsLeaf + test_leaf_context(context, model) else - DefaultContext() + test_parent_context(context, model) end - @test DynamicPPL.leafcontext(DynamicPPL.setleafcontext(context, leafcontext_new)) == - leafcontext_new +end + +function test_leaf_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Model) + @test DynamicPPL.NodeTrait(context) isa DynamicPPL.IsLeaf + + # Note that for a leaf context we can't assume that it will work with an + # empty VarInfo. Thus we only test evaluation (i.e., assuming that the + # varinfo already contains all necessary variables). + @testset "evaluation" begin + # Generate a new filled untyped varinfo + _, untyped_vi = DynamicPPL.init!!(model, DynamicPPL.VarInfo()) + typed_vi = DynamicPPL.typed_varinfo(untyped_vi) + new_model = contextualize(model, context) + for vi in [untyped_vi, typed_vi] + _, vi = DynamicPPL.evaluate!!(new_model, vi) + @test vi isa DynamicPPL.VarInfo + end + end +end + +function test_parent_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Model) + @test DynamicPPL.NodeTrait(context) isa DynamicPPL.IsParent - # The interface methods. - if node_trait isa DynamicPPL.IsParent - # `childcontext` and `setchildcontext` - # With new child context + @testset "{set,}{leaf,child}context" begin + # Ensure we're using a different leaf context than the current. + leafcontext_new = if DynamicPPL.leafcontext(context) isa DefaultContext + DynamicPPL.DynamicTransformationContext{false}() + else + DefaultContext() + end + @test DynamicPPL.leafcontext(DynamicPPL.setleafcontext(context, leafcontext_new)) == + leafcontext_new childcontext_new = TestParentContext() @test DynamicPPL.childcontext( DynamicPPL.setchildcontext(context, childcontext_new) @@ -56,19 +80,15 @@ function test_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Mod leafcontext_new end - # Make sure that the we can evaluate the model with the context (i.e. that none of the tilde-functions are incorrectly overloaded). - # The tilde-pipeline contains two different paths: with `SamplingContext` as a parent, and without it. - # NOTE(torfjelde): Need to sample with the untyped varinfo _using_ the context, since the - # context might alter which variables are present, their names, etc., e.g. `PrefixContext`. - # TODO(torfjelde): Make the `varinfo` used for testing a kwarg once it makes sense for other varinfos. - # Untyped varinfo. - varinfo_untyped = DynamicPPL.VarInfo() - model_with_spl = contextualize(model, SamplingContext(context)) - model_without_spl = contextualize(model, context) - @test DynamicPPL.evaluate!!(model_with_spl, varinfo_untyped) isa Any - @test DynamicPPL.evaluate!!(model_without_spl, varinfo_untyped) isa Any - # Typed varinfo. - varinfo_typed = DynamicPPL.typed_varinfo(varinfo_untyped) - @test DynamicPPL.evaluate!!(model_with_spl, varinfo_typed) isa Any - @test DynamicPPL.evaluate!!(model_without_spl, varinfo_typed) isa Any + @testset "initialisation and evaluation" begin + new_model = contextualize(model, context) + for vi in [DynamicPPL.VarInfo(), DynamicPPL.typed_varinfo(DynamicPPL.VarInfo())] + # Initialisation + _, vi = DynamicPPL.init!!(new_model, DynamicPPL.VarInfo()) + @test vi isa DynamicPPL.VarInfo + # Evaluation + _, vi = DynamicPPL.evaluate!!(new_model, vi) + @test vi isa DynamicPPL.VarInfo + end + end end diff --git a/src/test_utils/model_interface.jl b/src/test_utils/model_interface.jl index 93aed074c..cb949464e 100644 --- a/src/test_utils/model_interface.jl +++ b/src/test_utils/model_interface.jl @@ -92,9 +92,7 @@ Even though it is recommended to implement this by hand for a particular `Model` a default implementation using [`SimpleVarInfo{<:Dict}`](@ref) is provided. """ function varnames(model::Model) - return collect( - keys(last(DynamicPPL.evaluate_and_sample!!(model, SimpleVarInfo(Dict())))) - ) + return collect(keys(last(DynamicPPL.init!!(model, SimpleVarInfo(Dict()))))) end """ diff --git a/src/varinfo.jl b/src/varinfo.jl index b3380e7f9..11afca2a3 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -106,10 +106,14 @@ function VarInfo(meta=Metadata()) end """ - VarInfo([rng, ]model[, sampler]) + VarInfo( + [rng::Random.AbstractRNG], + model, + [init_strategy::AbstractInitStrategy] + ) -Generate a `VarInfo` object for the given `model`, by evaluating it once using -the given `rng`, `sampler`. +Generate a `VarInfo` object for the given `model`, by initialising it with the +given `rng` and `init_strategy`. !!! warning @@ -122,12 +126,12 @@ the given `rng`, `sampler`. instead. """ function VarInfo( - rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() + rng::Random.AbstractRNG, model::Model, init_strategy::AbstractInitStrategy=PriorInit() ) - return typed_varinfo(rng, model, sampler) + return typed_varinfo(rng, model, init_strategy) end -function VarInfo(model::Model, sampler::AbstractSampler=SampleFromPrior()) - return VarInfo(Random.default_rng(), model, sampler) +function VarInfo(model::Model, init_strategy::AbstractInitStrategy=PriorInit()) + return VarInfo(Random.default_rng(), model, init_strategy) end const UntypedVectorVarInfo = VarInfo{<:VarNamedVector} @@ -184,7 +188,7 @@ end ######################## """ - untyped_varinfo([rng, ]model[, sampler]) + untyped_varinfo([rng, ]model[, init_strategy]) Construct a VarInfo object for the given `model`, which has just a single `Metadata` as its metadata field. @@ -192,15 +196,15 @@ Construct a VarInfo object for the given `model`, which has just a single # Arguments - `rng::Random.AbstractRNG`: The random number generator to use during model evaluation - `model::Model`: The model for which to create the varinfo object -- `sampler::AbstractSampler`: The sampler to use for the model. Defaults to `SampleFromPrior()`. +- `init_strategy::AbstractInitStrategy`: How the values are to be initialised. Defaults to `PriorInit()`. """ function untyped_varinfo( - rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() + rng::Random.AbstractRNG, model::Model, init_strategy::AbstractInitStrategy=PriorInit() ) - return last(evaluate_and_sample!!(rng, model, VarInfo(Metadata()), sampler)) + return last(init!!(rng, model, VarInfo(Metadata()), init_strategy)) end -function untyped_varinfo(model::Model, sampler::AbstractSampler=SampleFromPrior()) - return untyped_varinfo(Random.default_rng(), model, sampler) +function untyped_varinfo(model::Model, init_strategy::AbstractInitStrategy=PriorInit()) + return untyped_varinfo(Random.default_rng(), model, init_strategy) end """ @@ -263,7 +267,7 @@ function typed_varinfo(vi::NTVarInfo) return vi end """ - typed_varinfo([rng, ]model[, sampler]) + typed_varinfo([rng, ]model[, init_strategy]) Return a VarInfo object for the given `model`, which has a NamedTuple of `Metadata` structs as its metadata field. @@ -271,19 +275,19 @@ Return a VarInfo object for the given `model`, which has a NamedTuple of # Arguments - `rng::Random.AbstractRNG`: The random number generator to use during model evaluation - `model::Model`: The model for which to create the varinfo object -- `sampler::AbstractSampler`: The sampler to use for the model. Defaults to `SampleFromPrior()`. +- `init_strategy::AbstractInitStrategy`: How the values are to be initialised. Defaults to `PriorInit()`. """ function typed_varinfo( - rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() + rng::Random.AbstractRNG, model::Model, init_strategy::AbstractInitStrategy=PriorInit() ) - return typed_varinfo(untyped_varinfo(rng, model, sampler)) + return typed_varinfo(untyped_varinfo(rng, model, init_strategy)) end -function typed_varinfo(model::Model, sampler::AbstractSampler=SampleFromPrior()) - return typed_varinfo(Random.default_rng(), model, sampler) +function typed_varinfo(model::Model, init_strategy::AbstractInitStrategy=PriorInit()) + return typed_varinfo(Random.default_rng(), model, init_strategy) end """ - untyped_vector_varinfo([rng, ]model[, sampler]) + untyped_vector_varinfo([rng, ]model[, init_strategy]) Return a VarInfo object for the given `model`, which has just a single `VarNamedVector` as its metadata field. @@ -291,23 +295,25 @@ Return a VarInfo object for the given `model`, which has just a single # Arguments - `rng::Random.AbstractRNG`: The random number generator to use during model evaluation - `model::Model`: The model for which to create the varinfo object -- `sampler::AbstractSampler`: The sampler to use for the model. Defaults to `SampleFromPrior()`. +- `init_strategy::AbstractInitStrategy`: How the values are to be initialised. Defaults to `PriorInit()`. """ function untyped_vector_varinfo(vi::UntypedVarInfo) md = metadata_to_varnamedvector(vi.metadata) return VarInfo(md, deepcopy(vi.accs)) end function untyped_vector_varinfo( - rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() + rng::Random.AbstractRNG, model::Model, init_strategy::AbstractInitStrategy=PriorInit() ) - return untyped_vector_varinfo(untyped_varinfo(rng, model, sampler)) + return untyped_vector_varinfo(untyped_varinfo(rng, model, init_strategy)) end -function untyped_vector_varinfo(model::Model, sampler::AbstractSampler=SampleFromPrior()) - return untyped_vector_varinfo(Random.default_rng(), model, sampler) +function untyped_vector_varinfo( + model::Model, init_strategy::AbstractInitStrategy=PriorInit() +) + return untyped_vector_varinfo(Random.default_rng(), model, init_strategy) end """ - typed_vector_varinfo([rng, ]model[, sampler]) + typed_vector_varinfo([rng, ]model[, init_strategy]) Return a VarInfo object for the given `model`, which has a NamedTuple of `VarNamedVector`s as its metadata field. @@ -315,7 +321,7 @@ Return a VarInfo object for the given `model`, which has a NamedTuple of # Arguments - `rng::Random.AbstractRNG`: The random number generator to use during model evaluation - `model::Model`: The model for which to create the varinfo object -- `sampler::AbstractSampler`: The sampler to use for the model. Defaults to `SampleFromPrior()`. +- `init_strategy::AbstractInitStrategy`: How the values are to be initialised. Defaults to `PriorInit()`. """ function typed_vector_varinfo(vi::NTVarInfo) md = map(metadata_to_varnamedvector, vi.metadata) @@ -327,12 +333,12 @@ function typed_vector_varinfo(vi::UntypedVectorVarInfo) return VarInfo(nt, deepcopy(vi.accs)) end function typed_vector_varinfo( - rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() + rng::Random.AbstractRNG, model::Model, init_strategy::AbstractInitStrategy=PriorInit() ) - return typed_vector_varinfo(untyped_vector_varinfo(rng, model, sampler)) + return typed_vector_varinfo(untyped_vector_varinfo(rng, model, init_strategy)) end -function typed_vector_varinfo(model::Model, sampler::AbstractSampler=SampleFromPrior()) - return typed_vector_varinfo(Random.default_rng(), model, sampler) +function typed_vector_varinfo(model::Model, init_strategy::AbstractInitStrategy=PriorInit()) + return typed_vector_varinfo(Random.default_rng(), model, init_strategy) end """ diff --git a/test/compiler.jl b/test/compiler.jl index 97121715a..874b71204 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -193,11 +193,8 @@ module Issue537 end varinfo = VarInfo(model) @test getlogjoint(varinfo) == lp @test varinfo_ isa AbstractVarInfo - # During the model evaluation, its context is wrapped in a - # SamplingContext, so `model_` is not going to be equal to `model`. - # We can still check equality of `f` though. @test model_.f === model.f - @test model_.context isa SamplingContext + @test model_.context isa DynamicPPL.InitContext @test model_.context.rng isa Random.AbstractRNG # disable warnings @@ -598,13 +595,13 @@ module Issue537 end # an attempt at a `NamedTuple` of the form `(x = 1, __varinfo__)`. @model empty_model() = return x = 1 empty_vi = VarInfo() - retval_and_vi = DynamicPPL.evaluate_and_sample!!(empty_model(), empty_vi) + retval_and_vi = DynamicPPL.init!!(empty_model(), empty_vi) @test retval_and_vi isa Tuple{Int,typeof(empty_vi)} # Even if the return-value is `AbstractVarInfo`, we should return # a `Tuple` with `AbstractVarInfo` in the second component too. @model demo() = return __varinfo__ - retval, svi = DynamicPPL.evaluate_and_sample!!(demo(), SimpleVarInfo()) + retval, svi = DynamicPPL.init!!(demo(), SimpleVarInfo()) @test svi == SimpleVarInfo() if Threads.nthreads() > 1 @test retval isa DynamicPPL.ThreadSafeVarInfo{<:SimpleVarInfo} @@ -620,11 +617,11 @@ module Issue537 end f(x) = return x^2 return f(1.0) end - retval, svi = DynamicPPL.evaluate_and_sample!!(demo(), SimpleVarInfo()) + retval, svi = DynamicPPL.init!!(demo(), SimpleVarInfo()) @test retval isa Float64 @model demo() = x ~ Normal() - retval, svi = DynamicPPL.evaluate_and_sample!!(demo(), SimpleVarInfo()) + retval, svi = DynamicPPL.init!!(demo(), SimpleVarInfo()) # Return-value when using `to_submodel` @model inner() = x ~ Normal() diff --git a/test/contexts.jl b/test/contexts.jl index 5768757bb..d7f2113c5 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -166,29 +166,30 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() @test new_vn == @varname(a.x[1]) @test new_ctx == DefaultContext() - ctx2 = SamplingContext(PrefixContext(@varname(a))) + ctx2 = FixedContext((b=4,), PrefixContext(@varname(a))) new_vn, new_ctx = DynamicPPL.prefix_and_strip_contexts(ctx2, vn) @test new_vn == @varname(a.x[1]) - @test new_ctx == SamplingContext() + @test new_ctx == FixedContext((b=4,)) ctx3 = PrefixContext(@varname(a), ConditionContext((a=1,))) new_vn, new_ctx = DynamicPPL.prefix_and_strip_contexts(ctx3, vn) @test new_vn == @varname(a.x[1]) @test new_ctx == ConditionContext((a=1,)) - ctx4 = SamplingContext(PrefixContext(@varname(a), ConditionContext((a=1,)))) + ctx4 = FixedContext( + (b=4,), PrefixContext(@varname(a), ConditionContext((a=1,))) + ) new_vn, new_ctx = DynamicPPL.prefix_and_strip_contexts(ctx4, vn) @test new_vn == @varname(a.x[1]) - @test new_ctx == SamplingContext(ConditionContext((a=1,))) + @test new_ctx == FixedContext((b=4,)ConditionContext((a=1,))) end @testset "evaluation: $(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS prefix_vn = @varname(my_prefix) - context = DynamicPPL.PrefixContext(prefix_vn, SamplingContext()) - sampling_model = contextualize(model, context) - # Sample with the context. - varinfo = DynamicPPL.VarInfo() - DynamicPPL.evaluate!!(sampling_model, varinfo) + context = DynamicPPL.PrefixContext(prefix_vn, DefaultContext()) + new_model = contextualize(model, context) + # Initialize a new varinfo with the prefixed model + _, varinfo = DynamicPPL.init!!(new_model, DynamicPPL.VarInfo()) # Extract the resulting varnames vns_actual = Set(keys(varinfo)) diff --git a/test/model.jl b/test/model.jl index daa3cc743..7f4313ee7 100644 --- a/test/model.jl +++ b/test/model.jl @@ -155,24 +155,6 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() logjoint(model, chain) end - @testset "rng" begin - model = GDEMO_DEFAULT - - for sampler in (SampleFromPrior(), SampleFromUniform()) - for i in 1:10 - Random.seed!(100 + i) - vi = VarInfo() - DynamicPPL.evaluate_and_sample!!(Random.default_rng(), model, vi, sampler) - vals = vi[:] - - Random.seed!(100 + i) - vi = VarInfo() - DynamicPPL.evaluate_and_sample!!(Random.default_rng(), model, vi, sampler) - @test vi[:] == vals - end - end - end - @testset "defaults without VarInfo, Sampler, and Context" begin model = GDEMO_DEFAULT @@ -332,7 +314,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() @test logjoint(model, x) != DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian(model, x...) # Ensure `varnames` is implemented. - vi = last(DynamicPPL.evaluate_and_sample!!(model, SimpleVarInfo(OrderedDict()))) + vi = last(DynamicPPL.init!!(model, SimpleVarInfo(OrderedDict{VarName,Any}()))) @test all(collect(keys(vi)) .== DynamicPPL.TestUtils.varnames(model)) # Ensure `posterior_mean` is implemented. @test DynamicPPL.TestUtils.posterior_mean(model) isa typeof(x) @@ -591,10 +573,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() xs_train = 1:0.1:10 ys_train = ground_truth_β .* xs_train + rand(Normal(0, 0.1), length(xs_train)) m_lin_reg = linear_reg(xs_train, ys_train) - chain = [ - last(DynamicPPL.evaluate_and_sample!!(m_lin_reg, VarInfo())) for - _ in 1:10000 - ] + chain = [VarInfo(m_lin_reg) _ in 1:10000] # chain is generated from the prior @test mean([chain[i][@varname(β)] for i in eachindex(chain)]) ≈ 1.0 atol = 0.1 diff --git a/test/sampler.jl b/test/sampler.jl index fe9fd331a..d04f39ac9 100644 --- a/test/sampler.jl +++ b/test/sampler.jl @@ -69,8 +69,8 @@ end # initial samplers - DynamicPPL.initialsampler(::Sampler{OnlyInitAlgUniform}) = SampleFromUniform() - @test DynamicPPL.initialsampler(Sampler(OnlyInitAlgDefault())) == SampleFromPrior() + DynamicPPL.init_strategy(::Sampler{OnlyInitAlgUniform}) = UniformInit() + @test DynamicPPL.init_strategy(Sampler(OnlyInitAlgDefault())) == PriorInit() for alg in (OnlyInitAlgDefault(), OnlyInitAlgUniform()) # model with one variable: initialization p = 0.2 diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index 2f080e66b..f57a955d2 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -158,7 +158,7 @@ ### Sampling ### # Sample a new varinfo! - _, svi_new = DynamicPPL.evaluate_and_sample!!(model, svi) + _, svi_new = DynamicPPL.init!!(model, svi) # Realization for `m` should be different wp. 1. for vn in DynamicPPL.TestUtils.varnames(model) @@ -226,9 +226,9 @@ # Initialize. svi_nt = DynamicPPL.settrans!!(SimpleVarInfo(), true) - svi_nt = last(DynamicPPL.evaluate_and_sample!!(model, svi_nt)) + svi_nt = last(DynamicPPL.init!!(model, svi_nt)) svi_vnv = DynamicPPL.settrans!!(SimpleVarInfo(DynamicPPL.VarNamedVector()), true) - svi_vnv = last(DynamicPPL.evaluate_and_sample!!(model, svi_vnv)) + svi_vnv = last(DynamicPPL.init!!(model, svi_vnv)) for svi in (svi_nt, svi_vnv) # Sample with large variations in unconstrained space. @@ -273,7 +273,7 @@ ) # Resulting varinfo should no longer be transformed. - vi_result = last(DynamicPPL.evaluate_and_sample!!(model, deepcopy(vi))) + vi_result = last(DynamicPPL.init!!(model, deepcopy(vi))) @test !DynamicPPL.istrans(vi_result) # Set the values to something that is out of domain if we're in constrained space. diff --git a/test/varinfo.jl b/test/varinfo.jl index 7f4c2daad..caf3bdb55 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -43,7 +43,7 @@ end end model = gdemo(1.0, 2.0) - vi = DynamicPPL.untyped_varinfo(model, SampleFromUniform()) + _, vi = DynamicPPL.init!!(model, VarInfo(), UniformInit()) tvi = DynamicPPL.typed_varinfo(vi) meta = vi.metadata @@ -486,17 +486,18 @@ end end model = gdemo([1.0, 1.5], [2.0, 2.5]) - # Check that instantiating the model using SampleFromUniform does not + # Check that instantiating the model using UniformInit does not # perform linking - # Note (penelopeysm): The purpose of using SampleFromUniform (SFU) - # specifically in this test is because SFU samples from the linked - # distribution i.e. in unconstrained space. However, it does this not - # by linking the varinfo but by transforming the distributions on the - # fly. That's why it's worth specifically checking that it can do this - # without having to change the VarInfo object. + # Note (penelopeysm): The purpose of using UniformInit specifically in + # this test is because it samples from the linked distribution i.e. in + # unconstrained space. However, it does this not by linking the varinfo + # but by transforming the distributions on the fly. That's why it's + # worth specifically checking that it can do this without having to + # change the VarInfo object. + # TODO(penelopeysm): Move this to UniformInit tests rather than here. vi = VarInfo() meta = vi.metadata - _, vi = DynamicPPL.evaluate_and_sample!!(model, vi, SampleFromUniform()) + _, vi = DynamicPPL.init!!(model, vi, UniformInit()) @test all(x -> !istrans(vi, x), meta.vns) # Check that linking and invlinking set the `trans` flag accordingly @@ -569,8 +570,8 @@ end ## `untyped_varinfo` vi = DynamicPPL.untyped_varinfo(model) vi = DynamicPPL.settrans!!(vi, true, vn) - # Sample in unconstrained space. - vi = last(DynamicPPL.evaluate_and_sample!!(model, vi)) + # Resample in unconstrained space. + vi = last(DynamicPPL.init!!(model, vi, PriorInit())) f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) x = f(DynamicPPL.getindex_internal(vi, vn)) @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) @@ -578,8 +579,8 @@ end ## `typed_varinfo` vi = DynamicPPL.typed_varinfo(model) vi = DynamicPPL.settrans!!(vi, true, vn) - # Sample in unconstrained space. - vi = last(DynamicPPL.evaluate_and_sample!!(model, vi)) + # Resample in unconstrained space. + vi = last(DynamicPPL.init!!(model, vi, PriorInit())) f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) x = f(DynamicPPL.getindex_internal(vi, vn)) @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) @@ -587,8 +588,8 @@ end ## `typed_varinfo` vi = DynamicPPL.typed_varinfo(model) vi = DynamicPPL.settrans!!(vi, true, vn) - # Sample in unconstrained space. - vi = last(DynamicPPL.evaluate_and_sample!!(model, vi)) + # Resample in unconstrained space. + vi = last(DynamicPPL.init!!(model, vi, PriorInit())) f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) x = f(DynamicPPL.getindex_internal(vi, vn)) @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) @@ -596,24 +597,24 @@ end ### `SimpleVarInfo` ## `SimpleVarInfo{<:NamedTuple}` vi = DynamicPPL.settrans!!(SimpleVarInfo(), true) - # Sample in unconstrained space. - vi = last(DynamicPPL.evaluate_and_sample!!(model, vi)) + # Resample in unconstrained space. + vi = last(DynamicPPL.init!!(model, vi, PriorInit())) f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) x = f(DynamicPPL.getindex_internal(vi, vn)) @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) ## `SimpleVarInfo{<:Dict}` vi = DynamicPPL.settrans!!(SimpleVarInfo(Dict{VarName,Any}()), true) - # Sample in unconstrained space. - vi = last(DynamicPPL.evaluate_and_sample!!(model, vi)) + # Resample in unconstrained space. + vi = last(DynamicPPL.init!!(model, vi, PriorInit())) f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) x = f(DynamicPPL.getindex_internal(vi, vn)) @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) ## `SimpleVarInfo{<:VarNamedVector}` vi = DynamicPPL.settrans!!(SimpleVarInfo(DynamicPPL.VarNamedVector()), true) - # Sample in unconstrained space. - vi = last(DynamicPPL.evaluate_and_sample!!(model, vi)) + # Resample in unconstrained space. + vi = last(DynamicPPL.init!!(model, vi, PriorInit())) f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) x = f(DynamicPPL.getindex_internal(vi, vn)) @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) @@ -994,10 +995,9 @@ end end model1 = demo(1) varinfo1 = DynamicPPL.link!!(VarInfo(model1), model1) - # Sampling from `model2` should hit the `istrans(vi) == true` branches - # because all the existing variables are linked. + # Calling init!! should preserve the fact that the variables are linked. model2 = demo(2) - varinfo2 = last(DynamicPPL.evaluate_and_sample!!(model2, deepcopy(varinfo1))) + varinfo2 = last(DynamicPPL.init!!(model2, deepcopy(varinfo1), PriorInit())) for vn in [@varname(x[1]), @varname(x[2])] @test DynamicPPL.istrans(varinfo2, vn) end @@ -1013,10 +1013,9 @@ end end model1 = demo_dot(1) varinfo1 = DynamicPPL.link!!(DynamicPPL.untyped_varinfo(model1), model1) - # Sampling from `model2` should hit the `istrans(vi) == true` branches - # because all the existing variables are linked. + # Calling init!! should preserve the fact that the variables are linked. model2 = demo_dot(2) - varinfo2 = last(DynamicPPL.evaluate_and_sample!!(model2, deepcopy(varinfo1))) + varinfo2 = last(DynamicPPL.init!!(model2, deepcopy(varinfo1), PriorInit())) for vn in [@varname(x), @varname(y[1])] @test DynamicPPL.istrans(varinfo2, vn) end diff --git a/test/varnamedvector.jl b/test/varnamedvector.jl index 57a8175d4..af24be86f 100644 --- a/test/varnamedvector.jl +++ b/test/varnamedvector.jl @@ -610,9 +610,7 @@ end DynamicPPL.TestUtils.test_values(varinfo_eval, value_true, vns) # Is sampling correct? - varinfo_sample = last( - DynamicPPL.evaluate_and_sample!!(model, deepcopy(varinfo)) - ) + varinfo_sample = last(DynamicPPL.init!!(model, deepcopy(varinfo))) # Log density should be different. @test getlogjoint(varinfo_sample) != getlogjoint(varinfo) # Values should be different. From c6df871bc5fdc44f91d14457a1fa2a78ca943bde Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 10 Jul 2025 01:51:13 +0100 Subject: [PATCH 12/16] Use `ParamsInit` for `predict`; remove `setval_and_resample!` and friends --- ext/DynamicPPLMCMCChainsExt.jl | 38 +++++--- src/model.jl | 11 ++- src/varinfo.jl | 143 ---------------------------- test/ext/DynamicPPLMCMCChainsExt.jl | 7 +- test/model.jl | 2 +- test/test_util.jl | 4 +- test/varinfo.jl | 60 +----------- 7 files changed, 49 insertions(+), 216 deletions(-) diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl index a29696720..cd86cfb5e 100644 --- a/ext/DynamicPPLMCMCChainsExt.jl +++ b/ext/DynamicPPLMCMCChainsExt.jl @@ -28,7 +28,7 @@ end function _check_varname_indexing(c::MCMCChains.Chains) return DynamicPPL.supports_varname_indexing(c) || - error("Chains do not support indexing using `VarName`s.") + error("This `Chains` object does not support indexing using `VarName`s.") end function DynamicPPL.getindex_varname( @@ -42,6 +42,15 @@ function DynamicPPL.varnames(c::MCMCChains.Chains) return keys(c.info.varname_to_symbol) end +function chain_sample_to_varname_dict(c::MCMCChains.Chains, sample_idx, chain_idx) + _check_varname_indexing(c) + d = Dict{DynamicPPL.VarName,Any}() + for vn in DynamicPPL.varnames(c) + d[vn] = DynamicPPL.getindex_varname(c, sample_idx, vn, chain_idx) + end + return d +end + """ predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false) @@ -114,9 +123,15 @@ function DynamicPPL.predict( iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) predictive_samples = map(iters) do (sample_idx, chain_idx) - DynamicPPL.setval_and_resample!(varinfo, parameter_only_chain, sample_idx, chain_idx) - varinfo = last(DynamicPPL.evaluate_and_sample!!(rng, model, varinfo)) - + # Extract values from the chain + values_dict = chain_sample_to_varname_dict(parameter_only_chain, sample_idx, chain_idx) + # Resample any variables that are not present in `values_dict` + _, varinfo = DynamicPPL.init!!( + rng, + model, + varinfo, + DynamicPPL.ParamsInit(values_dict, DynamicPPL.PriorInit()), + ) vals = DynamicPPL.values_as_in_model(model, false, varinfo) varname_vals = mapreduce( collect, @@ -248,13 +263,14 @@ function DynamicPPL.returned(model::DynamicPPL.Model, chain_full::MCMCChains.Cha varinfo = DynamicPPL.VarInfo(model) iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) return map(iters) do (sample_idx, chain_idx) - # TODO: Use `fix` once we've addressed https://github.com/TuringLang/DynamicPPL.jl/issues/702. - # Update the varinfo with the current sample and make variables not present in `chain` - # to be sampled. - DynamicPPL.setval_and_resample!(varinfo, chain, sample_idx, chain_idx) - # NOTE: Some of the varialbes can be a view into the `varinfo`, so we need to - # `deepcopy` the `varinfo` before passing it to the `model`. - model(deepcopy(varinfo)) + # Extract values from the chain + values_dict = chain_sample_to_varname_dict(chain, sample_idx, chain_idx) + # Resample any variables that are not present in `values_dict`, and + # return the model's retval. + retval, _ = DynamicPPL.init!!( + model, varinfo, DynamicPPL.ParamsInit(values_dict, DynamicPPL.PriorInit()) + ) + retval end end diff --git a/src/model.jl b/src/model.jl index 7d7a7921d..af394987f 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1200,8 +1200,15 @@ function predict( varinfo = DynamicPPL.VarInfo(model) return map(chain) do params_varinfo vi = deepcopy(varinfo) - DynamicPPL.setval_and_resample!(vi, values_as(params_varinfo, NamedTuple)) - model(rng, vi) + # TODO(penelopeysm): Requires two model evaluations, one to extract the + # parameters and one to set them. The reason why we need values_as_in_model + # is because `params_varinfo` may well have some weird combination of + # linked/unlinked, whereas `varinfo` is always unlinked since it is + # freshly constructed. + # This is quite inefficient. It would of course be alright if + # ValuesAsInModelAccumulator was a default acc. + values_nt = values_as_in_model(model, false, params_varinfo) + _, vi = DynamicPPL.init!!(rng, model, vi, ParamsInit(values_nt, PriorInit())) return vi end end diff --git a/src/varinfo.jl b/src/varinfo.jl index 11afca2a3..888c51959 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1506,42 +1506,6 @@ function islinked(vi::VarInfo) return any(istrans(vi, vn) for vn in keys(vi)) end -function nested_setindex_maybe!(vi::UntypedVarInfo, val, vn::VarName) - return _nested_setindex_maybe!(vi, getmetadata(vi, vn), val, vn) -end -function nested_setindex_maybe!( - vi::VarInfo{<:NamedTuple{names}}, val, vn::VarName{sym} -) where {names,sym} - return if sym in names - _nested_setindex_maybe!(vi, getmetadata(vi, vn), val, vn) - else - nothing - end -end -function _nested_setindex_maybe!( - vi::VarInfo, md::Union{Metadata,VarNamedVector}, val, vn::VarName -) - # If `vn` is in `vns`, then we can just use the standard `setindex!`. - vns = Base.keys(md) - if vn in vns - setindex!(vi, val, vn) - return vn - end - - # Otherwise, we need to check if either of the `vns` subsumes `vn`. - i = findfirst(Base.Fix2(subsumes, vn), vns) - i === nothing && return nothing - - vn_parent = vns[i] - val_parent = getindex(vi, vn_parent) # TODO: Ensure that we're working with a view here. - # Split the varname into its tail optic. - optic = remove_parent_optic(vn_parent, vn) - # Update the value for the parent. - val_parent_updated = set!!(val_parent, optic, val) - setindex!(vi, val_parent_updated, vn_parent) - return vn_parent -end - # The default getindex & setindex!() for get & set values # NOTE: vi[vn] will always transform the variable to its original space and Julia type function getindex(vi::VarInfo, vn::VarName) @@ -2045,113 +2009,6 @@ function _setval_kernel!(vi::VarInfoOrThreadSafeVarInfo, vn::VarName, values, ke return indices end -""" - setval_and_resample!(vi::VarInfo, x) - setval_and_resample!(vi::VarInfo, values, keys) - setval_and_resample!(vi::VarInfo, chains::AbstractChains, sample_idx, chain_idx) - -Set the values in `vi` to the provided values and those which are not present -in `x` or `chains` to *be* resampled. - -Note that this does *not* resample the values not provided! It will call -`setflag!(vi, vn, "del")` for variables `vn` for which no values are provided, which means -that the next time we call `model(vi)` these variables will be resampled. - -## Note -- This suffers from the same limitations as [`setval!`](@ref). See `setval!` for more info. - -## Example -```jldoctest -julia> using DynamicPPL, Distributions, StableRNGs - -julia> @model function demo(x) - m ~ Normal() - for i in eachindex(x) - x[i] ~ Normal(m, 1) - end - end; - -julia> rng = StableRNG(42); - -julia> m = demo([missing]); - -julia> var_info = DynamicPPL.VarInfo(rng, m); - # Checking the setting of "del" flags only makes sense for VarInfo{<:Metadata}. For VarInfo{<:VarNamedVector} the flag is effectively always set. - -julia> var_info[@varname(m)] --0.6702516921145671 - -julia> var_info[@varname(x[1])] --0.22312984965118443 - -julia> DynamicPPL.setval_and_resample!(var_info, (m = 100.0, )); # set `m` and ready `x[1]` for resampling - -julia> var_info[@varname(m)] # [✓] changed -100.0 - -julia> var_info[@varname(x[1])] # [✓] unchanged --0.22312984965118443 - -julia> m(rng, var_info); # sample `x[1]` conditioned on `m = 100.0` - -julia> var_info[@varname(m)] # [✓] unchanged -100.0 - -julia> var_info[@varname(x[1])] # [✓] changed -101.37363069798343 -``` - -## See also -- [`setval!`](@ref) -""" -function setval_and_resample!(vi::VarInfoOrThreadSafeVarInfo, x) - return setval_and_resample!(vi, values(x), keys(x)) -end -function setval_and_resample!(vi::VarInfoOrThreadSafeVarInfo, values, keys) - return _apply!(_setval_and_resample_kernel!, vi, values, keys) -end -function setval_and_resample!( - vi::VarInfoOrThreadSafeVarInfo, chains::AbstractChains, sample_idx::Int, chain_idx::Int -) - if supports_varname_indexing(chains) - # First we need to set every variable to be resampled. - for vn in keys(vi) - set_flag!(vi, vn, "del") - end - # Then we set the variables in `varinfo` from `chain`. - for vn in varnames(chains) - vn_updated = nested_setindex_maybe!( - vi, getindex_varname(chains, sample_idx, vn, chain_idx), vn - ) - - # Unset the `del` flag if we found something. - if vn_updated !== nothing - # NOTE: This will be triggered even if only a subset of a variable has been set! - unset_flag!(vi, vn_updated, "del") - end - end - else - setval_and_resample!(vi, chains.value[sample_idx, :, chain_idx], keys(chains)) - end -end - -function _setval_and_resample_kernel!( - vi::VarInfoOrThreadSafeVarInfo, vn::VarName, values, keys -) - indices = findall(Base.Fix1(subsumes_string, string(vn)), keys) - if !isempty(indices) - val = reduce(vcat, values[indices]) - setval!(vi, val, vn) - settrans!!(vi, false, vn) - else - # Ensures that we'll resample the variable corresponding to `vn` if we run - # the model on `vi` again. - set_flag!(vi, vn, "del") - end - - return indices -end - values_as(vi::VarInfo) = vi.metadata values_as(vi::VarInfo, ::Type{Vector}) = copy(getindex_internal(vi, Colon())) function values_as(vi::UntypedVarInfo, ::Type{NamedTuple}) diff --git a/test/ext/DynamicPPLMCMCChainsExt.jl b/test/ext/DynamicPPLMCMCChainsExt.jl index 3ba5edfe1..79e13ad84 100644 --- a/test/ext/DynamicPPLMCMCChainsExt.jl +++ b/test/ext/DynamicPPLMCMCChainsExt.jl @@ -2,7 +2,12 @@ @model demo() = x ~ Normal() model = demo() - chain = MCMCChains.Chains(randn(1000, 2, 1), [:x, :y], Dict(:internals => [:y])) + chain = MCMCChains.Chains( + randn(1000, 2, 1), + [:x, :y], + Dict(:internals => [:y]); + info=(; varname_to_symbol=Dict(@varname(x) => :x)), + ) chain_generated = @test_nowarn returned(model, chain) @test size(chain_generated) == (1000, 1) @test mean(chain_generated) ≈ 0 atol = 0.1 diff --git a/test/model.jl b/test/model.jl index 7f4313ee7..18e5e633f 100644 --- a/test/model.jl +++ b/test/model.jl @@ -573,7 +573,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() xs_train = 1:0.1:10 ys_train = ground_truth_β .* xs_train + rand(Normal(0, 0.1), length(xs_train)) m_lin_reg = linear_reg(xs_train, ys_train) - chain = [VarInfo(m_lin_reg) _ in 1:10000] + chain = [VarInfo(m_lin_reg) for _ in 1:10000] # chain is generated from the prior @test mean([chain[i][@varname(β)] for i in eachindex(chain)]) ≈ 1.0 atol = 0.1 diff --git a/test/test_util.jl b/test/test_util.jl index d5335249d..b7c46ff34 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -81,8 +81,10 @@ function make_chain_from_prior(rng::Random.AbstractRNG, model::Model, n_iters::I varnames = collect(varnames) # Construct matrix of values vals = [get(dict, vn, missing) for dict in dicts, vn in varnames] + # Construct dict of varnames -> symbol + vn_to_sym_dict = Dict(zip(varnames, map(Symbol, varnames))) # Construct and return the Chains object - return Chains(vals, varnames) + return Chains(vals, varnames; info=(; varname_to_symbol=vn_to_sym_dict)) end function make_chain_from_prior(model::Model, n_iters::Int) return make_chain_from_prior(Random.default_rng(), model, n_iters) diff --git a/test/varinfo.jl b/test/varinfo.jl index caf3bdb55..ddbc4bfe8 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -278,7 +278,7 @@ end @test typed_vi[vn_y] == 2.0 end - @testset "setval! & setval_and_resample!" begin + @testset "setval!" begin @model function testmodel(x) n = length(x) s ~ truncated(Normal(); lower=0) @@ -329,8 +329,8 @@ end else DynamicPPL.setval!(vicopy, (m=zeros(5),)) end - # Setting `m` fails for univariate due to limitations of `setval!` - # and `setval_and_resample!`. See docstring of `setval!` for more info. + # Setting `m` fails for univariate due to limitations of `setval!`. + # See docstring of `setval!` for more info. if model == model_uv && vi in [vi_untyped, vi_typed] @test_broken vicopy[m_vns] == zeros(5) else @@ -355,57 +355,6 @@ end DynamicPPL.setval!(vicopy, (s=42,)) @test vicopy[m_vns] == 1:5 @test vicopy[s_vns] == 42 - - ### `setval_and_resample!` ### - if model == model_mv && vi == vi_untyped - # Trying to re-run model with `MvNormal` on `vi_untyped` will call - # `MvNormal(μ::Vector{Real}, Σ)` which causes `StackOverflowError` - # so we skip this particular case. - continue - end - - if vi in [vi_vnv, vi_vnv_typed] - # `setval_and_resample!` works differently for `VarNamedVector`: All - # values will be resampled when model(vicopy) is called. Hence the below - # tests are not applicable. - continue - end - - vicopy = deepcopy(vi) - DynamicPPL.setval_and_resample!(vicopy, (m=zeros(5),)) - model(vicopy) - # Setting `m` fails for univariate due to limitations of `subsumes(::String, ::String)` - if model == model_uv - @test_broken vicopy[m_vns] == zeros(5) - else - @test vicopy[m_vns] == zeros(5) - end - @test vicopy[s_vns] != vi[s_vns] - - # Ordering is NOT preserved. - DynamicPPL.setval_and_resample!( - vicopy, (; (Symbol("m[$i]") => i for i in (1, 3, 5, 4, 2))...) - ) - model(vicopy) - if model == model_uv - @test vicopy[m_vns] == 1:5 - else - @test vicopy[m_vns] == [1, 3, 5, 4, 2] - end - @test vicopy[s_vns] != vi[s_vns] - - # Correct ordering. - DynamicPPL.setval_and_resample!( - vicopy, (; (Symbol("m[$i]") => i for i in (1, 2, 3, 4, 5))...) - ) - model(vicopy) - @test vicopy[m_vns] == 1:5 - @test vicopy[s_vns] != vi[s_vns] - - DynamicPPL.setval_and_resample!(vicopy, (s=42,)) - model(vicopy) - @test vicopy[m_vns] != 1:5 - @test vicopy[s_vns] == 42 end end @@ -419,9 +368,6 @@ end ks = [@varname(x[1, 1]), @varname(x[2, 1]), @varname(x[1, 2]), @varname(x[2, 2])] DynamicPPL.setval!(vi, vi.metadata.x.vals, ks) @test vals_prev == vi.metadata.x.vals - - DynamicPPL.setval_and_resample!(vi, vi.metadata.x.vals, ks) - @test vals_prev == vi.metadata.x.vals end @testset "setval! on chain" begin From 949e9baf6b056a43eeba1d227e5d49a8e60c46d6 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 10 Jul 2025 12:31:37 +0100 Subject: [PATCH 13/16] Use `init!!` for initialisation --- docs/src/api.md | 7 ++- src/sampler.jl | 148 +++++++++--------------------------------------- 2 files changed, 34 insertions(+), 121 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index 07843b34b..893885405 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -456,6 +456,11 @@ AbstractPPL.evaluate!! This method mutates the `varinfo` used for execution. By default, it does not perform any actual sampling: it only evaluates the model using the values of the variables that are already in the `varinfo`. +To perform sampling, you can either wrap `model.context` in a `SamplingContext`, or use this convenience method: + +```@docs +DynamicPPL.evaluate_and_sample!! +``` The behaviour of a model execution can be changed with evaluation contexts, which are a field of the model. Contexts are subtypes of `AbstractPPL.AbstractContext`. @@ -509,7 +514,7 @@ The default implementation of [`Sampler`](@ref) uses the following unexported fu ```@docs DynamicPPL.initialstep DynamicPPL.loadstate -DynamicPPL.initialsampler +DynamicPPL.init_strategy ``` Finally, to specify which varinfo type a [`Sampler`](@ref) should use for a given [`Model`](@ref), this is specified by [`DynamicPPL.default_varinfo`](@ref) and can thus be overloaded for each `model`-`sampler` combination. This can be useful in cases where one has explicit knowledge that one type of varinfo will be more performant for the given `model` and `sampler`. diff --git a/src/sampler.jl b/src/sampler.jl index 184e4b70a..d43b840a5 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -68,6 +68,8 @@ end Return a default varinfo object for the given `model` and `sampler`. +The default method for this returns an empty NTVarInfo (i.e. 'typed varinfo'). + # Arguments - `rng::Random.AbstractRNG`: Random number generator. - `model::Model`: Model for which we want to create a varinfo object. @@ -76,9 +78,10 @@ Return a default varinfo object for the given `model` and `sampler`. # Returns - `AbstractVarInfo`: Default varinfo object for the given `model` and `sampler`. """ -function default_varinfo(rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler) - init_sampler = initialsampler(sampler) - return typed_varinfo(rng, model, init_sampler) +function default_varinfo(::Random.AbstractRNG, ::Model, ::AbstractSampler) + # Note that variable values are unconditionally initialized later, so no + # point putting them in now. + return typed_varinfo(VarInfo()) end function AbstractMCMC.sample( @@ -96,24 +99,32 @@ function AbstractMCMC.sample( ) end -# initial step: general interface for resuming and +""" + init_strategy(sampler) + +Define the initialisation strategy used for generating initial values when +sampling with `sampler`. Defaults to `PriorInit()`, but can be overridden. +""" +init_strategy(::Sampler) = PriorInit() + function AbstractMCMC.step( - rng::Random.AbstractRNG, model::Model, spl::Sampler; initial_params=nothing, kwargs... + rng::Random.AbstractRNG, + model::Model, + spl::Sampler; + initial_params::AbstractInitStrategy=init_strategy(spl), + kwargs..., ) - # Sample initial values. + # Generate the default varinfo (usually this just makes an empty VarInfo + # with NamedTuple of Metadata). vi = default_varinfo(rng, model, spl) - # Update the parameters if provided. - if initial_params !== nothing - vi = initialize_parameters!!(vi, initial_params, model) - - # Update joint log probability. - # This is a quick fix for https://github.com/TuringLang/Turing.jl/issues/1588 - # and https://github.com/TuringLang/Turing.jl/issues/1563 - # to avoid that existing variables are resampled - vi = last(evaluate!!(model, vi)) - end + # Fill it with initial parameters. Note that, if `ParamsInit` is used, the + # parameters provided must be in unlinked space (when inserted into the + # varinfo, they will be adjusted to match the linking status of the + # varinfo). + _, vi = init!!(rng, model, vi, initial_params) + # Call the actual function that does the first step. return initialstep(rng, model, spl, vi; initial_params, kwargs...) end @@ -131,110 +142,7 @@ loadstate(data) = data Default type of the chain of posterior samples from `sampler`. """ -default_chain_type(sampler::Sampler) = Any - -""" - initialsampler(sampler::Sampler) - -Return the sampler that is used for generating the initial parameters when sampling with -`sampler`. - -By default, it returns an instance of [`SampleFromPrior`](@ref). -""" -initialsampler(spl::Sampler) = SampleFromPrior() - -""" - set_initial_values(varinfo::AbstractVarInfo, initial_params::AbstractVector) - set_initial_values(varinfo::AbstractVarInfo, initial_params::NamedTuple) - -Take the values inside `initial_params`, replace the corresponding values in -the given VarInfo object, and return a new VarInfo object with the updated values. - -This differs from `DynamicPPL.unflatten` in two ways: - -1. It works with `NamedTuple` arguments. -2. For the `AbstractVector` method, if any of the elements are missing, it will not -overwrite the original value in the VarInfo (it will just use the original -value instead). -""" -function set_initial_values(varinfo::AbstractVarInfo, initial_params::AbstractVector) - throw( - ArgumentError( - "`initial_params` must be a vector of type `Union{Real,Missing}`. " * - "If `initial_params` is a vector of vectors, please flatten it (e.g. using `vcat`) first.", - ), - ) -end - -function set_initial_values( - varinfo::AbstractVarInfo, initial_params::AbstractVector{<:Union{Real,Missing}} -) - flattened_param_vals = varinfo[:] - length(flattened_param_vals) == length(initial_params) || throw( - DimensionMismatch( - "Provided initial value size ($(length(initial_params))) doesn't match " * - "the model size ($(length(flattened_param_vals))).", - ), - ) - - # Update values that are provided. - for i in eachindex(initial_params) - x = initial_params[i] - if x !== missing - flattened_param_vals[i] = x - end - end - - # Update in `varinfo`. - new_varinfo = unflatten(varinfo, flattened_param_vals) - return new_varinfo -end - -function set_initial_values(varinfo::AbstractVarInfo, initial_params::NamedTuple) - varinfo = deepcopy(varinfo) - vars_in_varinfo = keys(varinfo) - for v in keys(initial_params) - vn = VarName{v}() - if !(vn in vars_in_varinfo) - for vv in vars_in_varinfo - if subsumes(vn, vv) - throw( - ArgumentError( - "The current model contains sub-variables of $v, such as ($vv). " * - "Using NamedTuple for initial_params is not supported in such a case. " * - "Please use AbstractVector for initial_params instead of NamedTuple.", - ), - ) - end - end - throw(ArgumentError("Variable $v not found in the model.")) - end - end - initial_params = NamedTuple(k => v for (k, v) in pairs(initial_params) if v !== missing) - return update_values!!( - varinfo, initial_params, map(k -> VarName{k}(), keys(initial_params)) - ) -end - -function initialize_parameters!!(vi::AbstractVarInfo, initial_params, model::Model) - @debug "Using passed-in initial variable values" initial_params - - # `link` the varinfo if needed. - linked = islinked(vi) - if linked - vi = invlink!!(vi, model) - end - - # Set the values in `vi`. - vi = set_initial_values(vi, initial_params) - - # `invlink` if needed. - if linked - vi = link!!(vi, model) - end - - return vi -end +default_chain_type(::Sampler) = Any """ initialstep(rng, model, sampler, varinfo; kwargs...) From 17ffda1b9628ef1adba022f7bfe12dcb2ae4e66f Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 10 Jul 2025 16:10:00 +0100 Subject: [PATCH 14/16] Paper over the `Sampling->Init` context stack (pending removal of SamplingContext) --- src/context_implementations.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index b11a723a5..a63a36c04 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -28,6 +28,11 @@ end function tilde_assume(rng::Random.AbstractRNG, ::DefaultContext, sampler, right, vn, vi) return assume(rng, sampler, right, vn, vi) end +function tilde_assume(::Random.AbstractRNG, ::InitContext, sampler, right, vn, vi) + return error( + "Encountered SamplingContext->InitContext. This method will be removed in the next PR.", + ) +end function tilde_assume(::DefaultContext, sampler, right, vn, vi) # same as above but no rng return assume(Random.default_rng(), sampler, right, vn, vi) From d55d3787eebc8aad28909a0ab69fbf373598d37a Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 10 Jul 2025 16:25:00 +0100 Subject: [PATCH 15/16] Remove SamplingContext from JETExt to avoid triggering `Sampling->Init` pathway --- ext/DynamicPPLJETExt.jl | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/ext/DynamicPPLJETExt.jl b/ext/DynamicPPLJETExt.jl index 760d17bb0..89a36ffaf 100644 --- a/ext/DynamicPPLJETExt.jl +++ b/ext/DynamicPPLJETExt.jl @@ -21,22 +21,17 @@ end function DynamicPPL.Experimental._determine_varinfo_jet( model::DynamicPPL.Model; only_ddpl::Bool=true ) - # Use SamplingContext to test type stability. - sampling_model = DynamicPPL.contextualize( - model, DynamicPPL.SamplingContext(model.context) - ) - # First we try with the typed varinfo. - varinfo = DynamicPPL.typed_varinfo(sampling_model) + varinfo = DynamicPPL.typed_varinfo(model) - # Let's make sure that both evaluation and sampling doesn't result in type errors. + # Let's make sure that evaluation doesn't result in type errors. issuccess, result = DynamicPPL.Experimental.is_suitable_varinfo( - sampling_model, varinfo; only_ddpl + model, varinfo; only_ddpl ) if !issuccess # Useful information for debugging. - @debug "Evaluaton with typed varinfo failed with the following issues:" + @debug "Evaluation with typed varinfo failed with the following issues:" @debug result end @@ -46,7 +41,7 @@ function DynamicPPL.Experimental._determine_varinfo_jet( else # Warn the user that we can't use the type stable one. @warn "Model seems incompatible with typed varinfo. Falling back to untyped varinfo." - DynamicPPL.untyped_varinfo(sampling_model) + DynamicPPL.untyped_varinfo(model) end end From 3ab2061bf0279b8b5003f310bb05ccf687fd3a4e Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 10 Jul 2025 17:57:23 +0100 Subject: [PATCH 16/16] [no ci] Remove `SamplingContext` for good --- docs/src/api.md | 11 +--- ext/DynamicPPLEnzymeCoreExt.jl | 2 - src/DynamicPPL.jl | 3 - src/context_implementations.jl | 112 ++------------------------------- src/contexts.jl | 69 +------------------- src/debug_utils.jl | 5 +- src/sampler.jl | 45 ------------- src/simple_varinfo.jl | 17 ----- src/utils.jl | 44 ------------- test/ad.jl | 40 ------------ test/contexts.jl | 21 +------ test/ext/DynamicPPLJETExt.jl | 11 +--- test/sampler.jl | 106 +++++++++++++++---------------- test/threadsafe.jl | 6 +- 14 files changed, 68 insertions(+), 424 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index 893885405..ab618715f 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -466,7 +466,6 @@ The behaviour of a model execution can be changed with evaluation contexts, whic Contexts are subtypes of `AbstractPPL.AbstractContext`. ```@docs -SamplingContext DefaultContext PrefixContext ConditionContext @@ -495,15 +494,7 @@ DynamicPPL.init ### Samplers -In DynamicPPL two samplers are defined that are used to initialize unobserved random variables: -[`SampleFromPrior`](@ref) which samples from the prior distribution, and [`SampleFromUniform`](@ref) which samples from a uniform distribution. - -```@docs -SampleFromPrior -SampleFromUniform -``` - -Additionally, a generic sampler for inference is implemented. +In DynamicPPL a generic sampler for inference is implemented. ```@docs Sampler diff --git a/ext/DynamicPPLEnzymeCoreExt.jl b/ext/DynamicPPLEnzymeCoreExt.jl index ceb3f4981..f2d24ad92 100644 --- a/ext/DynamicPPLEnzymeCoreExt.jl +++ b/ext/DynamicPPLEnzymeCoreExt.jl @@ -8,8 +8,6 @@ else using ..EnzymeCore end -@inline EnzymeCore.EnzymeRules.inactive_type(::Type{<:DynamicPPL.SamplingContext}) = true - # Mark istrans as having 0 derivative. The `nothing` return value is not significant, Enzyme # only checks whether such a method exists, and never runs it. @inline EnzymeCore.EnzymeRules.inactive_noinl(::typeof(DynamicPPL.istrans), args...) = diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index bb6af996e..486788cbf 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -97,13 +97,10 @@ export AbstractVarInfo, values_as_in_model, # Samplers Sampler, - SampleFromPrior, - SampleFromUniform, # LogDensityFunction LogDensityFunction, # Contexts contextualize, - SamplingContext, DefaultContext, PrefixContext, ConditionContext, diff --git a/src/context_implementations.jl b/src/context_implementations.jl index a63a36c04..ba746896b 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -1,43 +1,14 @@ # assume -""" - tilde_assume(context::SamplingContext, right, vn, vi) - -Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), -accumulate the log probability, and return the sampled value with a context associated -with a sampler. - -Falls back to -```julia -tilde_assume(context.rng, context.context, context.sampler, right, vn, vi) -``` -""" -function tilde_assume(context::SamplingContext, right, vn, vi) - return tilde_assume(context.rng, context.context, context.sampler, right, vn, vi) -end - function tilde_assume(context::AbstractContext, args...) return tilde_assume(childcontext(context), args...) end function tilde_assume(::DefaultContext, right, vn, vi) - return assume(right, vn, vi) -end - -function tilde_assume(rng::Random.AbstractRNG, context::AbstractContext, args...) - return tilde_assume(rng, childcontext(context), args...) -end -function tilde_assume(rng::Random.AbstractRNG, ::DefaultContext, sampler, right, vn, vi) - return assume(rng, sampler, right, vn, vi) -end -function tilde_assume(::Random.AbstractRNG, ::InitContext, sampler, right, vn, vi) - return error( - "Encountered SamplingContext->InitContext. This method will be removed in the next PR.", - ) -end -function tilde_assume(::DefaultContext, sampler, right, vn, vi) - # same as above but no rng - return assume(Random.default_rng(), sampler, right, vn, vi) + y = getindex_internal(vi, vn) + f = from_maybe_linked_internal_transform(vi, vn, right) + x, logjac = with_logabsdet_jacobian(f, y) + vi = accumulate_assume!!(vi, x, logjac, vn, right) + return x, vi end - function tilde_assume(context::PrefixContext, right, vn, vi) # Note that we can't use something like this here: # new_vn = prefix(context, vn) @@ -51,12 +22,6 @@ function tilde_assume(context::PrefixContext, right, vn, vi) new_vn, new_context = prefix_and_strip_contexts(context, vn) return tilde_assume(new_context, right, new_vn, vi) end -function tilde_assume( - rng::Random.AbstractRNG, context::PrefixContext, sampler, right, vn, vi -) - new_vn, new_context = prefix_and_strip_contexts(context, vn) - return tilde_assume(rng, new_context, sampler, right, new_vn, vi) -end """ tilde_assume!!(context, right, vn, vi) @@ -76,17 +41,6 @@ function tilde_assume!!(context, right, vn, vi) end # observe -""" - tilde_observe!!(context::SamplingContext, right, left, vi) - -Handle observed constants with a `context` associated with a sampler. - -Falls back to `tilde_observe!!(context.context, right, left, vi)`. -""" -function tilde_observe!!(context::SamplingContext, right, left, vn, vi) - return tilde_observe!!(context.context, right, left, vn, vi) -end - function tilde_observe!!(context::AbstractContext, right, left, vn, vi) return tilde_observe!!(childcontext(context), right, left, vn, vi) end @@ -119,59 +73,3 @@ function tilde_observe!!(::DefaultContext, right, left, vn, vi) vi = accumulate_observe!!(vi, right, left, vn) return left, vi end - -function assume(::Random.AbstractRNG, spl::Sampler, dist) - return error("DynamicPPL.assume: unmanaged inference algorithm: $(typeof(spl))") -end - -# fallback without sampler -function assume(dist::Distribution, vn::VarName, vi) - y = getindex_internal(vi, vn) - f = from_maybe_linked_internal_transform(vi, vn, dist) - x, logjac = with_logabsdet_jacobian(f, y) - vi = accumulate_assume!!(vi, x, logjac, vn, dist) - return x, vi -end - -# TODO: Remove this thing. -# SampleFromPrior and SampleFromUniform -function assume( - rng::Random.AbstractRNG, - sampler::Union{SampleFromPrior,SampleFromUniform}, - dist::Distribution, - vn::VarName, - vi::VarInfoOrThreadSafeVarInfo, -) - if haskey(vi, vn) - # Always overwrite the parameters with new ones for `SampleFromUniform`. - if sampler isa SampleFromUniform || is_flagged(vi, vn, "del") - # TODO(mhauru) Is it important to unset the flag here? The `true` allows us - # to ignore the fact that for VarNamedVector this does nothing, but I'm unsure - # if that's okay. - unset_flag!(vi, vn, "del", true) - r = init(rng, dist, sampler) - f = to_maybe_linked_internal_transform(vi, vn, dist) - # TODO(mhauru) This should probably be call a function called setindex_internal! - vi = BangBang.setindex!!(vi, f(r), vn) - setorder!(vi, vn, get_num_produce(vi)) - else - # Otherwise we just extract it. - r = vi[vn, dist] - end - else - r = init(rng, dist, sampler) - if istrans(vi) - f = to_linked_internal_transform(vi, vn, dist) - vi = push!!(vi, vn, f(r), dist) - # By default `push!!` sets the transformed flag to `false`. - vi = settrans!!(vi, true, vn) - else - vi = push!!(vi, vn, r, dist) - end - end - - # HACK: The above code might involve an `invlink` somewhere, etc. so we need to correct. - logjac = logabsdetjac(istrans(vi, vn) ? link_transform(dist) : identity, r) - vi = accumulate_assume!!(vi, r, -logjac, vn, dist) - return r, vi -end diff --git a/src/contexts.jl b/src/contexts.jl index cd9876768..e50ba0df3 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -47,7 +47,7 @@ effectively updating the child context. ```jldoctest julia> using DynamicPPL: DynamicTransformationContext -julia> ctx = SamplingContext(); +julia> ctx = ConditionContext((; a = 1); julia> DynamicPPL.childcontext(ctx) DefaultContext() @@ -121,73 +121,6 @@ setleafcontext(::IsLeaf, ::IsParent, left, right) = right setleafcontext(::IsLeaf, ::IsLeaf, left, right) = right # Contexts -""" - SamplingContext( - [rng::Random.AbstractRNG=Random.default_rng()], - [sampler::AbstractSampler=SampleFromPrior()], - [context::AbstractContext=DefaultContext()], - ) - -Create a context that allows you to sample parameters with the `sampler` when running the model. -The `context` determines how the returned log density is computed when running the model. - -See also: [`DefaultContext`](@ref) -""" -struct SamplingContext{S<:AbstractSampler,C<:AbstractContext,R} <: AbstractContext - rng::R - sampler::S - context::C -end - -function SamplingContext( - rng::Random.AbstractRNG=Random.default_rng(), sampler::AbstractSampler=SampleFromPrior() -) - return SamplingContext(rng, sampler, DefaultContext()) -end - -function SamplingContext( - sampler::AbstractSampler, context::AbstractContext=DefaultContext() -) - return SamplingContext(Random.default_rng(), sampler, context) -end - -function SamplingContext(rng::Random.AbstractRNG, context::AbstractContext) - return SamplingContext(rng, SampleFromPrior(), context) -end - -function SamplingContext(context::AbstractContext) - return SamplingContext(Random.default_rng(), SampleFromPrior(), context) -end - -NodeTrait(context::SamplingContext) = IsParent() -childcontext(context::SamplingContext) = context.context -function setchildcontext(parent::SamplingContext, child) - return SamplingContext(parent.rng, parent.sampler, child) -end - -""" - hassampler(context) - -Return `true` if `context` has a sampler. -""" -hassampler(::SamplingContext) = true -hassampler(context::AbstractContext) = hassampler(NodeTrait(context), context) -hassampler(::IsLeaf, context::AbstractContext) = false -hassampler(::IsParent, context::AbstractContext) = hassampler(childcontext(context)) - -""" - getsampler(context) - -Return the sampler of the context `context`. - -This will traverse the context tree until it reaches the first [`SamplingContext`](@ref), -at which point it will return the sampler of that context. -""" -getsampler(context::SamplingContext) = context.sampler -getsampler(context::AbstractContext) = getsampler(NodeTrait(context), context) -getsampler(::IsParent, context::AbstractContext) = getsampler(childcontext(context)) -getsampler(::IsLeaf, ::AbstractContext) = error("No sampler found in context") - """ struct DefaultContext <: AbstractContext end diff --git a/src/debug_utils.jl b/src/debug_utils.jl index 4343ce8ac..af5e07d37 100644 --- a/src/debug_utils.jl +++ b/src/debug_utils.jl @@ -438,9 +438,10 @@ function check_model_and_trace( kwargs..., ) # Execute the model with the debug context. - debug_context = DebugContext( - SamplingContext(rng, model.context); error_on_failure=error_on_failure, kwargs... + new_context = DynamicPPL.setleafcontext( + model.context, DynamicPPL.InitContext(rng, DynamicPPL.PriorInit()) ) + debug_context = DebugContext(new_context; error_on_failure=error_on_failure, kwargs...) debug_model = DynamicPPL.contextualize(model, debug_context) # Perform checks before evaluating the model. diff --git a/src/sampler.jl b/src/sampler.jl index d43b840a5..4f4b0ed45 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -1,34 +1,3 @@ -# TODO: Make `UniformSampling` and `Prior` algs + just use `Sampler` -# That would let us use all defaults for Sampler, combine it with other samplers etc. -""" - SampleFromUniform - -Sampling algorithm that samples unobserved random variables from a uniform distribution. - -# References - -[Stan reference manual](https://mc-stan.org/docs/2_28/reference-manual/initialization.html#random-initial-values) -""" -struct SampleFromUniform <: AbstractSampler end - -""" - SampleFromPrior - -Sampling algorithm that samples unobserved random variables from their prior distribution. -""" -struct SampleFromPrior <: AbstractSampler end - -# Initializations. -init(rng, dist, ::SampleFromPrior) = rand(rng, dist) -function init(rng, dist, ::SampleFromUniform) - return istransformable(dist) ? inittrans(rng, dist) : rand(rng, dist) -end - -init(rng, dist, ::SampleFromPrior, n::Int) = rand(rng, dist, n) -function init(rng, dist, ::SampleFromUniform, n::Int) - return istransformable(dist) ? inittrans(rng, dist, n) : rand(rng, dist, n) -end - # TODO(mhauru) Could we get rid of Sampler now that it's just a wrapper around `alg`? # (Selector has been removed). """ @@ -49,20 +18,6 @@ struct Sampler{T} <: AbstractSampler alg::T end -# AbstractMCMC interface for SampleFromUniform and SampleFromPrior -function AbstractMCMC.step( - rng::Random.AbstractRNG, - model::Model, - sampler::Union{SampleFromUniform,SampleFromPrior}, - state=nothing; - kwargs..., -) - vi = VarInfo() - strategy = sampler isa SampleFromPrior ? PriorInit() : UniformInit() - DynamicPPL.init!!(rng, model, vi, strategy) - return vi, nothing -end - """ default_varinfo(rng, model, sampler) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 5e4371fba..86b5c80bb 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -457,23 +457,6 @@ function Base.merge(varinfo_left::SimpleVarInfo, varinfo_right::SimpleVarInfo) end # Context implementations -# NOTE: Evaluations, i.e. those without `rng` are shared with other -# implementations of `AbstractVarInfo`. -function assume( - rng::Random.AbstractRNG, - sampler::Union{SampleFromPrior,SampleFromUniform}, - dist::Distribution, - vn::VarName, - vi::SimpleOrThreadSafeSimple, -) - value = init(rng, dist, sampler) - # Transform if we're working in unconstrained space. - f = to_maybe_linked_internal_transform(vi, vn, dist) - value_raw, logjac = with_logabsdet_jacobian(f, value) - vi = BangBang.push!!(vi, vn, value_raw, dist) - vi = accumulate_assume!!(vi, value, -logjac, vn, dist) - return value, vi -end function settrans!!(vi::SimpleVarInfo, trans) return settrans!!(vi, trans ? DynamicTransformation() : NoTransformation()) diff --git a/src/utils.jl b/src/utils.jl index af2891a2b..8a770d9ef 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -456,50 +456,6 @@ function recombine(d::MultivariateDistribution, val::AbstractVector, n::Int) return copy(reshape(val, length(d), n)) end -# Uniform random numbers with range 4 for robust initializations -# Reference: https://mc-stan.org/docs/2_19/reference-manual/initialization.html -randrealuni(rng::Random.AbstractRNG) = 4 * rand(rng) - 2 -randrealuni(rng::Random.AbstractRNG, args...) = 4 .* rand(rng, args...) .- 2 - -istransformable(dist) = link_transform(dist) !== identity - -################################# -# Single-sample initialisations # -################################# - -inittrans(rng, dist::UnivariateDistribution) = Bijectors.invlink(dist, randrealuni(rng)) -function inittrans(rng, dist::MultivariateDistribution) - # Get the length of the unconstrained vector - b = link_transform(dist) - d = Bijectors.output_length(b, length(dist)) - return Bijectors.invlink(dist, randrealuni(rng, d)) -end -function inittrans(rng, dist::MatrixDistribution) - # Get the size of the unconstrained vector - b = link_transform(dist) - sz = Bijectors.output_size(b, size(dist)) - return Bijectors.invlink(dist, randrealuni(rng, sz...)) -end -function inittrans(rng, dist::Distribution{CholeskyVariate}) - # Get the size of the unconstrained vector - b = link_transform(dist) - sz = Bijectors.output_size(b, size(dist)) - return Bijectors.invlink(dist, randrealuni(rng, sz...)) -end -################################ -# Multi-sample initialisations # -################################ - -function inittrans(rng, dist::UnivariateDistribution, n::Int) - return Bijectors.invlink(dist, randrealuni(rng, n)) -end -function inittrans(rng, dist::MultivariateDistribution, n::Int) - return Bijectors.invlink(dist, randrealuni(rng, size(dist)[1], n)) -end -function inittrans(rng, dist::MatrixDistribution, n::Int) - return Bijectors.invlink(dist, [randrealuni(rng, size(dist)...) for _ in 1:n]) -end - ####################### # Convenience methods # ####################### diff --git a/test/ad.jl b/test/ad.jl index 48dffeadb..aed5a7abf 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -77,46 +77,6 @@ using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest end end - @testset "Turing#2151: ReverseDiff compilation & eltype(vi, spl)" begin - # Failing model - t = 1:0.05:8 - σ = 0.3 - y = @. rand(sin(t) + Normal(0, σ)) - @model function state_space(y, TT, ::Type{T}=Float64) where {T} - # Priors - α ~ Normal(y[1], 0.001) - τ ~ Exponential(1) - η ~ filldist(Normal(0, 1), TT - 1) - σ ~ Exponential(1) - # create latent variable - x = Vector{T}(undef, TT) - x[1] = α - for t in 2:TT - x[t] = x[t - 1] + η[t - 1] * τ - end - # measurement model - y ~ MvNormal(x, σ^2 * I) - return x - end - model = state_space(y, length(t)) - - # Dummy sampling algorithm for testing. The test case can only be replicated - # with a custom sampler, it doesn't work with SampleFromPrior(). We need to - # overload assume so that model evaluation doesn't fail due to a lack - # of implementation - struct MyEmptyAlg end - DynamicPPL.assume( - ::Random.AbstractRNG, ::DynamicPPL.Sampler{MyEmptyAlg}, dist, vn, vi - ) = DynamicPPL.assume(dist, vn, vi) - - # Compiling the ReverseDiff tape used to fail here - spl = Sampler(MyEmptyAlg()) - vi = VarInfo(model) - sampling_model = contextualize(model, SamplingContext(model.context)) - ldf = LogDensityFunction(sampling_model, vi; adtype=AutoReverseDiff(; compile=true)) - @test LogDensityProblems.logdensity_and_gradient(ldf, vi[:]) isa Any - end - # Test that various different ways of specifying array types as arguments work with all # ADTypes. @testset "Array argument types" begin diff --git a/test/contexts.jl b/test/contexts.jl index d7f2113c5..b643a864a 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -50,7 +50,6 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() contexts = Dict( :default => DefaultContext(), :testparent => DynamicPPL.TestUtils.TestParentContext(DefaultContext()), - :sampling => SamplingContext(), :prefix => PrefixContext(@varname(x)), :condition1 => ConditionContext((x=1.0,)), :condition2 => ConditionContext( @@ -151,11 +150,11 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() vn = @varname(x[1]) ctx1 = PrefixContext(@varname(a)) @test DynamicPPL.prefix(ctx1, vn) == @varname(a.x[1]) - ctx2 = SamplingContext(ctx1) + ctx2 = ConditionContext(Dict(), ctx1) @test DynamicPPL.prefix(ctx2, vn) == @varname(a.x[1]) ctx3 = PrefixContext(@varname(b), ctx2) @test DynamicPPL.prefix(ctx3, vn) == @varname(b.a.x[1]) - ctx4 = DynamicPPL.SamplingContext(ctx3) + ctx4 = FixedContext(Dict(), ctx3) @test DynamicPPL.prefix(ctx4, vn) == @varname(b.a.x[1]) end @@ -204,22 +203,6 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() end end - @testset "SamplingContext" begin - context = SamplingContext(Random.default_rng(), SampleFromPrior(), DefaultContext()) - @test context isa SamplingContext - - # convenience constructors - @test SamplingContext() == context - @test SamplingContext(Random.default_rng()) == context - @test SamplingContext(SampleFromPrior()) == context - @test SamplingContext(DefaultContext()) == context - @test SamplingContext(Random.default_rng(), SampleFromPrior()) == context - @test SamplingContext(Random.default_rng(), DefaultContext()) == context - @test SamplingContext(SampleFromPrior(), DefaultContext()) == context - @test SamplingContext(SampleFromPrior(), DefaultContext()) == context - @test EnzymeCore.EnzymeRules.inactive_type(typeof(context)) - end - @testset "ConditionContext" begin @testset "Nesting" begin @testset "NamedTuple" begin diff --git a/test/ext/DynamicPPLJETExt.jl b/test/ext/DynamicPPLJETExt.jl index 6737cf056..38cd62554 100644 --- a/test/ext/DynamicPPLJETExt.jl +++ b/test/ext/DynamicPPLJETExt.jl @@ -62,19 +62,14 @@ @testset "demo models" begin @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS - sampling_model = contextualize(model, SamplingContext(model.context)) # Use debug logging below. varinfo = DynamicPPL.Experimental.determine_suitable_varinfo(model) - # Check that the inferred varinfo is indeed suitable for evaluation and sampling + # Check that the inferred varinfo is indeed suitable for evaluation and initialisation f_eval, argtypes_eval = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( model, varinfo ) JET.test_call(f_eval, argtypes_eval) - f_sample, argtypes_sample = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( - sampling_model, varinfo - ) - JET.test_call(f_sample, argtypes_sample) # For our demo models, they should all result in typed. is_typed = varinfo isa DynamicPPL.NTVarInfo @test is_typed @@ -85,10 +80,6 @@ model, typed_vi ) JET.test_call(f_eval, argtypes_eval) - f_sample, argtypes_sample = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( - sampling_model, typed_vi - ) - JET.test_call(f_sample, argtypes_sample) end end end diff --git a/test/sampler.jl b/test/sampler.jl index d04f39ac9..0438362a6 100644 --- a/test/sampler.jl +++ b/test/sampler.jl @@ -1,57 +1,57 @@ @testset "sampler.jl" begin - @testset "SampleFromPrior and SampleUniform" begin - @model function gdemo(x, y) - s ~ InverseGamma(2, 3) - m ~ Normal(2.0, sqrt(s)) - x ~ Normal(m, sqrt(s)) - return y ~ Normal(m, sqrt(s)) - end - - model = gdemo(1.0, 2.0) - N = 1_000 - - chains = sample(model, SampleFromPrior(), N; progress=false) - @test chains isa Vector{<:VarInfo} - @test length(chains) == N - - # Expected value of ``X`` where ``X ~ N(2, ...)`` is 2. - @test mean(vi[@varname(m)] for vi in chains) ≈ 2 atol = 0.15 - - # Expected value of ``X`` where ``X ~ IG(2, 3)`` is 3. - @test mean(vi[@varname(s)] for vi in chains) ≈ 3 atol = 0.2 - - chains = sample(model, SampleFromUniform(), N; progress=false) - @test chains isa Vector{<:VarInfo} - @test length(chains) == N - - # `m` is Gaussian, i.e. no transformation is used, so it - # should have a mean equal to its prior, i.e. 2. - @test mean(vi[@varname(m)] for vi in chains) ≈ 2 atol = 0.1 - - # Expected value of ``exp(X)`` where ``X ~ U[-2, 2]`` is ≈ 1.8. - @test mean(vi[@varname(s)] for vi in chains) ≈ 1.8 atol = 0.1 - end - - @testset "init" begin - @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS - N = 1000 - chain_init = sample(model, SampleFromUniform(), N; progress=false) - - for vn in keys(first(chain_init)) - if AbstractPPL.subsumes(@varname(s), vn) - # `s ~ InverseGamma(2, 3)` and its unconstrained value will be sampled from Unif[-2,2]. - dist = InverseGamma(2, 3) - b = DynamicPPL.link_transform(dist) - @test mean(mean(b(vi[vn])) for vi in chain_init) ≈ 0 atol = 0.11 - elseif AbstractPPL.subsumes(@varname(m), vn) - # `m ~ Normal(0, sqrt(s))` and its constrained value is the same. - @test mean(mean(vi[vn]) for vi in chain_init) ≈ 0 atol = 0.11 - else - error("Unknown variable name: $vn") - end - end - end - end + # @testset "SampleFromPrior and SampleUniform" begin + # @model function gdemo(x, y) + # s ~ InverseGamma(2, 3) + # m ~ Normal(2.0, sqrt(s)) + # x ~ Normal(m, sqrt(s)) + # return y ~ Normal(m, sqrt(s)) + # end + # + # model = gdemo(1.0, 2.0) + # N = 1_000 + # + # chains = sample(model, SampleFromPrior(), N; progress=false) + # @test chains isa Vector{<:VarInfo} + # @test length(chains) == N + # + # # Expected value of ``X`` where ``X ~ N(2, ...)`` is 2. + # @test mean(vi[@varname(m)] for vi in chains) ≈ 2 atol = 0.15 + # + # # Expected value of ``X`` where ``X ~ IG(2, 3)`` is 3. + # @test mean(vi[@varname(s)] for vi in chains) ≈ 3 atol = 0.2 + # + # chains = sample(model, SampleFromUniform(), N; progress=false) + # @test chains isa Vector{<:VarInfo} + # @test length(chains) == N + # + # # `m` is Gaussian, i.e. no transformation is used, so it + # # should have a mean equal to its prior, i.e. 2. + # @test mean(vi[@varname(m)] for vi in chains) ≈ 2 atol = 0.1 + # + # # Expected value of ``exp(X)`` where ``X ~ U[-2, 2]`` is ≈ 1.8. + # @test mean(vi[@varname(s)] for vi in chains) ≈ 1.8 atol = 0.1 + # end + + # @testset "init" begin + # @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS + # N = 1000 + # chain_init = sample(model, SampleFromUniform(), N; progress=false) + # + # for vn in keys(first(chain_init)) + # if AbstractPPL.subsumes(@varname(s), vn) + # # `s ~ InverseGamma(2, 3)` and its unconstrained value will be sampled from Unif[-2,2]. + # dist = InverseGamma(2, 3) + # b = DynamicPPL.link_transform(dist) + # @test mean(mean(b(vi[vn])) for vi in chain_init) ≈ 0 atol = 0.11 + # elseif AbstractPPL.subsumes(@varname(m), vn) + # # `m ~ Normal(0, sqrt(s))` and its constrained value is the same. + # @test mean(mean(vi[vn]) for vi in chain_init) ≈ 0 atol = 0.11 + # else + # error("Unknown variable name: $vn") + # end + # end + # end + # end @testset "Initial parameters" begin # dummy algorithm that just returns initial value and does not perform any sampling diff --git a/test/threadsafe.jl b/test/threadsafe.jl index 24a738a78..85d86047a 100644 --- a/test/threadsafe.jl +++ b/test/threadsafe.jl @@ -68,8 +68,7 @@ @time model(vi) # Ensure that we use `ThreadSafeVarInfo` to handle multithreaded observe statements. - sampling_model = contextualize(model, SamplingContext(model.context)) - DynamicPPL.evaluate_threadsafe!!(sampling_model, vi) + DynamicPPL.evaluate_threadsafe!!(model, vi) @test getlogjoint(vi) ≈ lp_w_threads # check that it's wrapped during the model evaluation @test vi_ isa DynamicPPL.ThreadSafeVarInfo @@ -104,8 +103,7 @@ @test lp_w_threads ≈ lp_wo_threads # Ensure that we use `VarInfo`. - sampling_model = contextualize(model, SamplingContext(model.context)) - DynamicPPL.evaluate_threadunsafe!!(sampling_model, vi) + DynamicPPL.evaluate_threadunsafe!!(model, vi) @test getlogjoint(vi) ≈ lp_w_threads @test vi_ isa VarInfo @test vi isa VarInfo