From 3b87baae450162c17f49e5920120ac471f8907a9 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 9 Jul 2025 23:20:10 +0100 Subject: [PATCH 01/13] Implement InitContext --- src/DynamicPPL.jl | 7 ++ src/contexts/init.jl | 180 +++++++++++++++++++++++++++++++++++++++++++ src/model.jl | 33 ++++++++ test/contexts.jl | 12 ++- 4 files changed, 231 insertions(+), 1 deletion(-) create mode 100644 src/contexts/init.jl diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index b400e83dd..d117245e1 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -108,6 +108,12 @@ export AbstractVarInfo, ConditionContext, assume, tilde_assume, + # Initialisation + InitContext, + AbstractInitStrategy, + PriorInit, + UniformInit, + ParamsInit, # Pseudo distributions NamedDist, NoDist, @@ -174,6 +180,7 @@ include("sampler.jl") include("varname.jl") include("distribution_wrappers.jl") include("contexts.jl") +include("contexts/init.jl") include("submodel.jl") include("varnamedvector.jl") include("accumulators.jl") diff --git a/src/contexts/init.jl b/src/contexts/init.jl new file mode 100644 index 000000000..580b1a666 --- /dev/null +++ b/src/contexts/init.jl @@ -0,0 +1,180 @@ +""" + AbstractInitStrategy + +Abstract type representing the possible ways of initialising new values for +the random variables in a model (e.g., when creating a new VarInfo). +""" +abstract type AbstractInitStrategy end + +""" + init(rng::Random.AbstractRNG, vn::VarName, dist::Distribution, strategy::AbstractInitStrategy) + +Generate a new value for a random variable with the given distribution. + +!!! warning "Values must be unlinked" + The values returned by `init` are always in the untransformed space, i.e., + they must be within the support of the original distribution. That means that, + for example, `init(rng, dist, u::UniformInit)` will in general return values that + are outside the range [u.lower, u.upper]. +""" +function init end + +""" + PriorInit() + +Obtain new values by sampling from the prior distribution. +""" +struct PriorInit <: AbstractInitStrategy end +init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, ::PriorInit) = rand(rng, dist) + +""" + UniformInit() + UniformInit(lower, upper) + +Obtain new values by first transforming the distribution of the random variable +to unconstrained space, and then sampling a value uniformly between `lower` and +`upper`. + +If unspecified, defaults to `(lower, upper) = (-2, 2)`, which mimics Stan's +default initialisation strategy. + +# References + +[Stan reference manual page on initialization](https://mc-stan.org/docs/reference-manual/execution.html#initialization) +""" +struct UniformInit{T<:AbstractFloat} <: AbstractInitStrategy + lower::T + upper::T + function UniformInit(lower::T, upper::T) where {T<:AbstractFloat} + lower > upper && + throw(ArgumentError("`lower` must be less than or equal to `upper`")) + return new{T}(lower, upper) + end + UniformInit() = UniformInit(-2.0, 2.0) +end +function init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, u::UniformInit) + b = Bijectors.bijector(dist) + sz = Bijectors.output_size(b, size(dist)) + y = rand(rng, Uniform(u.lower, u.upper), sz) + b_inv = Bijectors.inverse(b) + x = b_inv(y) + # 0-dim arrays: https://github.com/TuringLang/Bijectors.jl/issues/398 + if x isa Array{<:Any,0} + x = x[] + end + return x +end + +""" + ParamsInit(params::AbstractDict{<:VarName}, default::AbstractInitStrategy=PriorInit()) + ParamsInit(params::NamedTuple, default::AbstractInitStrategy=PriorInit()) + +Obtain new values by extracting them from the given dictionary or NamedTuple. +The parameter `default` specifies how new values are to be obtained if they +cannot be found in `params`, or they are specified as `missing`. The default +for `default` is `PriorInit()`. + +!!! note + These values must be provided in the space of the untransformed distribution. +""" +struct ParamsInit{P,S<:AbstractInitStrategy} <: AbstractInitStrategy + params::P + default::S + function ParamsInit(params::AbstractDict{<:VarName}, default::AbstractInitStrategy) + return new{typeof(params),typeof(default)}(params, default) + end + ParamsInit(params::AbstractDict{<:VarName}) = ParamsInit(params, PriorInit()) + function ParamsInit(params::NamedTuple, default::AbstractInitStrategy=PriorInit()) + return ParamsInit(to_varname_dict(params), default) + end +end +function init(rng::Random.AbstractRNG, vn::VarName, dist::Distribution, p::ParamsInit) + # TODO(penelopeysm): We should do a check to make sure that all of the + # parameters in `p.params` were actually used, and either warn or error if + # they aren't. This is non-trivial (we need to use something like + # varname_leaves), so I'm going to defer it to a later PR. + return if hasvalue(p.params, vn, dist) + x = getvalue(p.params, vn, dist) + if x === missing + init(rng, vn, dist, p.default) + else + # TODO(penelopeysm): We could also check that the type of x matches + # the dist? + x + end + else + init(rng, vn, dist, p.default) + end +end + +""" + InitContext( + [rng::Random.AbstractRNG=Random.default_rng()], + [strategy::AbstractInitStrategy=PriorInit()], + ) + +A leaf context that indicates that new values for random variables are +currently being obtained through sampling. Used e.g. when initialising a fresh +VarInfo. Note that, if `leafcontext(model.context) isa InitContext`, then +`evaluate!!(model, varinfo)` will override all values in the VarInfo. +""" +struct InitContext{R<:Random.AbstractRNG,S<:AbstractInitStrategy} <: AbstractContext + rng::R + strategy::S + function InitContext( + rng::Random.AbstractRNG, strategy::AbstractInitStrategy=PriorInit() + ) + return new{typeof(rng),typeof(strategy)}(rng, strategy) + end + function InitContext(strategy::AbstractInitStrategy=PriorInit()) + return InitContext(Random.default_rng(), strategy) + end +end +NodeTrait(::InitContext) = IsLeaf() + +function tilde_assume( + ctx::InitContext, dist::Distribution, vn::VarName, vi::AbstractVarInfo +) + in_varinfo = haskey(vi, vn) + # `init()` always returns values in original space, i.e. possibly + # constrained + x = init(ctx.rng, vn, dist, ctx.strategy) + # Determine whether to insert a transformed value into the VarInfo. + # If the VarInfo alrady had a value for this variable, we will + # keep the same linked status as in the original VarInfo. If not, we + # check the rest of the VarInfo to see if other variables are linked. + # istrans(vi) returns true if vi is nonempty and all variables in vi + # are linked. + insert_transformed_value = in_varinfo ? istrans(vi, vn) : istrans(vi) + f = if insert_transformed_value + to_linked_internal_transform(vi, vn, dist) + else + to_internal_transform(vi, vn, dist) + end + # TODO(penelopeysm): We would really like to do: + # y, logjac = with_logabsdet_jacobian(f, x) + # Unfortunately, `to_{linked_}internal_transform` returns a function that + # always converts x to a vector, i.e., if dist is univariate, f(x) will be + # a vector of length 1. It would be nice if we could unify these. + y = f(x) + logjac = logabsdetjac(insert_transformed_value ? link_transform(dist) : identity, x) + # Add the new value to the VarInfo. `push!!` errors if the value already + # exists, hence the need for setindex!!. + if in_varinfo + vi = setindex!!(vi, y, vn) + else + vi = push!!(vi, vn, y, dist) + end + # Neither of these set the `trans` flag so we have to do it manually if + # necessary. + insert_transformed_value && settrans!!(vi, true, vn) + # `accumulate_assume!!` wants untransformed values as the second argument. + vi = accumulate_assume!!(vi, x, -logjac, vn, dist) + # We always return the untransformed value here, as that will determine + # what the lhs of the tilde-statement is set to. + return x, vi +end + +function tilde_observe!!(::InitContext, right, left, vn, vi) + return tilde_observe!!(DefaultContext(), right, left, vn, vi) +end diff --git a/src/model.jl b/src/model.jl index 9f9c6ec3b..73129d97f 100644 --- a/src/model.jl +++ b/src/model.jl @@ -854,6 +854,39 @@ function evaluate_and_sample!!( return evaluate_and_sample!!(Random.default_rng(), model, varinfo, sampler) end +""" + init!!( + [rng::Random.AbstractRNG,] + model::Model, + varinfo::AbstractVarInfo, + [init_strategy::AbstractInitStrategy=PriorInit()] + ) + +Evaluate the `model` and replace the values of the model's random variables in +the given `varinfo` with new values using a specified initialisation strategy. +If the values in `varinfo` are not already present, they will be added using +that same strategy. + +If `init_strategy` is not provided, defaults to PriorInit(). + +Returns a tuple of the model's return value, plus the updated `varinfo` object. +""" +function init!!( + rng::Random.AbstractRNG, + model::Model, + varinfo::AbstractVarInfo, + init_strategy::AbstractInitStrategy=PriorInit(), +) + new_context = setleafcontext(model.context, InitContext(rng, init_strategy)) + new_model = contextualize(model, new_context) + return evaluate!!(new_model, varinfo) +end +function init!!( + model::Model, varinfo::AbstractVarInfo, init_strategy::AbstractInitStrategy=PriorInit() +) + return init!!(Random.default_rng(), model, varinfo, init_strategy) +end + """ evaluate!!(model::Model, varinfo) diff --git a/test/contexts.jl b/test/contexts.jl index 597ab736c..be976aad4 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -1,5 +1,5 @@ using Test, DynamicPPL, Accessors -using AbstractPPL: getoptic +using AbstractPPL: getoptic, hasvalue, getvalue using DynamicPPL: leafcontext, setleafcontext, @@ -431,4 +431,14 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() @test fixed(c6) == Dict(@varname(a.b.d) => 2) end end + + @testset "InitContext" begin + @testset "PriorInit" begin end + + @testset "UniformInit" begin end + + @testset "ParamsInit" begin end + + @testset "rng is respected (at least with PriorInit" begin end + end end From 75f69b9329b253850bed3e16929ff228bedebac7 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 9 Jul 2025 23:25:10 +0100 Subject: [PATCH 02/13] Fix loading order of modules; move `prefix(::Model)` to model.jl --- src/DynamicPPL.jl | 4 ++-- src/contexts.jl | 35 ----------------------------------- src/model.jl | 35 +++++++++++++++++++++++++++++++++++ 3 files changed, 37 insertions(+), 37 deletions(-) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index d117245e1..067ef327a 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -175,12 +175,12 @@ abstract type AbstractVarInfo <: AbstractModelTrace end # Necessary forward declarations include("utils.jl") include("chains.jl") +include("contexts.jl") +include("contexts/init.jl") include("model.jl") include("sampler.jl") include("varname.jl") include("distribution_wrappers.jl") -include("contexts.jl") -include("contexts/init.jl") include("submodel.jl") include("varnamedvector.jl") include("accumulators.jl") diff --git a/src/contexts.jl b/src/contexts.jl index addadfa1a..cd9876768 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -280,41 +280,6 @@ function prefix_and_strip_contexts(::IsParent, ctx::AbstractContext, vn::VarName return vn, setchildcontext(ctx, new_ctx) end -""" - prefix(model::Model, x::VarName) - prefix(model::Model, x::Val{sym}) - prefix(model::Model, x::Any) - -Return `model` but with all random variables prefixed by `x`, where `x` is either: -- a `VarName` (e.g. `@varname(a)`), -- a `Val{sym}` (e.g. `Val(:a)`), or -- for any other type, `x` is converted to a Symbol and then to a `VarName`. Note that - this will introduce runtime overheads so is not recommended unless absolutely - necessary. - -# Examples - -```jldoctest -julia> using DynamicPPL: prefix - -julia> @model demo() = x ~ Dirac(1) -demo (generic function with 2 methods) - -julia> rand(prefix(demo(), @varname(my_prefix))) -(var"my_prefix.x" = 1,) - -julia> rand(prefix(demo(), Val(:my_prefix))) -(var"my_prefix.x" = 1,) -``` -""" -prefix(model::Model, x::VarName) = contextualize(model, PrefixContext(x, model.context)) -function prefix(model::Model, x::Val{sym}) where {sym} - return contextualize(model, PrefixContext(VarName{sym}(), model.context)) -end -function prefix(model::Model, x) - return contextualize(model, PrefixContext(VarName{Symbol(x)}(), model.context)) -end - """ ConditionContext{Values<:Union{NamedTuple,AbstractDict},Ctx<:AbstractContext} diff --git a/src/model.jl b/src/model.jl index 73129d97f..0c51cab0e 100644 --- a/src/model.jl +++ b/src/model.jl @@ -799,6 +799,41 @@ julia> # Now `a.x` will be sampled. """ fixed(model::Model) = fixed(model.context) +""" + prefix(model::Model, x::VarName) + prefix(model::Model, x::Val{sym}) + prefix(model::Model, x::Any) + +Return `model` but with all random variables prefixed by `x`, where `x` is either: +- a `VarName` (e.g. `@varname(a)`), +- a `Val{sym}` (e.g. `Val(:a)`), or +- for any other type, `x` is converted to a Symbol and then to a `VarName`. Note that + this will introduce runtime overheads so is not recommended unless absolutely + necessary. + +# Examples + +```jldoctest +julia> using DynamicPPL: prefix + +julia> @model demo() = x ~ Dirac(1) +demo (generic function with 2 methods) + +julia> rand(prefix(demo(), @varname(my_prefix))) +(var"my_prefix.x" = 1,) + +julia> rand(prefix(demo(), Val(:my_prefix))) +(var"my_prefix.x" = 1,) +``` +""" +prefix(model::Model, x::VarName) = contextualize(model, PrefixContext(x, model.context)) +function prefix(model::Model, x::Val{sym}) where {sym} + return contextualize(model, PrefixContext(VarName{sym}(), model.context)) +end +function prefix(model::Model, x) + return contextualize(model, PrefixContext(VarName{Symbol(x)}(), model.context)) +end + """ (model::Model)([rng, varinfo]) From 1988baae867aafe2a20aadc57537c9bb1a93c3e8 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 10 Jul 2025 00:21:53 +0100 Subject: [PATCH 03/13] Add tests for InitContext behaviour --- src/contexts/init.jl | 12 +-- test/contexts.jl | 183 +++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 179 insertions(+), 16 deletions(-) diff --git a/src/contexts/init.jl b/src/contexts/init.jl index 580b1a666..6ff276d21 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -147,17 +147,11 @@ function tilde_assume( # are linked. insert_transformed_value = in_varinfo ? istrans(vi, vn) : istrans(vi) f = if insert_transformed_value - to_linked_internal_transform(vi, vn, dist) + link_transform(dist) else - to_internal_transform(vi, vn, dist) + identity end - # TODO(penelopeysm): We would really like to do: - # y, logjac = with_logabsdet_jacobian(f, x) - # Unfortunately, `to_{linked_}internal_transform` returns a function that - # always converts x to a vector, i.e., if dist is univariate, f(x) will be - # a vector of length 1. It would be nice if we could unify these. - y = f(x) - logjac = logabsdetjac(insert_transformed_value ? link_transform(dist) : identity, x) + y, logjac = with_logabsdet_jacobian(f, x) # Add the new value to the VarInfo. `push!!` errors if the value already # exists, hence the need for setindex!!. if in_varinfo diff --git a/test/contexts.jl b/test/contexts.jl index be976aad4..5768757bb 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -20,8 +20,9 @@ using DynamicPPL: hasconditioned_nested, getconditioned_nested, collapse_prefix_stack, - prefix_cond_and_fixed_variables, - getvalue + prefix_cond_and_fixed_variables +using LinearAlgebra: I +using Random: Xoshiro using EnzymeCore @@ -103,7 +104,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() # sometimes only the main symbol (e.g. it contains `x` when # `vn` is `x[1]`) for vn in conditioned_vns - val = DynamicPPL.getvalue(conditioned_values, vn) + val = getvalue(conditioned_values, vn) # These VarNames are present in the conditioning values, so # we should always be able to extract the value. @test hasconditioned_nested(context, vn) @@ -433,12 +434,180 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() end @testset "InitContext" begin - @testset "PriorInit" begin end + empty_varinfos = [ + VarInfo(), + DynamicPPL.typed_varinfo(VarInfo()), + VarInfo(DynamicPPL.VarNamedVector()), + DynamicPPL.typed_vector_varinfo(DynamicPPL.typed_varinfo(VarInfo())), + SimpleVarInfo(), + SimpleVarInfo(Dict{VarName,Any}()), + ] + + @model function test_init_model() + x ~ Normal() + y ~ MvNormal(fill(x, 2), I) + 1.0 ~ Normal() + return nothing + end + function test_generating_new_values(strategy::AbstractInitStrategy) + @testset "generating new values: $(typeof(strategy))" begin + # Check that init!! can generate values that weren't there + # previously. + model = test_init_model() + for empty_vi in empty_varinfos + this_vi = deepcopy(empty_vi) + _, vi = DynamicPPL.init!!(model, this_vi, strategy) + @test Set(keys(vi)) == Set([@varname(x), @varname(y)]) + x, y = vi[@varname(x)], vi[@varname(y)] + @test x isa Real + @test y isa AbstractVector{<:Real} + @test length(y) == 2 + (; logprior, loglikelihood) = getlogp(vi) + @test logpdf(Normal(), x) + logpdf(MvNormal(fill(x, 2), I), y) == + logprior + @test logpdf(Normal(), 1.0) == loglikelihood + end + end + end + function test_replacing_values(strategy::AbstractInitStrategy) + @testset "replacing old values: $(typeof(strategy))" begin + # Check that init!! can overwrite values that were already there. + model = test_init_model() + for empty_vi in empty_varinfos + # start by generating some rubbish values + vi = deepcopy(empty_vi) + old_x, old_y = 100000.00, [300000.00, 500000.00] + push!!(vi, @varname(x), old_x, Normal()) + push!!(vi, @varname(y), old_y, MvNormal(fill(old_x, 2), I)) + # then overwrite it + _, new_vi = DynamicPPL.init!!(model, vi, strategy) + new_x, new_y = new_vi[@varname(x)], new_vi[@varname(y)] + # check that the values are (presumably) different + @test old_x != new_x + @test old_y != new_y + end + end + end + function test_rng_respected(strategy::AbstractInitStrategy) + @testset "check that RNG is respected: $(typeof(strategy))" begin + model = test_init_model() + for empty_vi in empty_varinfos + _, vi1 = DynamicPPL.init!!( + Xoshiro(468), model, deepcopy(empty_vi), strategy + ) + _, vi2 = DynamicPPL.init!!( + Xoshiro(468), model, deepcopy(empty_vi), strategy + ) + _, vi3 = DynamicPPL.init!!( + Xoshiro(469), model, deepcopy(empty_vi), strategy + ) + @test vi1[@varname(x)] == vi2[@varname(x)] + @test vi1[@varname(y)] == vi2[@varname(y)] + @test vi1[@varname(x)] != vi3[@varname(x)] + @test vi1[@varname(y)] != vi3[@varname(y)] + end + end + end - @testset "UniformInit" begin end + @testset "PriorInit" begin + test_generating_new_values(PriorInit()) + test_replacing_values(PriorInit()) + test_rng_respected(PriorInit()) + + @testset "check that values are within support" begin + # Not many other sensible checks we can do for priors. + @model just_unif() = x ~ Uniform(0.0, 1e-7) + for _ in 1:100 + _, vi = DynamicPPL.init!!(just_unif(), VarInfo(), PriorInit()) + @test vi[@varname(x)] isa Real + @test 0.0 <= vi[@varname(x)] <= 1e-7 + end + end + end - @testset "ParamsInit" begin end + @testset "UniformInit" begin + test_generating_new_values(UniformInit()) + test_replacing_values(UniformInit()) + test_rng_respected(UniformInit()) + + @testset "check that bounds are respected" begin + @testset "unconstrained" begin + umin, umax = -1.0, 1.0 + @model just_norm() = x ~ Normal() + for _ in 1:100 + _, vi = DynamicPPL.init!!( + just_norm(), VarInfo(), UniformInit(umin, umax) + ) + @test vi[@varname(x)] isa Real + @test umin <= vi[@varname(x)] <= umax + end + end + @testset "constrained" begin + umin, umax = -1.0, 1.0 + @model just_beta() = x ~ Beta(2, 2) + inv_bijector = inverse(Bijectors.bijector(Beta(2, 2))) + tmin, tmax = inv_bijector(umin), inv_bijector(umax) + for _ in 1:100 + _, vi = DynamicPPL.init!!( + just_beta(), VarInfo(), UniformInit(umin, umax) + ) + @test vi[@varname(x)] isa Real + @test tmin <= vi[@varname(x)] <= tmax + end + end + end + end - @testset "rng is respected (at least with PriorInit" begin end + @testset "ParamsInit" begin + @testset "given full set of parameters" begin + # test_init_model has x ~ Normal() and y ~ MvNormal(zeros(2), I) + my_x, my_y = 1.0, [2.0, 3.0] + params_nt = (; x=my_x, y=my_y) + params_dict = Dict(@varname(x) => my_x, @varname(y) => my_y) + model = test_init_model() + for empty_vi in empty_varinfos + _, vi = DynamicPPL.init!!( + model, deepcopy(empty_vi), ParamsInit(params_nt) + ) + @test vi[@varname(x)] == my_x + @test vi[@varname(y)] == my_y + logp_nt = getlogp(vi) + _, vi = DynamicPPL.init!!( + model, deepcopy(empty_vi), ParamsInit(params_dict) + ) + @test vi[@varname(x)] == my_x + @test vi[@varname(y)] == my_y + logp_dict = getlogp(vi) + @test logp_nt == logp_dict + end + end + + @testset "given only partial parameters" begin + # In this case, we expect `ParamsInit` to use the value of x, and + # generate a new value for y. + my_x = 1.0 + params_nt = (; x=my_x) + params_dict = Dict(@varname(x) => my_x) + model = test_init_model() + for empty_vi in empty_varinfos + _, vi = DynamicPPL.init!!( + Xoshiro(468), model, deepcopy(empty_vi), ParamsInit(params_nt) + ) + @test vi[@varname(x)] == my_x + nt_y = vi[@varname(y)] + @test nt_y isa AbstractVector{<:Real} + @test length(nt_y) == 2 + _, vi = DynamicPPL.init!!( + Xoshiro(469), model, deepcopy(empty_vi), ParamsInit(params_dict) + ) + @test vi[@varname(x)] == my_x + dict_y = vi[@varname(y)] + @test dict_y isa AbstractVector{<:Real} + @test length(dict_y) == 2 + # the values should be different since we used different seeds + @test dict_y != nt_y + end + end + end end end From 4d4ffd912e26b3de5abc66c4a3d6fe68f8c8e4ce Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 10 Jul 2025 00:30:07 +0100 Subject: [PATCH 04/13] inline `rand(::Distributions.Uniform)` Note that, apart from being simpler code, Distributions.Uniform also doesn't allow the lower and upper bounds to be exactly equal (but we might like to keep that option open in DynamicPPL, e.g. if the user wants to initialise all values to the same value in linked space). --- src/contexts/init.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/contexts/init.jl b/src/contexts/init.jl index 6ff276d21..3b7007f51 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -38,6 +38,8 @@ to unconstrained space, and then sampling a value uniformly between `lower` and If unspecified, defaults to `(lower, upper) = (-2, 2)`, which mimics Stan's default initialisation strategy. +Requires that `lower <= upper`. + # References [Stan reference manual page on initialization](https://mc-stan.org/docs/reference-manual/execution.html#initialization) @@ -55,7 +57,7 @@ end function init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, u::UniformInit) b = Bijectors.bijector(dist) sz = Bijectors.output_size(b, size(dist)) - y = rand(rng, Uniform(u.lower, u.upper), sz) + y = u.lower .+ ((u.upper - u.lower) .* rand(rng, sz...)) b_inv = Bijectors.inverse(b) x = b_inv(y) # 0-dim arrays: https://github.com/TuringLang/Bijectors.jl/issues/398 From d24f96645d0a72dcdaeba5ec9b179b42287c0acd Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 10 Jul 2025 00:46:43 +0100 Subject: [PATCH 05/13] Document --- docs/src/api.md | 21 +++++++++++++++++++++ src/contexts/init.jl | 20 ++++++++++---------- 2 files changed, 31 insertions(+), 10 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index 9a1923b53..11e999451 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -463,6 +463,27 @@ SamplingContext DefaultContext PrefixContext ConditionContext +InitContext +``` + +### VarInfo initialisation + +`InitContext` is used to initialise, or overwrite, values in a VarInfo. + +To accomplish this, an initialisation _strategy_ is required, which defines how new values are to be obtained. +There are three concrete strategies provided in DynamicPPL: + +```@docs +PriorInit +UniformInit +ParamsInit +``` + +If you wish to write your own, you have to subtype [`DynamicPPL.AbstractInitStrategy`](@ref) and implement the `init` method. + +```@docs +DynamicPPL.AbstractInitStrategy +DynamicPPL.init ``` ### Samplers diff --git a/src/contexts/init.jl b/src/contexts/init.jl index 3b7007f51..2b87b533b 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -32,11 +32,11 @@ init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, ::PriorInit) = rand UniformInit(lower, upper) Obtain new values by first transforming the distribution of the random variable -to unconstrained space, and then sampling a value uniformly between `lower` and -`upper`. +to unconstrained space, then sampling a value uniformly between `lower` and +`upper`, and transforming that value back to the original space. -If unspecified, defaults to `(lower, upper) = (-2, 2)`, which mimics Stan's -default initialisation strategy. +If `lower` and `upper` are unspecified, they default to `(-2, 2)`, which mimics +Stan's default initialisation strategy. Requires that `lower <= upper`. @@ -91,17 +91,17 @@ struct ParamsInit{P,S<:AbstractInitStrategy} <: AbstractInitStrategy end end function init(rng::Random.AbstractRNG, vn::VarName, dist::Distribution, p::ParamsInit) - # TODO(penelopeysm): We should do a check to make sure that all of the - # parameters in `p.params` were actually used, and either warn or error if - # they aren't. This is non-trivial (we need to use something like - # varname_leaves), so I'm going to defer it to a later PR. + # TODO(penelopeysm): It would be nice to do a check to make sure that all + # of the parameters in `p.params` were actually used, and either warn or + # error if they aren't. This is actually quite non-trivial though because + # the structure of Dicts in particular can have arbitrary nesting. return if hasvalue(p.params, vn, dist) x = getvalue(p.params, vn, dist) if x === missing init(rng, vn, dist, p.default) else - # TODO(penelopeysm): We could also check that the type of x matches - # the dist? + # TODO(penelopeysm): Since x is user-supplied, maybe we could also + # check here that the type / size of x matches the dist? x end else From 42d74b6dfb54e49965840e1df5f645f2cfb08e34 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sat, 19 Jul 2025 23:36:33 +0100 Subject: [PATCH 06/13] Add a test to check that `init!!` doesn't change linking --- src/contexts/init.jl | 2 +- test/contexts.jl | 26 ++++++++++++++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/src/contexts/init.jl b/src/contexts/init.jl index 2b87b533b..baa1bf088 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -165,7 +165,7 @@ function tilde_assume( # necessary. insert_transformed_value && settrans!!(vi, true, vn) # `accumulate_assume!!` wants untransformed values as the second argument. - vi = accumulate_assume!!(vi, x, -logjac, vn, dist) + vi = accumulate_assume!!(vi, x, logjac, vn, dist) # We always return the untransformed value here, as that will determine # what the lhs of the tilde-statement is set to. return x, vi diff --git a/test/contexts.jl b/test/contexts.jl index 5768757bb..e2a882f48 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -508,11 +508,33 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() end end end + function test_link_status_respected(strategy::AbstractInitStrategy) + @testset "check that varinfo linking is preserved: $(typeof(strategy))" begin + @model logn() = a ~ LogNormal() + model = logn() + vi = VarInfo(model) + linked_vi = DynamicPPL.link!!(vi, model) + _, new_vi = DynamicPPL.init!!(model, linked_vi, strategy) + @test DynamicPPL.istrans(new_vi) + # this is the unlinked value, since it uses `getindex` + a = new_vi[@varname(a)] + # internal logjoint should correspond to the transformed value + @test isapprox( + DynamicPPL.getlogjoint_internal(new_vi), logpdf(Normal(), log(a)) + ) + # user logjoint should correspond to the transformed value + @test isapprox(DynamicPPL.getlogjoint(new_vi), logpdf(LogNormal(), a)) + @test isapprox( + only(DynamicPPL.getindex_internal(new_vi, @varname(a))), log(a) + ) + end + end @testset "PriorInit" begin test_generating_new_values(PriorInit()) test_replacing_values(PriorInit()) test_rng_respected(PriorInit()) + test_link_status_respected(PriorInit()) @testset "check that values are within support" begin # Not many other sensible checks we can do for priors. @@ -529,6 +551,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() test_generating_new_values(UniformInit()) test_replacing_values(UniformInit()) test_rng_respected(UniformInit()) + test_link_status_respected(UniformInit()) @testset "check that bounds are respected" begin @testset "unconstrained" begin @@ -559,6 +582,9 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() end @testset "ParamsInit" begin + test_link_status_respected(ParamsInit((; a=1.0))) + test_link_status_respected(ParamsInit(Dict(@varname(a) => 1.0))) + @testset "given full set of parameters" begin # test_init_model has x ~ Normal() and y ~ MvNormal(zeros(2), I) my_x, my_y = 1.0, [2.0, 3.0] From ab3e8da9998108416f4d0dd4b0c4065a8162bfde Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 20 Jul 2025 01:58:18 +0100 Subject: [PATCH 07/13] Fix `push!` for VarNamedVector This should have been changed in #940, but slipped through as the file wasn't listed as one of the changed files. --- src/varnamedvector.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/varnamedvector.jl b/src/varnamedvector.jl index d756a4922..2336b89b6 100644 --- a/src/varnamedvector.jl +++ b/src/varnamedvector.jl @@ -766,6 +766,11 @@ function update_internal!( return nothing end +function BangBang.push!(vnv::VarNamedVector, vn, val, dist) + f = from_vec_transform(dist) + return setindex_internal!(vnv, tovec(val), vn, f) +end + # BangBang versions of the above functions. # The only difference is that update_internal!! and insert_internal!! check whether the # container types of the VarNamedVector vector need to be expanded to accommodate the new From be8a1b348117b023f67262ad23a6723562e9703d Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 12 Aug 2025 18:07:06 +0100 Subject: [PATCH 08/13] Add some line breaks Co-authored-by: Markus Hauru --- test/contexts.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/contexts.jl b/test/contexts.jl index e2a882f48..d900ccf61 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -449,6 +449,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() 1.0 ~ Normal() return nothing end + function test_generating_new_values(strategy::AbstractInitStrategy) @testset "generating new values: $(typeof(strategy))" begin # Check that init!! can generate values that weren't there @@ -469,6 +470,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() end end end + function test_replacing_values(strategy::AbstractInitStrategy) @testset "replacing old values: $(typeof(strategy))" begin # Check that init!! can overwrite values that were already there. @@ -488,6 +490,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() end end end + function test_rng_respected(strategy::AbstractInitStrategy) @testset "check that RNG is respected: $(typeof(strategy))" begin model = test_init_model() @@ -508,6 +511,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() end end end + function test_link_status_respected(strategy::AbstractInitStrategy) @testset "check that varinfo linking is preserved: $(typeof(strategy))" begin @model logn() = a ~ LogNormal() From c40a1933cd5e2eaff033188d73a9891f8593b6a7 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 12 Aug 2025 18:24:33 +0100 Subject: [PATCH 09/13] Add the option of no fallback for ParamsInit --- src/contexts/init.jl | 29 ++++++++++---- test/contexts.jl | 90 ++++++++++++++++++++++++++++++-------------- 2 files changed, 83 insertions(+), 36 deletions(-) diff --git a/src/contexts/init.jl b/src/contexts/init.jl index baa1bf088..0ca14c85f 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -68,25 +68,35 @@ function init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, u::Uniform end """ - ParamsInit(params::AbstractDict{<:VarName}, default::AbstractInitStrategy=PriorInit()) - ParamsInit(params::NamedTuple, default::AbstractInitStrategy=PriorInit()) + ParamsInit( + params::Union{AbstractDict{<:VarName},NamedTuple}, + default::Union{AbstractInitStrategy,Nothing}=PriorInit() + ) Obtain new values by extracting them from the given dictionary or NamedTuple. + The parameter `default` specifies how new values are to be obtained if they -cannot be found in `params`, or they are specified as `missing`. The default -for `default` is `PriorInit()`. +cannot be found in `params`, or they are specified as `missing`. `default` +can either be an initialisation strategy itself, in which case it will be +used to obtain new values, or it can be `nothing`, in which case an error +will be thrown. The default for `default` is `PriorInit()`. !!! note - These values must be provided in the space of the untransformed distribution. + The values in `params` must be provided in the space of the untransformed +distribution. """ -struct ParamsInit{P,S<:AbstractInitStrategy} <: AbstractInitStrategy +struct ParamsInit{P,S<:Union{AbstractInitStrategy,Nothing}} <: AbstractInitStrategy params::P default::S - function ParamsInit(params::AbstractDict{<:VarName}, default::AbstractInitStrategy) + function ParamsInit( + params::AbstractDict{<:VarName}, default::Union{AbstractInitStrategy,Nothing} + ) return new{typeof(params),typeof(default)}(params, default) end ParamsInit(params::AbstractDict{<:VarName}) = ParamsInit(params, PriorInit()) - function ParamsInit(params::NamedTuple, default::AbstractInitStrategy=PriorInit()) + function ParamsInit( + params::NamedTuple, default::Union{AbstractInitStrategy,Nothing}=PriorInit() + ) return ParamsInit(to_varname_dict(params), default) end end @@ -98,6 +108,8 @@ function init(rng::Random.AbstractRNG, vn::VarName, dist::Distribution, p::Param return if hasvalue(p.params, vn, dist) x = getvalue(p.params, vn, dist) if x === missing + p.default === nothing && + error("A `missing` value was provided for the variable `$(vn)`.") init(rng, vn, dist, p.default) else # TODO(penelopeysm): Since x is user-supplied, maybe we could also @@ -105,6 +117,7 @@ function init(rng::Random.AbstractRNG, vn::VarName, dist::Distribution, p::Param x end else + p.default === nothing && error("No value was provided for the variable `$(vn)`.") init(rng, vn, dist, p.default) end end diff --git a/test/contexts.jl b/test/contexts.jl index d900ccf61..b0b5fb8db 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -435,12 +435,15 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() @testset "InitContext" begin empty_varinfos = [ - VarInfo(), - DynamicPPL.typed_varinfo(VarInfo()), - VarInfo(DynamicPPL.VarNamedVector()), - DynamicPPL.typed_vector_varinfo(DynamicPPL.typed_varinfo(VarInfo())), - SimpleVarInfo(), - SimpleVarInfo(Dict{VarName,Any}()), + ("untyped+metadata", VarInfo()), + ("typed+metadata", DynamicPPL.typed_varinfo(VarInfo())), + ("untyped+VNV", VarInfo(DynamicPPL.VarNamedVector())), + ( + "typed+VNV", + DynamicPPL.typed_vector_varinfo(DynamicPPL.typed_varinfo(VarInfo())), + ), + ("SVI+NamedTuple", SimpleVarInfo()), + ("Svi+Dict", SimpleVarInfo(Dict{VarName,Any}())), ] @model function test_init_model() @@ -455,7 +458,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() # Check that init!! can generate values that weren't there # previously. model = test_init_model() - for empty_vi in empty_varinfos + @testset "$vi_name" for (vi_name, empty_vi) in empty_varinfos this_vi = deepcopy(empty_vi) _, vi = DynamicPPL.init!!(model, this_vi, strategy) @test Set(keys(vi)) == Set([@varname(x), @varname(y)]) @@ -475,7 +478,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() @testset "replacing old values: $(typeof(strategy))" begin # Check that init!! can overwrite values that were already there. model = test_init_model() - for empty_vi in empty_varinfos + @testset "$vi_name" for (vi_name, empty_vi) in empty_varinfos # start by generating some rubbish values vi = deepcopy(empty_vi) old_x, old_y = 100000.00, [300000.00, 500000.00] @@ -494,7 +497,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() function test_rng_respected(strategy::AbstractInitStrategy) @testset "check that RNG is respected: $(typeof(strategy))" begin model = test_init_model() - for empty_vi in empty_varinfos + @testset "$vi_name" for (vi_name, empty_vi) in empty_varinfos _, vi1 = DynamicPPL.init!!( Xoshiro(468), model, deepcopy(empty_vi), strategy ) @@ -613,29 +616,60 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() end @testset "given only partial parameters" begin - # In this case, we expect `ParamsInit` to use the value of x, and - # generate a new value for y. my_x = 1.0 params_nt = (; x=my_x) params_dict = Dict(@varname(x) => my_x) model = test_init_model() - for empty_vi in empty_varinfos - _, vi = DynamicPPL.init!!( - Xoshiro(468), model, deepcopy(empty_vi), ParamsInit(params_nt) - ) - @test vi[@varname(x)] == my_x - nt_y = vi[@varname(y)] - @test nt_y isa AbstractVector{<:Real} - @test length(nt_y) == 2 - _, vi = DynamicPPL.init!!( - Xoshiro(469), model, deepcopy(empty_vi), ParamsInit(params_dict) - ) - @test vi[@varname(x)] == my_x - dict_y = vi[@varname(y)] - @test dict_y isa AbstractVector{<:Real} - @test length(dict_y) == 2 - # the values should be different since we used different seeds - @test dict_y != nt_y + @testset "$vi_name" for (vi_name, empty_vi) in empty_varinfos + @testset "with PriorInit fallback" begin + _, vi = DynamicPPL.init!!( + Xoshiro(468), + model, + deepcopy(empty_vi), + ParamsInit(params_nt, PriorInit()), + ) + @test vi[@varname(x)] == my_x + nt_y = vi[@varname(y)] + @test nt_y isa AbstractVector{<:Real} + @test length(nt_y) == 2 + _, vi = DynamicPPL.init!!( + Xoshiro(469), + model, + deepcopy(empty_vi), + ParamsInit(params_dict, PriorInit()), + ) + @test vi[@varname(x)] == my_x + dict_y = vi[@varname(y)] + @test dict_y isa AbstractVector{<:Real} + @test length(dict_y) == 2 + # the values should be different since we used different seeds + @test dict_y != nt_y + end + + @testset "with no fallback" begin + # These just don't have an entry for `y`. + @test_throws ErrorException DynamicPPL.init!!( + model, deepcopy(empty_vi), ParamsInit(params_nt, nothing) + ) + @test_throws ErrorException DynamicPPL.init!!( + model, deepcopy(empty_vi), ParamsInit(params_dict, nothing) + ) + # We also explicitly test the case where `y = missing`. + params_nt_missing = (; x=my_x, y=missing) + params_dict_missing = Dict( + @varname(x) => my_x, @varname(y) => missing + ) + @test_throws ErrorException DynamicPPL.init!!( + model, + deepcopy(empty_vi), + ParamsInit(params_nt_missing, nothing), + ) + @test_throws ErrorException DynamicPPL.init!!( + model, + deepcopy(empty_vi), + ParamsInit(params_dict_missing, nothing), + ) + end end end end From bcfdd931331e12311b46276aafa728dd68326e6f Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 12 Aug 2025 18:32:59 +0100 Subject: [PATCH 10/13] Improve docstrings --- src/contexts/init.jl | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/contexts/init.jl b/src/contexts/init.jl index 0ca14c85f..cb4f24759 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -3,6 +3,9 @@ Abstract type representing the possible ways of initialising new values for the random variables in a model (e.g., when creating a new VarInfo). + +Any subtype of `AbstractInitStrategy` must implement the +[`DynamicPPL.init`](@ref) method. """ abstract type AbstractInitStrategy end @@ -11,8 +14,8 @@ abstract type AbstractInitStrategy end Generate a new value for a random variable with the given distribution. -!!! warning "Values must be unlinked" - The values returned by `init` are always in the untransformed space, i.e., +!!! warning "Return values must be unlinked" + The values returned by `init` must always be in the untransformed space, i.e., they must be within the support of the original distribution. That means that, for example, `init(rng, dist, u::UniformInit)` will in general return values that are outside the range [u.lower, u.upper]. From 9626ae96d7f6dfaa4c104726c9383aaaa775bfb5 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 13 Aug 2025 01:24:47 +0100 Subject: [PATCH 11/13] typo --- test/contexts.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/contexts.jl b/test/contexts.jl index b0b5fb8db..93dc52d88 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -598,7 +598,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() params_nt = (; x=my_x, y=my_y) params_dict = Dict(@varname(x) => my_x, @varname(y) => my_y) model = test_init_model() - for empty_vi in empty_varinfos + @testset "$vi_name" for (vi_name, empty_vi) in empty_varinfos _, vi = DynamicPPL.init!!( model, deepcopy(empty_vi), ParamsInit(params_nt) ) From 1e4ce9f75ab54a16b34d0f3212fe320b3ebeb6f1 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 13 Aug 2025 12:04:53 +0100 Subject: [PATCH 12/13] `p.default` -> `p.fallback` --- src/contexts/init.jl | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/src/contexts/init.jl b/src/contexts/init.jl index cb4f24759..75807ae32 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -73,16 +73,16 @@ end """ ParamsInit( params::Union{AbstractDict{<:VarName},NamedTuple}, - default::Union{AbstractInitStrategy,Nothing}=PriorInit() + fallback::Union{AbstractInitStrategy,Nothing}=PriorInit() ) Obtain new values by extracting them from the given dictionary or NamedTuple. -The parameter `default` specifies how new values are to be obtained if they -cannot be found in `params`, or they are specified as `missing`. `default` +The parameter `fallback` specifies how new values are to be obtained if they +cannot be found in `params`, or they are specified as `missing`. `fallback` can either be an initialisation strategy itself, in which case it will be used to obtain new values, or it can be `nothing`, in which case an error -will be thrown. The default for `default` is `PriorInit()`. +will be thrown. The default for `fallback` is `PriorInit()`. !!! note The values in `params` must be provided in the space of the untransformed @@ -90,17 +90,17 @@ distribution. """ struct ParamsInit{P,S<:Union{AbstractInitStrategy,Nothing}} <: AbstractInitStrategy params::P - default::S + fallback::S function ParamsInit( - params::AbstractDict{<:VarName}, default::Union{AbstractInitStrategy,Nothing} + params::AbstractDict{<:VarName}, fallback::Union{AbstractInitStrategy,Nothing} ) - return new{typeof(params),typeof(default)}(params, default) + return new{typeof(params),typeof(fallback)}(params, fallback) end ParamsInit(params::AbstractDict{<:VarName}) = ParamsInit(params, PriorInit()) function ParamsInit( - params::NamedTuple, default::Union{AbstractInitStrategy,Nothing}=PriorInit() + params::NamedTuple, fallback::Union{AbstractInitStrategy,Nothing}=PriorInit() ) - return ParamsInit(to_varname_dict(params), default) + return ParamsInit(to_varname_dict(params), fallback) end end function init(rng::Random.AbstractRNG, vn::VarName, dist::Distribution, p::ParamsInit) @@ -111,17 +111,17 @@ function init(rng::Random.AbstractRNG, vn::VarName, dist::Distribution, p::Param return if hasvalue(p.params, vn, dist) x = getvalue(p.params, vn, dist) if x === missing - p.default === nothing && + p.fallback === nothing && error("A `missing` value was provided for the variable `$(vn)`.") - init(rng, vn, dist, p.default) + init(rng, vn, dist, p.fallback) else # TODO(penelopeysm): Since x is user-supplied, maybe we could also # check here that the type / size of x matches the dist? x end else - p.default === nothing && error("No value was provided for the variable `$(vn)`.") - init(rng, vn, dist, p.default) + p.fallback === nothing && error("No value was provided for the variable `$(vn)`.") + init(rng, vn, dist, p.fallback) end end From 6705d7b31da0718aedbc8cf0a6063f48515076c4 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 13 Aug 2025 17:05:43 +0100 Subject: [PATCH 13/13] Rename `{Prior,Uniform,Params}Init` -> `InitFrom{Prior,Uniform,Params}` --- docs/src/api.md | 6 ++--- src/DynamicPPL.jl | 6 ++--- src/contexts/init.jl | 52 ++++++++++++++++++++++++-------------------- src/model.jl | 10 +++++---- test/contexts.jl | 50 +++++++++++++++++++++--------------------- 5 files changed, 65 insertions(+), 59 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index 11e999451..c6244b75f 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -474,9 +474,9 @@ To accomplish this, an initialisation _strategy_ is required, which defines how There are three concrete strategies provided in DynamicPPL: ```@docs -PriorInit -UniformInit -ParamsInit +InitFromPrior +InitFromUniform +InitFromParams ``` If you wish to write your own, you have to subtype [`DynamicPPL.AbstractInitStrategy`](@ref) and implement the `init` method. diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 067ef327a..859c7d49d 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -111,9 +111,9 @@ export AbstractVarInfo, # Initialisation InitContext, AbstractInitStrategy, - PriorInit, - UniformInit, - ParamsInit, + InitFromPrior, + InitFromUniform, + InitFromParams, # Pseudo distributions NamedDist, NoDist, diff --git a/src/contexts/init.jl b/src/contexts/init.jl index 75807ae32..636847117 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -17,22 +17,24 @@ Generate a new value for a random variable with the given distribution. !!! warning "Return values must be unlinked" The values returned by `init` must always be in the untransformed space, i.e., they must be within the support of the original distribution. That means that, - for example, `init(rng, dist, u::UniformInit)` will in general return values that + for example, `init(rng, dist, u::InitFromUniform)` will in general return values that are outside the range [u.lower, u.upper]. """ function init end """ - PriorInit() + InitFromPrior() Obtain new values by sampling from the prior distribution. """ -struct PriorInit <: AbstractInitStrategy end -init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, ::PriorInit) = rand(rng, dist) +struct InitFromPrior <: AbstractInitStrategy end +function init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, ::InitFromPrior) + return rand(rng, dist) +end """ - UniformInit() - UniformInit(lower, upper) + InitFromUniform() + InitFromUniform(lower, upper) Obtain new values by first transforming the distribution of the random variable to unconstrained space, then sampling a value uniformly between `lower` and @@ -47,17 +49,17 @@ Requires that `lower <= upper`. [Stan reference manual page on initialization](https://mc-stan.org/docs/reference-manual/execution.html#initialization) """ -struct UniformInit{T<:AbstractFloat} <: AbstractInitStrategy +struct InitFromUniform{T<:AbstractFloat} <: AbstractInitStrategy lower::T upper::T - function UniformInit(lower::T, upper::T) where {T<:AbstractFloat} + function InitFromUniform(lower::T, upper::T) where {T<:AbstractFloat} lower > upper && throw(ArgumentError("`lower` must be less than or equal to `upper`")) return new{T}(lower, upper) end - UniformInit() = UniformInit(-2.0, 2.0) + InitFromUniform() = InitFromUniform(-2.0, 2.0) end -function init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, u::UniformInit) +function init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, u::InitFromUniform) b = Bijectors.bijector(dist) sz = Bijectors.output_size(b, size(dist)) y = u.lower .+ ((u.upper - u.lower) .* rand(rng, sz...)) @@ -71,9 +73,9 @@ function init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, u::Uniform end """ - ParamsInit( + InitFromParams( params::Union{AbstractDict{<:VarName},NamedTuple}, - fallback::Union{AbstractInitStrategy,Nothing}=PriorInit() + fallback::Union{AbstractInitStrategy,Nothing}=InitFromPrior() ) Obtain new values by extracting them from the given dictionary or NamedTuple. @@ -82,28 +84,30 @@ The parameter `fallback` specifies how new values are to be obtained if they cannot be found in `params`, or they are specified as `missing`. `fallback` can either be an initialisation strategy itself, in which case it will be used to obtain new values, or it can be `nothing`, in which case an error -will be thrown. The default for `fallback` is `PriorInit()`. +will be thrown. The default for `fallback` is `InitFromPrior()`. !!! note The values in `params` must be provided in the space of the untransformed -distribution. + distribution. """ -struct ParamsInit{P,S<:Union{AbstractInitStrategy,Nothing}} <: AbstractInitStrategy +struct InitFromParams{P,S<:Union{AbstractInitStrategy,Nothing}} <: AbstractInitStrategy params::P fallback::S - function ParamsInit( + function InitFromParams( params::AbstractDict{<:VarName}, fallback::Union{AbstractInitStrategy,Nothing} ) return new{typeof(params),typeof(fallback)}(params, fallback) end - ParamsInit(params::AbstractDict{<:VarName}) = ParamsInit(params, PriorInit()) - function ParamsInit( - params::NamedTuple, fallback::Union{AbstractInitStrategy,Nothing}=PriorInit() + function InitFromParams(params::AbstractDict{<:VarName}) + return InitFromParams(params, InitFromPrior()) + end + function InitFromParams( + params::NamedTuple, fallback::Union{AbstractInitStrategy,Nothing}=InitFromPrior() ) - return ParamsInit(to_varname_dict(params), fallback) + return InitFromParams(to_varname_dict(params), fallback) end end -function init(rng::Random.AbstractRNG, vn::VarName, dist::Distribution, p::ParamsInit) +function init(rng::Random.AbstractRNG, vn::VarName, dist::Distribution, p::InitFromParams) # TODO(penelopeysm): It would be nice to do a check to make sure that all # of the parameters in `p.params` were actually used, and either warn or # error if they aren't. This is actually quite non-trivial though because @@ -128,7 +132,7 @@ end """ InitContext( [rng::Random.AbstractRNG=Random.default_rng()], - [strategy::AbstractInitStrategy=PriorInit()], + [strategy::AbstractInitStrategy=InitFromPrior()], ) A leaf context that indicates that new values for random variables are @@ -140,11 +144,11 @@ struct InitContext{R<:Random.AbstractRNG,S<:AbstractInitStrategy} <: AbstractCon rng::R strategy::S function InitContext( - rng::Random.AbstractRNG, strategy::AbstractInitStrategy=PriorInit() + rng::Random.AbstractRNG, strategy::AbstractInitStrategy=InitFromPrior() ) return new{typeof(rng),typeof(strategy)}(rng, strategy) end - function InitContext(strategy::AbstractInitStrategy=PriorInit()) + function InitContext(strategy::AbstractInitStrategy=InitFromPrior()) return InitContext(Random.default_rng(), strategy) end end diff --git a/src/model.jl b/src/model.jl index 0c51cab0e..e7a1a864f 100644 --- a/src/model.jl +++ b/src/model.jl @@ -894,7 +894,7 @@ end [rng::Random.AbstractRNG,] model::Model, varinfo::AbstractVarInfo, - [init_strategy::AbstractInitStrategy=PriorInit()] + [init_strategy::AbstractInitStrategy=InitFromPrior()] ) Evaluate the `model` and replace the values of the model's random variables in @@ -902,7 +902,7 @@ the given `varinfo` with new values using a specified initialisation strategy. If the values in `varinfo` are not already present, they will be added using that same strategy. -If `init_strategy` is not provided, defaults to PriorInit(). +If `init_strategy` is not provided, defaults to InitFromPrior(). Returns a tuple of the model's return value, plus the updated `varinfo` object. """ @@ -910,14 +910,16 @@ function init!!( rng::Random.AbstractRNG, model::Model, varinfo::AbstractVarInfo, - init_strategy::AbstractInitStrategy=PriorInit(), + init_strategy::AbstractInitStrategy=InitFromPrior(), ) new_context = setleafcontext(model.context, InitContext(rng, init_strategy)) new_model = contextualize(model, new_context) return evaluate!!(new_model, varinfo) end function init!!( - model::Model, varinfo::AbstractVarInfo, init_strategy::AbstractInitStrategy=PriorInit() + model::Model, + varinfo::AbstractVarInfo, + init_strategy::AbstractInitStrategy=InitFromPrior(), ) return init!!(Random.default_rng(), model, varinfo, init_strategy) end diff --git a/test/contexts.jl b/test/contexts.jl index 93dc52d88..365865e7e 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -537,28 +537,28 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() end end - @testset "PriorInit" begin - test_generating_new_values(PriorInit()) - test_replacing_values(PriorInit()) - test_rng_respected(PriorInit()) - test_link_status_respected(PriorInit()) + @testset "InitFromPrior" begin + test_generating_new_values(InitFromPrior()) + test_replacing_values(InitFromPrior()) + test_rng_respected(InitFromPrior()) + test_link_status_respected(InitFromPrior()) @testset "check that values are within support" begin # Not many other sensible checks we can do for priors. @model just_unif() = x ~ Uniform(0.0, 1e-7) for _ in 1:100 - _, vi = DynamicPPL.init!!(just_unif(), VarInfo(), PriorInit()) + _, vi = DynamicPPL.init!!(just_unif(), VarInfo(), InitFromPrior()) @test vi[@varname(x)] isa Real @test 0.0 <= vi[@varname(x)] <= 1e-7 end end end - @testset "UniformInit" begin - test_generating_new_values(UniformInit()) - test_replacing_values(UniformInit()) - test_rng_respected(UniformInit()) - test_link_status_respected(UniformInit()) + @testset "InitFromUniform" begin + test_generating_new_values(InitFromUniform()) + test_replacing_values(InitFromUniform()) + test_rng_respected(InitFromUniform()) + test_link_status_respected(InitFromUniform()) @testset "check that bounds are respected" begin @testset "unconstrained" begin @@ -566,7 +566,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() @model just_norm() = x ~ Normal() for _ in 1:100 _, vi = DynamicPPL.init!!( - just_norm(), VarInfo(), UniformInit(umin, umax) + just_norm(), VarInfo(), InitFromUniform(umin, umax) ) @test vi[@varname(x)] isa Real @test umin <= vi[@varname(x)] <= umax @@ -579,7 +579,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() tmin, tmax = inv_bijector(umin), inv_bijector(umax) for _ in 1:100 _, vi = DynamicPPL.init!!( - just_beta(), VarInfo(), UniformInit(umin, umax) + just_beta(), VarInfo(), InitFromUniform(umin, umax) ) @test vi[@varname(x)] isa Real @test tmin <= vi[@varname(x)] <= tmax @@ -588,9 +588,9 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() end end - @testset "ParamsInit" begin - test_link_status_respected(ParamsInit((; a=1.0))) - test_link_status_respected(ParamsInit(Dict(@varname(a) => 1.0))) + @testset "InitFromParams" begin + test_link_status_respected(InitFromParams((; a=1.0))) + test_link_status_respected(InitFromParams(Dict(@varname(a) => 1.0))) @testset "given full set of parameters" begin # test_init_model has x ~ Normal() and y ~ MvNormal(zeros(2), I) @@ -600,13 +600,13 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() model = test_init_model() @testset "$vi_name" for (vi_name, empty_vi) in empty_varinfos _, vi = DynamicPPL.init!!( - model, deepcopy(empty_vi), ParamsInit(params_nt) + model, deepcopy(empty_vi), InitFromParams(params_nt) ) @test vi[@varname(x)] == my_x @test vi[@varname(y)] == my_y logp_nt = getlogp(vi) _, vi = DynamicPPL.init!!( - model, deepcopy(empty_vi), ParamsInit(params_dict) + model, deepcopy(empty_vi), InitFromParams(params_dict) ) @test vi[@varname(x)] == my_x @test vi[@varname(y)] == my_y @@ -621,12 +621,12 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() params_dict = Dict(@varname(x) => my_x) model = test_init_model() @testset "$vi_name" for (vi_name, empty_vi) in empty_varinfos - @testset "with PriorInit fallback" begin + @testset "with InitFromPrior fallback" begin _, vi = DynamicPPL.init!!( Xoshiro(468), model, deepcopy(empty_vi), - ParamsInit(params_nt, PriorInit()), + InitFromParams(params_nt, InitFromPrior()), ) @test vi[@varname(x)] == my_x nt_y = vi[@varname(y)] @@ -636,7 +636,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() Xoshiro(469), model, deepcopy(empty_vi), - ParamsInit(params_dict, PriorInit()), + InitFromParams(params_dict, InitFromPrior()), ) @test vi[@varname(x)] == my_x dict_y = vi[@varname(y)] @@ -649,10 +649,10 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() @testset "with no fallback" begin # These just don't have an entry for `y`. @test_throws ErrorException DynamicPPL.init!!( - model, deepcopy(empty_vi), ParamsInit(params_nt, nothing) + model, deepcopy(empty_vi), InitFromParams(params_nt, nothing) ) @test_throws ErrorException DynamicPPL.init!!( - model, deepcopy(empty_vi), ParamsInit(params_dict, nothing) + model, deepcopy(empty_vi), InitFromParams(params_dict, nothing) ) # We also explicitly test the case where `y = missing`. params_nt_missing = (; x=my_x, y=missing) @@ -662,12 +662,12 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() @test_throws ErrorException DynamicPPL.init!!( model, deepcopy(empty_vi), - ParamsInit(params_nt_missing, nothing), + InitFromParams(params_nt_missing, nothing), ) @test_throws ErrorException DynamicPPL.init!!( model, deepcopy(empty_vi), - ParamsInit(params_dict_missing, nothing), + InitFromParams(params_dict_missing, nothing), ) end end