From 2a1ebff3bd08935d519d8a75e01c3533dd1518c9 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 9 Jul 2025 23:20:10 +0100 Subject: [PATCH 01/20] Implement InitContext --- src/DynamicPPL.jl | 7 ++ src/contexts/init.jl | 180 +++++++++++++++++++++++++++++++++++++++++++ src/model.jl | 33 ++++++++ test/contexts.jl | 12 ++- 4 files changed, 231 insertions(+), 1 deletion(-) create mode 100644 src/contexts/init.jl diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index f190c7605..ff0223457 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -109,6 +109,12 @@ export AbstractVarInfo, ConditionContext, assume, tilde_assume, + # Initialisation + InitContext, + AbstractInitStrategy, + PriorInit, + UniformInit, + ParamsInit, # Pseudo distributions NamedDist, NoDist, @@ -175,6 +181,7 @@ include("sampler.jl") include("varname.jl") include("distribution_wrappers.jl") include("contexts.jl") +include("contexts/init.jl") include("submodel.jl") include("varnamedvector.jl") include("accumulators.jl") diff --git a/src/contexts/init.jl b/src/contexts/init.jl new file mode 100644 index 000000000..580b1a666 --- /dev/null +++ b/src/contexts/init.jl @@ -0,0 +1,180 @@ +""" + AbstractInitStrategy + +Abstract type representing the possible ways of initialising new values for +the random variables in a model (e.g., when creating a new VarInfo). +""" +abstract type AbstractInitStrategy end + +""" + init(rng::Random.AbstractRNG, vn::VarName, dist::Distribution, strategy::AbstractInitStrategy) + +Generate a new value for a random variable with the given distribution. + +!!! warning "Values must be unlinked" + The values returned by `init` are always in the untransformed space, i.e., + they must be within the support of the original distribution. That means that, + for example, `init(rng, dist, u::UniformInit)` will in general return values that + are outside the range [u.lower, u.upper]. +""" +function init end + +""" + PriorInit() + +Obtain new values by sampling from the prior distribution. +""" +struct PriorInit <: AbstractInitStrategy end +init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, ::PriorInit) = rand(rng, dist) + +""" + UniformInit() + UniformInit(lower, upper) + +Obtain new values by first transforming the distribution of the random variable +to unconstrained space, and then sampling a value uniformly between `lower` and +`upper`. + +If unspecified, defaults to `(lower, upper) = (-2, 2)`, which mimics Stan's +default initialisation strategy. + +# References + +[Stan reference manual page on initialization](https://mc-stan.org/docs/reference-manual/execution.html#initialization) +""" +struct UniformInit{T<:AbstractFloat} <: AbstractInitStrategy + lower::T + upper::T + function UniformInit(lower::T, upper::T) where {T<:AbstractFloat} + lower > upper && + throw(ArgumentError("`lower` must be less than or equal to `upper`")) + return new{T}(lower, upper) + end + UniformInit() = UniformInit(-2.0, 2.0) +end +function init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, u::UniformInit) + b = Bijectors.bijector(dist) + sz = Bijectors.output_size(b, size(dist)) + y = rand(rng, Uniform(u.lower, u.upper), sz) + b_inv = Bijectors.inverse(b) + x = b_inv(y) + # 0-dim arrays: https://github.com/TuringLang/Bijectors.jl/issues/398 + if x isa Array{<:Any,0} + x = x[] + end + return x +end + +""" + ParamsInit(params::AbstractDict{<:VarName}, default::AbstractInitStrategy=PriorInit()) + ParamsInit(params::NamedTuple, default::AbstractInitStrategy=PriorInit()) + +Obtain new values by extracting them from the given dictionary or NamedTuple. +The parameter `default` specifies how new values are to be obtained if they +cannot be found in `params`, or they are specified as `missing`. The default +for `default` is `PriorInit()`. + +!!! note + These values must be provided in the space of the untransformed distribution. +""" +struct ParamsInit{P,S<:AbstractInitStrategy} <: AbstractInitStrategy + params::P + default::S + function ParamsInit(params::AbstractDict{<:VarName}, default::AbstractInitStrategy) + return new{typeof(params),typeof(default)}(params, default) + end + ParamsInit(params::AbstractDict{<:VarName}) = ParamsInit(params, PriorInit()) + function ParamsInit(params::NamedTuple, default::AbstractInitStrategy=PriorInit()) + return ParamsInit(to_varname_dict(params), default) + end +end +function init(rng::Random.AbstractRNG, vn::VarName, dist::Distribution, p::ParamsInit) + # TODO(penelopeysm): We should do a check to make sure that all of the + # parameters in `p.params` were actually used, and either warn or error if + # they aren't. This is non-trivial (we need to use something like + # varname_leaves), so I'm going to defer it to a later PR. + return if hasvalue(p.params, vn, dist) + x = getvalue(p.params, vn, dist) + if x === missing + init(rng, vn, dist, p.default) + else + # TODO(penelopeysm): We could also check that the type of x matches + # the dist? + x + end + else + init(rng, vn, dist, p.default) + end +end + +""" + InitContext( + [rng::Random.AbstractRNG=Random.default_rng()], + [strategy::AbstractInitStrategy=PriorInit()], + ) + +A leaf context that indicates that new values for random variables are +currently being obtained through sampling. Used e.g. when initialising a fresh +VarInfo. Note that, if `leafcontext(model.context) isa InitContext`, then +`evaluate!!(model, varinfo)` will override all values in the VarInfo. +""" +struct InitContext{R<:Random.AbstractRNG,S<:AbstractInitStrategy} <: AbstractContext + rng::R + strategy::S + function InitContext( + rng::Random.AbstractRNG, strategy::AbstractInitStrategy=PriorInit() + ) + return new{typeof(rng),typeof(strategy)}(rng, strategy) + end + function InitContext(strategy::AbstractInitStrategy=PriorInit()) + return InitContext(Random.default_rng(), strategy) + end +end +NodeTrait(::InitContext) = IsLeaf() + +function tilde_assume( + ctx::InitContext, dist::Distribution, vn::VarName, vi::AbstractVarInfo +) + in_varinfo = haskey(vi, vn) + # `init()` always returns values in original space, i.e. possibly + # constrained + x = init(ctx.rng, vn, dist, ctx.strategy) + # Determine whether to insert a transformed value into the VarInfo. + # If the VarInfo alrady had a value for this variable, we will + # keep the same linked status as in the original VarInfo. If not, we + # check the rest of the VarInfo to see if other variables are linked. + # istrans(vi) returns true if vi is nonempty and all variables in vi + # are linked. + insert_transformed_value = in_varinfo ? istrans(vi, vn) : istrans(vi) + f = if insert_transformed_value + to_linked_internal_transform(vi, vn, dist) + else + to_internal_transform(vi, vn, dist) + end + # TODO(penelopeysm): We would really like to do: + # y, logjac = with_logabsdet_jacobian(f, x) + # Unfortunately, `to_{linked_}internal_transform` returns a function that + # always converts x to a vector, i.e., if dist is univariate, f(x) will be + # a vector of length 1. It would be nice if we could unify these. + y = f(x) + logjac = logabsdetjac(insert_transformed_value ? link_transform(dist) : identity, x) + # Add the new value to the VarInfo. `push!!` errors if the value already + # exists, hence the need for setindex!!. + if in_varinfo + vi = setindex!!(vi, y, vn) + else + vi = push!!(vi, vn, y, dist) + end + # Neither of these set the `trans` flag so we have to do it manually if + # necessary. + insert_transformed_value && settrans!!(vi, true, vn) + # `accumulate_assume!!` wants untransformed values as the second argument. + vi = accumulate_assume!!(vi, x, -logjac, vn, dist) + # We always return the untransformed value here, as that will determine + # what the lhs of the tilde-statement is set to. + return x, vi +end + +function tilde_observe!!(::InitContext, right, left, vn, vi) + return tilde_observe!!(DefaultContext(), right, left, vn, vi) +end diff --git a/src/model.jl b/src/model.jl index ac9968cf2..da01f3a1a 100644 --- a/src/model.jl +++ b/src/model.jl @@ -854,6 +854,39 @@ function evaluate_and_sample!!( return evaluate_and_sample!!(Random.default_rng(), model, varinfo, sampler) end +""" + init!!( + [rng::Random.AbstractRNG,] + model::Model, + varinfo::AbstractVarInfo, + [init_strategy::AbstractInitStrategy=PriorInit()] + ) + +Evaluate the `model` and replace the values of the model's random variables in +the given `varinfo` with new values using a specified initialisation strategy. +If the values in `varinfo` are not already present, they will be added using +that same strategy. + +If `init_strategy` is not provided, defaults to PriorInit(). + +Returns a tuple of the model's return value, plus the updated `varinfo` object. +""" +function init!!( + rng::Random.AbstractRNG, + model::Model, + varinfo::AbstractVarInfo, + init_strategy::AbstractInitStrategy=PriorInit(), +) + new_context = setleafcontext(model.context, InitContext(rng, init_strategy)) + new_model = contextualize(model, new_context) + return evaluate!!(new_model, varinfo) +end +function init!!( + model::Model, varinfo::AbstractVarInfo, init_strategy::AbstractInitStrategy=PriorInit() +) + return init!!(Random.default_rng(), model, varinfo, init_strategy) +end + """ evaluate!!(model::Model, varinfo) diff --git a/test/contexts.jl b/test/contexts.jl index 597ab736c..be976aad4 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -1,5 +1,5 @@ using Test, DynamicPPL, Accessors -using AbstractPPL: getoptic +using AbstractPPL: getoptic, hasvalue, getvalue using DynamicPPL: leafcontext, setleafcontext, @@ -431,4 +431,14 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() @test fixed(c6) == Dict(@varname(a.b.d) => 2) end end + + @testset "InitContext" begin + @testset "PriorInit" begin end + + @testset "UniformInit" begin end + + @testset "ParamsInit" begin end + + @testset "rng is respected (at least with PriorInit" begin end + end end From 06d0beb6d11a0c0aaf2363babe64cfe3bd63842e Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 9 Jul 2025 23:25:10 +0100 Subject: [PATCH 02/20] Fix loading order of modules; move `prefix(::Model)` to model.jl --- src/DynamicPPL.jl | 4 ++-- src/contexts.jl | 35 ----------------------------------- src/model.jl | 35 +++++++++++++++++++++++++++++++++++ 3 files changed, 37 insertions(+), 37 deletions(-) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index ff0223457..6050ce344 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -176,12 +176,12 @@ abstract type AbstractVarInfo <: AbstractModelTrace end # Necessary forward declarations include("utils.jl") include("chains.jl") +include("contexts.jl") +include("contexts/init.jl") include("model.jl") include("sampler.jl") include("varname.jl") include("distribution_wrappers.jl") -include("contexts.jl") -include("contexts/init.jl") include("submodel.jl") include("varnamedvector.jl") include("accumulators.jl") diff --git a/src/contexts.jl b/src/contexts.jl index addadfa1a..cd9876768 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -280,41 +280,6 @@ function prefix_and_strip_contexts(::IsParent, ctx::AbstractContext, vn::VarName return vn, setchildcontext(ctx, new_ctx) end -""" - prefix(model::Model, x::VarName) - prefix(model::Model, x::Val{sym}) - prefix(model::Model, x::Any) - -Return `model` but with all random variables prefixed by `x`, where `x` is either: -- a `VarName` (e.g. `@varname(a)`), -- a `Val{sym}` (e.g. `Val(:a)`), or -- for any other type, `x` is converted to a Symbol and then to a `VarName`. Note that - this will introduce runtime overheads so is not recommended unless absolutely - necessary. - -# Examples - -```jldoctest -julia> using DynamicPPL: prefix - -julia> @model demo() = x ~ Dirac(1) -demo (generic function with 2 methods) - -julia> rand(prefix(demo(), @varname(my_prefix))) -(var"my_prefix.x" = 1,) - -julia> rand(prefix(demo(), Val(:my_prefix))) -(var"my_prefix.x" = 1,) -``` -""" -prefix(model::Model, x::VarName) = contextualize(model, PrefixContext(x, model.context)) -function prefix(model::Model, x::Val{sym}) where {sym} - return contextualize(model, PrefixContext(VarName{sym}(), model.context)) -end -function prefix(model::Model, x) - return contextualize(model, PrefixContext(VarName{Symbol(x)}(), model.context)) -end - """ ConditionContext{Values<:Union{NamedTuple,AbstractDict},Ctx<:AbstractContext} diff --git a/src/model.jl b/src/model.jl index da01f3a1a..5a07129a0 100644 --- a/src/model.jl +++ b/src/model.jl @@ -799,6 +799,41 @@ julia> # Now `a.x` will be sampled. """ fixed(model::Model) = fixed(model.context) +""" + prefix(model::Model, x::VarName) + prefix(model::Model, x::Val{sym}) + prefix(model::Model, x::Any) + +Return `model` but with all random variables prefixed by `x`, where `x` is either: +- a `VarName` (e.g. `@varname(a)`), +- a `Val{sym}` (e.g. `Val(:a)`), or +- for any other type, `x` is converted to a Symbol and then to a `VarName`. Note that + this will introduce runtime overheads so is not recommended unless absolutely + necessary. + +# Examples + +```jldoctest +julia> using DynamicPPL: prefix + +julia> @model demo() = x ~ Dirac(1) +demo (generic function with 2 methods) + +julia> rand(prefix(demo(), @varname(my_prefix))) +(var"my_prefix.x" = 1,) + +julia> rand(prefix(demo(), Val(:my_prefix))) +(var"my_prefix.x" = 1,) +``` +""" +prefix(model::Model, x::VarName) = contextualize(model, PrefixContext(x, model.context)) +function prefix(model::Model, x::Val{sym}) where {sym} + return contextualize(model, PrefixContext(VarName{sym}(), model.context)) +end +function prefix(model::Model, x) + return contextualize(model, PrefixContext(VarName{Symbol(x)}(), model.context)) +end + """ (model::Model)([rng, varinfo]) From 1f7017a1711cbf4764eaf890c389cb4ea87f8348 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 10 Jul 2025 00:21:53 +0100 Subject: [PATCH 03/20] Add tests for InitContext behaviour --- src/contexts/init.jl | 12 +-- test/contexts.jl | 183 +++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 179 insertions(+), 16 deletions(-) diff --git a/src/contexts/init.jl b/src/contexts/init.jl index 580b1a666..6ff276d21 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -147,17 +147,11 @@ function tilde_assume( # are linked. insert_transformed_value = in_varinfo ? istrans(vi, vn) : istrans(vi) f = if insert_transformed_value - to_linked_internal_transform(vi, vn, dist) + link_transform(dist) else - to_internal_transform(vi, vn, dist) + identity end - # TODO(penelopeysm): We would really like to do: - # y, logjac = with_logabsdet_jacobian(f, x) - # Unfortunately, `to_{linked_}internal_transform` returns a function that - # always converts x to a vector, i.e., if dist is univariate, f(x) will be - # a vector of length 1. It would be nice if we could unify these. - y = f(x) - logjac = logabsdetjac(insert_transformed_value ? link_transform(dist) : identity, x) + y, logjac = with_logabsdet_jacobian(f, x) # Add the new value to the VarInfo. `push!!` errors if the value already # exists, hence the need for setindex!!. if in_varinfo diff --git a/test/contexts.jl b/test/contexts.jl index be976aad4..5768757bb 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -20,8 +20,9 @@ using DynamicPPL: hasconditioned_nested, getconditioned_nested, collapse_prefix_stack, - prefix_cond_and_fixed_variables, - getvalue + prefix_cond_and_fixed_variables +using LinearAlgebra: I +using Random: Xoshiro using EnzymeCore @@ -103,7 +104,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() # sometimes only the main symbol (e.g. it contains `x` when # `vn` is `x[1]`) for vn in conditioned_vns - val = DynamicPPL.getvalue(conditioned_values, vn) + val = getvalue(conditioned_values, vn) # These VarNames are present in the conditioning values, so # we should always be able to extract the value. @test hasconditioned_nested(context, vn) @@ -433,12 +434,180 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() end @testset "InitContext" begin - @testset "PriorInit" begin end + empty_varinfos = [ + VarInfo(), + DynamicPPL.typed_varinfo(VarInfo()), + VarInfo(DynamicPPL.VarNamedVector()), + DynamicPPL.typed_vector_varinfo(DynamicPPL.typed_varinfo(VarInfo())), + SimpleVarInfo(), + SimpleVarInfo(Dict{VarName,Any}()), + ] + + @model function test_init_model() + x ~ Normal() + y ~ MvNormal(fill(x, 2), I) + 1.0 ~ Normal() + return nothing + end + function test_generating_new_values(strategy::AbstractInitStrategy) + @testset "generating new values: $(typeof(strategy))" begin + # Check that init!! can generate values that weren't there + # previously. + model = test_init_model() + for empty_vi in empty_varinfos + this_vi = deepcopy(empty_vi) + _, vi = DynamicPPL.init!!(model, this_vi, strategy) + @test Set(keys(vi)) == Set([@varname(x), @varname(y)]) + x, y = vi[@varname(x)], vi[@varname(y)] + @test x isa Real + @test y isa AbstractVector{<:Real} + @test length(y) == 2 + (; logprior, loglikelihood) = getlogp(vi) + @test logpdf(Normal(), x) + logpdf(MvNormal(fill(x, 2), I), y) == + logprior + @test logpdf(Normal(), 1.0) == loglikelihood + end + end + end + function test_replacing_values(strategy::AbstractInitStrategy) + @testset "replacing old values: $(typeof(strategy))" begin + # Check that init!! can overwrite values that were already there. + model = test_init_model() + for empty_vi in empty_varinfos + # start by generating some rubbish values + vi = deepcopy(empty_vi) + old_x, old_y = 100000.00, [300000.00, 500000.00] + push!!(vi, @varname(x), old_x, Normal()) + push!!(vi, @varname(y), old_y, MvNormal(fill(old_x, 2), I)) + # then overwrite it + _, new_vi = DynamicPPL.init!!(model, vi, strategy) + new_x, new_y = new_vi[@varname(x)], new_vi[@varname(y)] + # check that the values are (presumably) different + @test old_x != new_x + @test old_y != new_y + end + end + end + function test_rng_respected(strategy::AbstractInitStrategy) + @testset "check that RNG is respected: $(typeof(strategy))" begin + model = test_init_model() + for empty_vi in empty_varinfos + _, vi1 = DynamicPPL.init!!( + Xoshiro(468), model, deepcopy(empty_vi), strategy + ) + _, vi2 = DynamicPPL.init!!( + Xoshiro(468), model, deepcopy(empty_vi), strategy + ) + _, vi3 = DynamicPPL.init!!( + Xoshiro(469), model, deepcopy(empty_vi), strategy + ) + @test vi1[@varname(x)] == vi2[@varname(x)] + @test vi1[@varname(y)] == vi2[@varname(y)] + @test vi1[@varname(x)] != vi3[@varname(x)] + @test vi1[@varname(y)] != vi3[@varname(y)] + end + end + end - @testset "UniformInit" begin end + @testset "PriorInit" begin + test_generating_new_values(PriorInit()) + test_replacing_values(PriorInit()) + test_rng_respected(PriorInit()) + + @testset "check that values are within support" begin + # Not many other sensible checks we can do for priors. + @model just_unif() = x ~ Uniform(0.0, 1e-7) + for _ in 1:100 + _, vi = DynamicPPL.init!!(just_unif(), VarInfo(), PriorInit()) + @test vi[@varname(x)] isa Real + @test 0.0 <= vi[@varname(x)] <= 1e-7 + end + end + end - @testset "ParamsInit" begin end + @testset "UniformInit" begin + test_generating_new_values(UniformInit()) + test_replacing_values(UniformInit()) + test_rng_respected(UniformInit()) + + @testset "check that bounds are respected" begin + @testset "unconstrained" begin + umin, umax = -1.0, 1.0 + @model just_norm() = x ~ Normal() + for _ in 1:100 + _, vi = DynamicPPL.init!!( + just_norm(), VarInfo(), UniformInit(umin, umax) + ) + @test vi[@varname(x)] isa Real + @test umin <= vi[@varname(x)] <= umax + end + end + @testset "constrained" begin + umin, umax = -1.0, 1.0 + @model just_beta() = x ~ Beta(2, 2) + inv_bijector = inverse(Bijectors.bijector(Beta(2, 2))) + tmin, tmax = inv_bijector(umin), inv_bijector(umax) + for _ in 1:100 + _, vi = DynamicPPL.init!!( + just_beta(), VarInfo(), UniformInit(umin, umax) + ) + @test vi[@varname(x)] isa Real + @test tmin <= vi[@varname(x)] <= tmax + end + end + end + end - @testset "rng is respected (at least with PriorInit" begin end + @testset "ParamsInit" begin + @testset "given full set of parameters" begin + # test_init_model has x ~ Normal() and y ~ MvNormal(zeros(2), I) + my_x, my_y = 1.0, [2.0, 3.0] + params_nt = (; x=my_x, y=my_y) + params_dict = Dict(@varname(x) => my_x, @varname(y) => my_y) + model = test_init_model() + for empty_vi in empty_varinfos + _, vi = DynamicPPL.init!!( + model, deepcopy(empty_vi), ParamsInit(params_nt) + ) + @test vi[@varname(x)] == my_x + @test vi[@varname(y)] == my_y + logp_nt = getlogp(vi) + _, vi = DynamicPPL.init!!( + model, deepcopy(empty_vi), ParamsInit(params_dict) + ) + @test vi[@varname(x)] == my_x + @test vi[@varname(y)] == my_y + logp_dict = getlogp(vi) + @test logp_nt == logp_dict + end + end + + @testset "given only partial parameters" begin + # In this case, we expect `ParamsInit` to use the value of x, and + # generate a new value for y. + my_x = 1.0 + params_nt = (; x=my_x) + params_dict = Dict(@varname(x) => my_x) + model = test_init_model() + for empty_vi in empty_varinfos + _, vi = DynamicPPL.init!!( + Xoshiro(468), model, deepcopy(empty_vi), ParamsInit(params_nt) + ) + @test vi[@varname(x)] == my_x + nt_y = vi[@varname(y)] + @test nt_y isa AbstractVector{<:Real} + @test length(nt_y) == 2 + _, vi = DynamicPPL.init!!( + Xoshiro(469), model, deepcopy(empty_vi), ParamsInit(params_dict) + ) + @test vi[@varname(x)] == my_x + dict_y = vi[@varname(y)] + @test dict_y isa AbstractVector{<:Real} + @test length(dict_y) == 2 + # the values should be different since we used different seeds + @test dict_y != nt_y + end + end + end end end From 02ae96507d970b42d0fdf11fc83329c263f98fca Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 10 Jul 2025 00:30:07 +0100 Subject: [PATCH 04/20] inline `rand(::Distributions.Uniform)` Note that, apart from being simpler code, Distributions.Uniform also doesn't allow the lower and upper bounds to be exactly equal (but we might like to keep that option open in DynamicPPL, e.g. if the user wants to initialise all values to the same value in linked space). --- src/contexts/init.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/contexts/init.jl b/src/contexts/init.jl index 6ff276d21..3b7007f51 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -38,6 +38,8 @@ to unconstrained space, and then sampling a value uniformly between `lower` and If unspecified, defaults to `(lower, upper) = (-2, 2)`, which mimics Stan's default initialisation strategy. +Requires that `lower <= upper`. + # References [Stan reference manual page on initialization](https://mc-stan.org/docs/reference-manual/execution.html#initialization) @@ -55,7 +57,7 @@ end function init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, u::UniformInit) b = Bijectors.bijector(dist) sz = Bijectors.output_size(b, size(dist)) - y = rand(rng, Uniform(u.lower, u.upper), sz) + y = u.lower .+ ((u.upper - u.lower) .* rand(rng, sz...)) b_inv = Bijectors.inverse(b) x = b_inv(y) # 0-dim arrays: https://github.com/TuringLang/Bijectors.jl/issues/398 From 55634f4a367da6d931b44b98cfc09289adb64b59 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 10 Jul 2025 00:46:43 +0100 Subject: [PATCH 05/20] Document --- docs/src/api.md | 21 +++++++++++++++++++++ src/contexts/init.jl | 20 ++++++++++---------- 2 files changed, 31 insertions(+), 10 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index 14b2447b5..026dbb999 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -464,6 +464,27 @@ SamplingContext DefaultContext PrefixContext ConditionContext +InitContext +``` + +### VarInfo initialisation + +`InitContext` is used to initialise, or overwrite, values in a VarInfo. + +To accomplish this, an initialisation _strategy_ is required, which defines how new values are to be obtained. +There are three concrete strategies provided in DynamicPPL: + +```@docs +PriorInit +UniformInit +ParamsInit +``` + +If you wish to write your own, you have to subtype [`DynamicPPL.AbstractInitStrategy`](@ref) and implement the `init` method. + +```@docs +DynamicPPL.AbstractInitStrategy +DynamicPPL.init ``` ### Samplers diff --git a/src/contexts/init.jl b/src/contexts/init.jl index 3b7007f51..2b87b533b 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -32,11 +32,11 @@ init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, ::PriorInit) = rand UniformInit(lower, upper) Obtain new values by first transforming the distribution of the random variable -to unconstrained space, and then sampling a value uniformly between `lower` and -`upper`. +to unconstrained space, then sampling a value uniformly between `lower` and +`upper`, and transforming that value back to the original space. -If unspecified, defaults to `(lower, upper) = (-2, 2)`, which mimics Stan's -default initialisation strategy. +If `lower` and `upper` are unspecified, they default to `(-2, 2)`, which mimics +Stan's default initialisation strategy. Requires that `lower <= upper`. @@ -91,17 +91,17 @@ struct ParamsInit{P,S<:AbstractInitStrategy} <: AbstractInitStrategy end end function init(rng::Random.AbstractRNG, vn::VarName, dist::Distribution, p::ParamsInit) - # TODO(penelopeysm): We should do a check to make sure that all of the - # parameters in `p.params` were actually used, and either warn or error if - # they aren't. This is non-trivial (we need to use something like - # varname_leaves), so I'm going to defer it to a later PR. + # TODO(penelopeysm): It would be nice to do a check to make sure that all + # of the parameters in `p.params` were actually used, and either warn or + # error if they aren't. This is actually quite non-trivial though because + # the structure of Dicts in particular can have arbitrary nesting. return if hasvalue(p.params, vn, dist) x = getvalue(p.params, vn, dist) if x === missing init(rng, vn, dist, p.default) else - # TODO(penelopeysm): We could also check that the type of x matches - # the dist? + # TODO(penelopeysm): Since x is user-supplied, maybe we could also + # check here that the type / size of x matches the dist? x end else From d6ba16c8b3bf6716a23f69f839704ae9d02d48b4 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sat, 19 Jul 2025 23:36:33 +0100 Subject: [PATCH 06/20] Add a test to check that `init!!` doesn't change linking --- src/contexts/init.jl | 2 +- test/contexts.jl | 26 ++++++++++++++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/src/contexts/init.jl b/src/contexts/init.jl index 2b87b533b..baa1bf088 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -165,7 +165,7 @@ function tilde_assume( # necessary. insert_transformed_value && settrans!!(vi, true, vn) # `accumulate_assume!!` wants untransformed values as the second argument. - vi = accumulate_assume!!(vi, x, -logjac, vn, dist) + vi = accumulate_assume!!(vi, x, logjac, vn, dist) # We always return the untransformed value here, as that will determine # what the lhs of the tilde-statement is set to. return x, vi diff --git a/test/contexts.jl b/test/contexts.jl index 5768757bb..e2a882f48 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -508,11 +508,33 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() end end end + function test_link_status_respected(strategy::AbstractInitStrategy) + @testset "check that varinfo linking is preserved: $(typeof(strategy))" begin + @model logn() = a ~ LogNormal() + model = logn() + vi = VarInfo(model) + linked_vi = DynamicPPL.link!!(vi, model) + _, new_vi = DynamicPPL.init!!(model, linked_vi, strategy) + @test DynamicPPL.istrans(new_vi) + # this is the unlinked value, since it uses `getindex` + a = new_vi[@varname(a)] + # internal logjoint should correspond to the transformed value + @test isapprox( + DynamicPPL.getlogjoint_internal(new_vi), logpdf(Normal(), log(a)) + ) + # user logjoint should correspond to the transformed value + @test isapprox(DynamicPPL.getlogjoint(new_vi), logpdf(LogNormal(), a)) + @test isapprox( + only(DynamicPPL.getindex_internal(new_vi, @varname(a))), log(a) + ) + end + end @testset "PriorInit" begin test_generating_new_values(PriorInit()) test_replacing_values(PriorInit()) test_rng_respected(PriorInit()) + test_link_status_respected(PriorInit()) @testset "check that values are within support" begin # Not many other sensible checks we can do for priors. @@ -529,6 +551,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() test_generating_new_values(UniformInit()) test_replacing_values(UniformInit()) test_rng_respected(UniformInit()) + test_link_status_respected(UniformInit()) @testset "check that bounds are respected" begin @testset "unconstrained" begin @@ -559,6 +582,9 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() end @testset "ParamsInit" begin + test_link_status_respected(ParamsInit((; a=1.0))) + test_link_status_respected(ParamsInit(Dict(@varname(a) => 1.0))) + @testset "given full set of parameters" begin # test_init_model has x ~ Normal() and y ~ MvNormal(zeros(2), I) my_x, my_y = 1.0, [2.0, 3.0] From fd78d42bd5fe9847bbd61e5f3e6f48d2f39fcf06 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 20 Jul 2025 01:58:18 +0100 Subject: [PATCH 07/20] Fix `push!` for VarNamedVector This should have been changed in #940, but slipped through as the file wasn't listed as one of the changed files. --- src/varnamedvector.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/varnamedvector.jl b/src/varnamedvector.jl index d756a4922..2336b89b6 100644 --- a/src/varnamedvector.jl +++ b/src/varnamedvector.jl @@ -766,6 +766,11 @@ function update_internal!( return nothing end +function BangBang.push!(vnv::VarNamedVector, vn, val, dist) + f = from_vec_transform(dist) + return setindex_internal!(vnv, tovec(val), vn, f) +end + # BangBang versions of the above functions. # The only difference is that update_internal!! and insert_internal!! check whether the # container types of the VarNamedVector vector need to be expanded to accommodate the new From d40df7eb9408dc0e94fc69dad88ed510771c7c85 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 10 Jul 2025 01:40:13 +0100 Subject: [PATCH 08/20] Replace `evaluate_and_sample!!` -> `init!!` --- docs/src/api.md | 12 +++--- src/extract_priors.jl | 2 +- src/model.jl | 42 +++--------------- src/sampler.jl | 3 +- src/simple_varinfo.jl | 48 ++++++++++++--------- src/test_utils/contexts.jl | 72 ++++++++++++++++++++----------- src/test_utils/model_interface.jl | 4 +- src/varinfo.jl | 68 ++++++++++++++++------------- test/compiler.jl | 13 +++--- test/contexts.jl | 19 ++++---- test/model.jl | 25 +---------- test/sampler.jl | 4 +- test/simple_varinfo.jl | 8 ++-- test/varinfo.jl | 31 +++++++------ test/varnamedvector.jl | 4 +- 15 files changed, 166 insertions(+), 189 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index 026dbb999..7b0e7f927 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -450,11 +450,6 @@ AbstractPPL.evaluate!! This method mutates the `varinfo` used for execution. By default, it does not perform any actual sampling: it only evaluates the model using the values of the variables that are already in the `varinfo`. -To perform sampling, you can either wrap `model.context` in a `SamplingContext`, or use this convenience method: - -```@docs -DynamicPPL.evaluate_and_sample!! -``` The behaviour of a model execution can be changed with evaluation contexts, which are a field of the model. Contexts are subtypes of `AbstractPPL.AbstractContext`. @@ -469,7 +464,12 @@ InitContext ### VarInfo initialisation -`InitContext` is used to initialise, or overwrite, values in a VarInfo. +The function `init!!` is used to initialise, or overwrite, values in a VarInfo. +It is really a thin wrapper around using `evaluate!!` with an `InitContext`. + +```@docs +DynamicPPL.init!! +``` To accomplish this, an initialisation _strategy_ is required, which defines how new values are to be obtained. There are three concrete strategies provided in DynamicPPL: diff --git a/src/extract_priors.jl b/src/extract_priors.jl index 64dcf2eea..5342d70c4 100644 --- a/src/extract_priors.jl +++ b/src/extract_priors.jl @@ -117,7 +117,7 @@ extract_priors(args::Union{Model,AbstractVarInfo}...) = function extract_priors(rng::Random.AbstractRNG, model::Model) varinfo = VarInfo() varinfo = setaccs!!(varinfo, (PriorDistributionAccumulator(),)) - varinfo = last(evaluate_and_sample!!(rng, model, varinfo)) + varinfo = last(init!!(rng, model, varinfo)) return getacc(varinfo, Val(:PriorDistributionAccumulator)).priors end diff --git a/src/model.jl b/src/model.jl index 5a07129a0..2cd647291 100644 --- a/src/model.jl +++ b/src/model.jl @@ -850,7 +850,7 @@ end # ^ Weird Documenter.jl bug means that we have to write the two above separately # as it can only detect the `function`-less syntax. function (model::Model)(rng::Random.AbstractRNG, varinfo::AbstractVarInfo=VarInfo()) - return first(evaluate_and_sample!!(rng, model, varinfo)) + return first(init!!(rng, model, varinfo)) end """ @@ -863,32 +863,6 @@ function use_threadsafe_eval(context::AbstractContext, varinfo::AbstractVarInfo) return Threads.nthreads() > 1 end -""" - evaluate_and_sample!!([rng::Random.AbstractRNG, ]model::Model, varinfo[, sampler]) - -Evaluate the `model` with the given `varinfo`, but perform sampling during the -evaluation using the given `sampler` by wrapping the model's context in a -`SamplingContext`. - -If `sampler` is not provided, defaults to [`SampleFromPrior`](@ref). - -Returns a tuple of the model's return value, plus the updated `varinfo` object. -""" -function evaluate_and_sample!!( - rng::Random.AbstractRNG, - model::Model, - varinfo::AbstractVarInfo, - sampler::AbstractSampler=SampleFromPrior(), -) - sampling_model = contextualize(model, SamplingContext(rng, sampler, model.context)) - return evaluate!!(sampling_model, varinfo) -end -function evaluate_and_sample!!( - model::Model, varinfo::AbstractVarInfo, sampler::AbstractSampler=SampleFromPrior() -) - return evaluate_and_sample!!(Random.default_rng(), model, varinfo, sampler) -end - """ init!!( [rng::Random.AbstractRNG,] @@ -897,10 +871,10 @@ end [init_strategy::AbstractInitStrategy=PriorInit()] ) -Evaluate the `model` and replace the values of the model's random variables in -the given `varinfo` with new values using a specified initialisation strategy. -If the values in `varinfo` are not already present, they will be added using -that same strategy. +Evaluate the `model` and replace the values of the model's random variables +in the given `varinfo` with new values, using a specified initialisation strategy. +If the values in `varinfo` are not set, they will be added. +using a specified initialisation strategy. If `init_strategy` is not provided, defaults to PriorInit(). @@ -1049,11 +1023,7 @@ Base.nameof(model::Model{<:Function}) = nameof(model.f) Generate a sample of type `T` from the prior distribution of the `model`. """ function Base.rand(rng::Random.AbstractRNG, ::Type{T}, model::Model) where {T} - x = last( - evaluate_and_sample!!( - rng, model, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}()) - ), - ) + x = last(init!!(rng, model, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}()))) return values_as(x, T) end diff --git a/src/sampler.jl b/src/sampler.jl index 673b5128f..184e4b70a 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -58,7 +58,8 @@ function AbstractMCMC.step( kwargs..., ) vi = VarInfo() - DynamicPPL.evaluate_and_sample!!(rng, model, vi, sampler) + strategy = sampler isa SampleFromPrior ? PriorInit() : UniformInit() + DynamicPPL.init!!(rng, model, vi, strategy) return vi, nothing end diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index ad22bf52d..8d4a8191e 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -39,7 +39,7 @@ julia> rng = StableRNG(42); julia> # In the `NamedTuple` version we need to provide the place-holder values for # the variables which are using "containers", e.g. `Array`. # In this case, this means that we need to specify `x` but not `m`. - _, vi = DynamicPPL.evaluate_and_sample!!(rng, m, SimpleVarInfo((x = ones(2), ))); + _, vi = DynamicPPL.init!!(rng, m, SimpleVarInfo((x = ones(2), ))); julia> # (✓) Vroom, vroom! FAST!!! vi[@varname(x[1])] @@ -57,12 +57,12 @@ julia> vi[@varname(x[1:2])] 1.3736306979834252 julia> # (×) If we don't provide the container... - _, vi = DynamicPPL.evaluate_and_sample!!(rng, m, SimpleVarInfo()); vi + _, vi = DynamicPPL.init!!(rng, m, SimpleVarInfo()); vi ERROR: type NamedTuple has no field x [...] julia> # If one does not know the varnames, we can use a `OrderedDict` instead. - _, vi = DynamicPPL.evaluate_and_sample!!(rng, m, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}())); + _, vi = DynamicPPL.init!!(rng, m, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}())); julia> # (✓) Sort of fast, but only possible at runtime. vi[@varname(x[1])] @@ -91,28 +91,28 @@ demo_constrained (generic function with 2 methods) julia> m = demo_constrained(); -julia> _, vi = DynamicPPL.evaluate_and_sample!!(rng, m, SimpleVarInfo()); +julia> _, vi = DynamicPPL.init!!(rng, m, SimpleVarInfo()); julia> vi[@varname(x)] # (✓) 0 ≤ x < ∞ 1.8632965762164932 -julia> _, vi = DynamicPPL.evaluate_and_sample!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)); +julia> _, vi = DynamicPPL.init!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)); julia> vi[@varname(x)] # (✓) -∞ < x < ∞ -0.21080155351918753 -julia> xs = [last(DynamicPPL.evaluate_and_sample!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10]; +julia> xs = [last(DynamicPPL.init!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10]; julia> any(xs .< 0) # (✓) Positive probability mass on negative numbers! true julia> # And with `OrderedDict` of course! - _, vi = DynamicPPL.evaluate_and_sample!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(OrderedDict{VarName,Any}()), true)); + _, vi = DynamicPPL.init!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(OrderedDict{VarName,Any}()), true)); julia> vi[@varname(x)] # (✓) -∞ < x < ∞ 0.6225185067787314 -julia> xs = [last(DynamicPPL.evaluate_and_sample!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10]; +julia> xs = [last(DynamicPPL.init!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10]; julia> any(xs .< 0) # (✓) Positive probability mass on negative numbers! true @@ -232,24 +232,23 @@ end # Constructor from `Model`. function SimpleVarInfo{T}( - rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() + rng::Random.AbstractRNG, model::Model, init_strategy::AbstractInitStrategy=PriorInit() ) where {T<:Real} - new_model = contextualize(model, SamplingContext(rng, sampler, model.context)) - return last(evaluate!!(new_model, SimpleVarInfo{T}())) + return last(init!!(rng, model, SimpleVarInfo{T}(), init_strategy)) end function SimpleVarInfo{T}( - model::Model, sampler::AbstractSampler=SampleFromPrior() + model::Model, init_strategy::AbstractInitStrategy=PriorInit() ) where {T<:Real} - return SimpleVarInfo{T}(Random.default_rng(), model, sampler) + return SimpleVarInfo{T}(Random.default_rng(), model, init_strategy) end # Constructors without type param function SimpleVarInfo( - rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() + rng::Random.AbstractRNG, model::Model, init_strategy::AbstractInitStrategy=PriorInit() ) - return SimpleVarInfo{LogProbType}(rng, model, sampler) + return SimpleVarInfo{LogProbType}(rng, model, init_strategy) end -function SimpleVarInfo(model::Model, sampler::AbstractSampler=SampleFromPrior()) - return SimpleVarInfo{LogProbType}(Random.default_rng(), model, sampler) +function SimpleVarInfo(model::Model, init_strategy::AbstractInitStrategy=PriorInit()) + return SimpleVarInfo{LogProbType}(Random.default_rng(), model, init_strategy) end # Constructor from `VarInfo`. @@ -265,12 +264,12 @@ end function untyped_simple_varinfo(model::Model) varinfo = SimpleVarInfo(OrderedDict{VarName,Any}()) - return last(evaluate_and_sample!!(model, varinfo)) + return last(init!!(model, varinfo)) end function typed_simple_varinfo(model::Model) varinfo = SimpleVarInfo{Float64}() - return last(evaluate_and_sample!!(model, varinfo)) + return last(init!!(model, varinfo)) end function unflatten(svi::SimpleVarInfo, x::AbstractVector) @@ -482,7 +481,6 @@ function assume( return value, vi end -# NOTE: We don't implement `settrans!!(vi, trans, vn)`. function settrans!!(vi::SimpleVarInfo, trans) return settrans!!(vi, trans ? DynamicTransformation() : NoTransformation()) end @@ -492,6 +490,16 @@ end function settrans!!(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, trans) return Accessors.@set vi.varinfo = settrans!!(vi.varinfo, trans) end +function settrans!!(vi::SimpleOrThreadSafeSimple, trans::Bool, ::VarName) + # We keep this method around just to obey the AbstractVarInfo interface. + # However, note that this would only be a valid operation if it would be a + # no-op, which we check here. + if trans != istrans(vi) + error( + "Individual variables in SimpleVarInfo cannot have different `settrans` statuses.", + ) + end +end istrans(vi::SimpleVarInfo) = !(vi.transformation isa NoTransformation) istrans(vi::SimpleVarInfo, ::VarName) = istrans(vi) diff --git a/src/test_utils/contexts.jl b/src/test_utils/contexts.jl index 863db4262..4a019441b 100644 --- a/src/test_utils/contexts.jl +++ b/src/test_utils/contexts.jl @@ -29,21 +29,45 @@ function test_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Mod node_trait = DynamicPPL.NodeTrait(context) # Throw error immediately if it it's missing a `NodeTrait` implementation. node_trait isa Union{DynamicPPL.IsLeaf,DynamicPPL.IsParent} || - throw(ValueError("Invalid NodeTrait: $node_trait")) + error("Invalid NodeTrait: $node_trait") - # To see change, let's make sure we're using a different leaf context than the current. - leafcontext_new = if DynamicPPL.leafcontext(context) isa DefaultContext - DynamicPPL.DynamicTransformationContext{false}() + if node_trait isa DynamicPPL.IsLeaf + test_leaf_context(context, model) else - DefaultContext() + test_parent_context(context, model) end - @test DynamicPPL.leafcontext(DynamicPPL.setleafcontext(context, leafcontext_new)) == - leafcontext_new +end + +function test_leaf_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Model) + @test DynamicPPL.NodeTrait(context) isa DynamicPPL.IsLeaf + + # Note that for a leaf context we can't assume that it will work with an + # empty VarInfo. Thus we only test evaluation (i.e., assuming that the + # varinfo already contains all necessary variables). + @testset "evaluation" begin + # Generate a new filled untyped varinfo + _, untyped_vi = DynamicPPL.init!!(model, DynamicPPL.VarInfo()) + typed_vi = DynamicPPL.typed_varinfo(untyped_vi) + new_model = contextualize(model, context) + for vi in [untyped_vi, typed_vi] + _, vi = DynamicPPL.evaluate!!(new_model, vi) + @test vi isa DynamicPPL.VarInfo + end + end +end + +function test_parent_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Model) + @test DynamicPPL.NodeTrait(context) isa DynamicPPL.IsParent - # The interface methods. - if node_trait isa DynamicPPL.IsParent - # `childcontext` and `setchildcontext` - # With new child context + @testset "{set,}{leaf,child}context" begin + # Ensure we're using a different leaf context than the current. + leafcontext_new = if DynamicPPL.leafcontext(context) isa DefaultContext + DynamicPPL.DynamicTransformationContext{false}() + else + DefaultContext() + end + @test DynamicPPL.leafcontext(DynamicPPL.setleafcontext(context, leafcontext_new)) == + leafcontext_new childcontext_new = TestParentContext() @test DynamicPPL.childcontext( DynamicPPL.setchildcontext(context, childcontext_new) @@ -56,19 +80,15 @@ function test_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Mod leafcontext_new end - # Make sure that the we can evaluate the model with the context (i.e. that none of the tilde-functions are incorrectly overloaded). - # The tilde-pipeline contains two different paths: with `SamplingContext` as a parent, and without it. - # NOTE(torfjelde): Need to sample with the untyped varinfo _using_ the context, since the - # context might alter which variables are present, their names, etc., e.g. `PrefixContext`. - # TODO(torfjelde): Make the `varinfo` used for testing a kwarg once it makes sense for other varinfos. - # Untyped varinfo. - varinfo_untyped = DynamicPPL.VarInfo() - model_with_spl = contextualize(model, SamplingContext(context)) - model_without_spl = contextualize(model, context) - @test DynamicPPL.evaluate!!(model_with_spl, varinfo_untyped) isa Any - @test DynamicPPL.evaluate!!(model_without_spl, varinfo_untyped) isa Any - # Typed varinfo. - varinfo_typed = DynamicPPL.typed_varinfo(varinfo_untyped) - @test DynamicPPL.evaluate!!(model_with_spl, varinfo_typed) isa Any - @test DynamicPPL.evaluate!!(model_without_spl, varinfo_typed) isa Any + @testset "initialisation and evaluation" begin + new_model = contextualize(model, context) + for vi in [DynamicPPL.VarInfo(), DynamicPPL.typed_varinfo(DynamicPPL.VarInfo())] + # Initialisation + _, vi = DynamicPPL.init!!(new_model, DynamicPPL.VarInfo()) + @test vi isa DynamicPPL.VarInfo + # Evaluation + _, vi = DynamicPPL.evaluate!!(new_model, vi) + @test vi isa DynamicPPL.VarInfo + end + end end diff --git a/src/test_utils/model_interface.jl b/src/test_utils/model_interface.jl index 93aed074c..cb949464e 100644 --- a/src/test_utils/model_interface.jl +++ b/src/test_utils/model_interface.jl @@ -92,9 +92,7 @@ Even though it is recommended to implement this by hand for a particular `Model` a default implementation using [`SimpleVarInfo{<:Dict}`](@ref) is provided. """ function varnames(model::Model) - return collect( - keys(last(DynamicPPL.evaluate_and_sample!!(model, SimpleVarInfo(Dict())))) - ) + return collect(keys(last(DynamicPPL.init!!(model, SimpleVarInfo(Dict()))))) end """ diff --git a/src/varinfo.jl b/src/varinfo.jl index e115a6799..9be51cd15 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -113,10 +113,14 @@ function VarInfo(meta=Metadata()) end """ - VarInfo([rng, ]model[, sampler]) + VarInfo( + [rng::Random.AbstractRNG], + model, + [init_strategy::AbstractInitStrategy] + ) -Generate a `VarInfo` object for the given `model`, by evaluating it once using -the given `rng`, `sampler`. +Generate a `VarInfo` object for the given `model`, by initialising it with the +given `rng` and `init_strategy`. !!! warning @@ -129,12 +133,12 @@ the given `rng`, `sampler`. instead. """ function VarInfo( - rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() + rng::Random.AbstractRNG, model::Model, init_strategy::AbstractInitStrategy=PriorInit() ) - return typed_varinfo(rng, model, sampler) + return typed_varinfo(rng, model, init_strategy) end -function VarInfo(model::Model, sampler::AbstractSampler=SampleFromPrior()) - return VarInfo(Random.default_rng(), model, sampler) +function VarInfo(model::Model, init_strategy::AbstractInitStrategy=PriorInit()) + return VarInfo(Random.default_rng(), model, init_strategy) end const UntypedVectorVarInfo = VarInfo{<:VarNamedVector} @@ -195,7 +199,7 @@ end ######################## """ - untyped_varinfo([rng, ]model[, sampler]) + untyped_varinfo([rng, ]model[, init_strategy]) Construct a VarInfo object for the given `model`, which has just a single `Metadata` as its metadata field. @@ -203,15 +207,15 @@ Construct a VarInfo object for the given `model`, which has just a single # Arguments - `rng::Random.AbstractRNG`: The random number generator to use during model evaluation - `model::Model`: The model for which to create the varinfo object -- `sampler::AbstractSampler`: The sampler to use for the model. Defaults to `SampleFromPrior()`. +- `init_strategy::AbstractInitStrategy`: How the values are to be initialised. Defaults to `PriorInit()`. """ function untyped_varinfo( - rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() + rng::Random.AbstractRNG, model::Model, init_strategy::AbstractInitStrategy=PriorInit() ) - return last(evaluate_and_sample!!(rng, model, VarInfo(Metadata()), sampler)) + return last(init!!(rng, model, VarInfo(Metadata()), init_strategy)) end -function untyped_varinfo(model::Model, sampler::AbstractSampler=SampleFromPrior()) - return untyped_varinfo(Random.default_rng(), model, sampler) +function untyped_varinfo(model::Model, init_strategy::AbstractInitStrategy=PriorInit()) + return untyped_varinfo(Random.default_rng(), model, init_strategy) end """ @@ -270,7 +274,7 @@ function typed_varinfo(vi::NTVarInfo) return vi end """ - typed_varinfo([rng, ]model[, sampler]) + typed_varinfo([rng, ]model[, init_strategy]) Return a VarInfo object for the given `model`, which has a NamedTuple of `Metadata` structs as its metadata field. @@ -278,19 +282,19 @@ Return a VarInfo object for the given `model`, which has a NamedTuple of # Arguments - `rng::Random.AbstractRNG`: The random number generator to use during model evaluation - `model::Model`: The model for which to create the varinfo object -- `sampler::AbstractSampler`: The sampler to use for the model. Defaults to `SampleFromPrior()`. +- `init_strategy::AbstractInitStrategy`: How the values are to be initialised. Defaults to `PriorInit()`. """ function typed_varinfo( - rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() + rng::Random.AbstractRNG, model::Model, init_strategy::AbstractInitStrategy=PriorInit() ) - return typed_varinfo(untyped_varinfo(rng, model, sampler)) + return typed_varinfo(untyped_varinfo(rng, model, init_strategy)) end -function typed_varinfo(model::Model, sampler::AbstractSampler=SampleFromPrior()) - return typed_varinfo(Random.default_rng(), model, sampler) +function typed_varinfo(model::Model, init_strategy::AbstractInitStrategy=PriorInit()) + return typed_varinfo(Random.default_rng(), model, init_strategy) end """ - untyped_vector_varinfo([rng, ]model[, sampler]) + untyped_vector_varinfo([rng, ]model[, init_strategy]) Return a VarInfo object for the given `model`, which has just a single `VarNamedVector` as its metadata field. @@ -298,23 +302,25 @@ Return a VarInfo object for the given `model`, which has just a single # Arguments - `rng::Random.AbstractRNG`: The random number generator to use during model evaluation - `model::Model`: The model for which to create the varinfo object -- `sampler::AbstractSampler`: The sampler to use for the model. Defaults to `SampleFromPrior()`. +- `init_strategy::AbstractInitStrategy`: How the values are to be initialised. Defaults to `PriorInit()`. """ function untyped_vector_varinfo(vi::UntypedVarInfo) md = metadata_to_varnamedvector(vi.metadata) return VarInfo(md, copy(vi.accs)) end function untyped_vector_varinfo( - rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() + rng::Random.AbstractRNG, model::Model, init_strategy::AbstractInitStrategy=PriorInit() ) - return untyped_vector_varinfo(untyped_varinfo(rng, model, sampler)) + return untyped_vector_varinfo(untyped_varinfo(rng, model, init_strategy)) end -function untyped_vector_varinfo(model::Model, sampler::AbstractSampler=SampleFromPrior()) - return untyped_vector_varinfo(Random.default_rng(), model, sampler) +function untyped_vector_varinfo( + model::Model, init_strategy::AbstractInitStrategy=PriorInit() +) + return untyped_vector_varinfo(Random.default_rng(), model, init_strategy) end """ - typed_vector_varinfo([rng, ]model[, sampler]) + typed_vector_varinfo([rng, ]model[, init_strategy]) Return a VarInfo object for the given `model`, which has a NamedTuple of `VarNamedVector`s as its metadata field. @@ -322,7 +328,7 @@ Return a VarInfo object for the given `model`, which has a NamedTuple of # Arguments - `rng::Random.AbstractRNG`: The random number generator to use during model evaluation - `model::Model`: The model for which to create the varinfo object -- `sampler::AbstractSampler`: The sampler to use for the model. Defaults to `SampleFromPrior()`. +- `init_strategy::AbstractInitStrategy`: How the values are to be initialised. Defaults to `PriorInit()`. """ function typed_vector_varinfo(vi::NTVarInfo) md = map(metadata_to_varnamedvector, vi.metadata) @@ -334,12 +340,12 @@ function typed_vector_varinfo(vi::UntypedVectorVarInfo) return VarInfo(nt, copy(vi.accs)) end function typed_vector_varinfo( - rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() + rng::Random.AbstractRNG, model::Model, init_strategy::AbstractInitStrategy=PriorInit() ) - return typed_vector_varinfo(untyped_vector_varinfo(rng, model, sampler)) + return typed_vector_varinfo(untyped_vector_varinfo(rng, model, init_strategy)) end -function typed_vector_varinfo(model::Model, sampler::AbstractSampler=SampleFromPrior()) - return typed_vector_varinfo(Random.default_rng(), model, sampler) +function typed_vector_varinfo(model::Model, init_strategy::AbstractInitStrategy=PriorInit()) + return typed_vector_varinfo(Random.default_rng(), model, init_strategy) end """ diff --git a/test/compiler.jl b/test/compiler.jl index 97121715a..874b71204 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -193,11 +193,8 @@ module Issue537 end varinfo = VarInfo(model) @test getlogjoint(varinfo) == lp @test varinfo_ isa AbstractVarInfo - # During the model evaluation, its context is wrapped in a - # SamplingContext, so `model_` is not going to be equal to `model`. - # We can still check equality of `f` though. @test model_.f === model.f - @test model_.context isa SamplingContext + @test model_.context isa DynamicPPL.InitContext @test model_.context.rng isa Random.AbstractRNG # disable warnings @@ -598,13 +595,13 @@ module Issue537 end # an attempt at a `NamedTuple` of the form `(x = 1, __varinfo__)`. @model empty_model() = return x = 1 empty_vi = VarInfo() - retval_and_vi = DynamicPPL.evaluate_and_sample!!(empty_model(), empty_vi) + retval_and_vi = DynamicPPL.init!!(empty_model(), empty_vi) @test retval_and_vi isa Tuple{Int,typeof(empty_vi)} # Even if the return-value is `AbstractVarInfo`, we should return # a `Tuple` with `AbstractVarInfo` in the second component too. @model demo() = return __varinfo__ - retval, svi = DynamicPPL.evaluate_and_sample!!(demo(), SimpleVarInfo()) + retval, svi = DynamicPPL.init!!(demo(), SimpleVarInfo()) @test svi == SimpleVarInfo() if Threads.nthreads() > 1 @test retval isa DynamicPPL.ThreadSafeVarInfo{<:SimpleVarInfo} @@ -620,11 +617,11 @@ module Issue537 end f(x) = return x^2 return f(1.0) end - retval, svi = DynamicPPL.evaluate_and_sample!!(demo(), SimpleVarInfo()) + retval, svi = DynamicPPL.init!!(demo(), SimpleVarInfo()) @test retval isa Float64 @model demo() = x ~ Normal() - retval, svi = DynamicPPL.evaluate_and_sample!!(demo(), SimpleVarInfo()) + retval, svi = DynamicPPL.init!!(demo(), SimpleVarInfo()) # Return-value when using `to_submodel` @model inner() = x ~ Normal() diff --git a/test/contexts.jl b/test/contexts.jl index e2a882f48..a802e0c53 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -166,29 +166,30 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() @test new_vn == @varname(a.x[1]) @test new_ctx == DefaultContext() - ctx2 = SamplingContext(PrefixContext(@varname(a))) + ctx2 = FixedContext((b=4,), PrefixContext(@varname(a))) new_vn, new_ctx = DynamicPPL.prefix_and_strip_contexts(ctx2, vn) @test new_vn == @varname(a.x[1]) - @test new_ctx == SamplingContext() + @test new_ctx == FixedContext((b=4,)) ctx3 = PrefixContext(@varname(a), ConditionContext((a=1,))) new_vn, new_ctx = DynamicPPL.prefix_and_strip_contexts(ctx3, vn) @test new_vn == @varname(a.x[1]) @test new_ctx == ConditionContext((a=1,)) - ctx4 = SamplingContext(PrefixContext(@varname(a), ConditionContext((a=1,)))) + ctx4 = FixedContext( + (b=4,), PrefixContext(@varname(a), ConditionContext((a=1,))) + ) new_vn, new_ctx = DynamicPPL.prefix_and_strip_contexts(ctx4, vn) @test new_vn == @varname(a.x[1]) - @test new_ctx == SamplingContext(ConditionContext((a=1,))) + @test new_ctx == FixedContext((b=4,), ConditionContext((a=1,))) end @testset "evaluation: $(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS prefix_vn = @varname(my_prefix) - context = DynamicPPL.PrefixContext(prefix_vn, SamplingContext()) - sampling_model = contextualize(model, context) - # Sample with the context. - varinfo = DynamicPPL.VarInfo() - DynamicPPL.evaluate!!(sampling_model, varinfo) + context = DynamicPPL.PrefixContext(prefix_vn, DefaultContext()) + new_model = contextualize(model, context) + # Initialize a new varinfo with the prefixed model + _, varinfo = DynamicPPL.init!!(new_model, DynamicPPL.VarInfo()) # Extract the resulting varnames vns_actual = Set(keys(varinfo)) diff --git a/test/model.jl b/test/model.jl index 81f84e548..63a455c26 100644 --- a/test/model.jl +++ b/test/model.jl @@ -155,24 +155,6 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() logjoint(model, chain) end - @testset "rng" begin - model = GDEMO_DEFAULT - - for sampler in (SampleFromPrior(), SampleFromUniform()) - for i in 1:10 - Random.seed!(100 + i) - vi = VarInfo() - DynamicPPL.evaluate_and_sample!!(Random.default_rng(), model, vi, sampler) - vals = vi[:] - - Random.seed!(100 + i) - vi = VarInfo() - DynamicPPL.evaluate_and_sample!!(Random.default_rng(), model, vi, sampler) - @test vi[:] == vals - end - end - end - @testset "defaults without VarInfo, Sampler, and Context" begin model = GDEMO_DEFAULT @@ -332,7 +314,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() @test logjoint(model, x) != DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian(model, x...) # Ensure `varnames` is implemented. - vi = last(DynamicPPL.evaluate_and_sample!!(model, SimpleVarInfo(OrderedDict()))) + vi = last(DynamicPPL.init!!(model, SimpleVarInfo(OrderedDict{VarName,Any}()))) @test all(collect(keys(vi)) .== DynamicPPL.TestUtils.varnames(model)) # Ensure `posterior_mean` is implemented. @test DynamicPPL.TestUtils.posterior_mean(model) isa typeof(x) @@ -598,10 +580,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() xs_train = 1:0.1:10 ys_train = ground_truth_β .* xs_train + rand(Normal(0, 0.1), length(xs_train)) m_lin_reg = linear_reg(xs_train, ys_train) - chain = [ - last(DynamicPPL.evaluate_and_sample!!(m_lin_reg, VarInfo())) for - _ in 1:10000 - ] + chain = [VarInfo(m_lin_reg) _ in 1:10000] # chain is generated from the prior @test mean([chain[i][@varname(β)] for i in eachindex(chain)]) ≈ 1.0 atol = 0.1 diff --git a/test/sampler.jl b/test/sampler.jl index fe9fd331a..d04f39ac9 100644 --- a/test/sampler.jl +++ b/test/sampler.jl @@ -69,8 +69,8 @@ end # initial samplers - DynamicPPL.initialsampler(::Sampler{OnlyInitAlgUniform}) = SampleFromUniform() - @test DynamicPPL.initialsampler(Sampler(OnlyInitAlgDefault())) == SampleFromPrior() + DynamicPPL.init_strategy(::Sampler{OnlyInitAlgUniform}) = UniformInit() + @test DynamicPPL.init_strategy(Sampler(OnlyInitAlgDefault())) == PriorInit() for alg in (OnlyInitAlgDefault(), OnlyInitAlgUniform()) # model with one variable: initialization p = 0.2 diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index be6deb96e..93c7b069e 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -160,7 +160,7 @@ ### Sampling ### # Sample a new varinfo! - _, svi_new = DynamicPPL.evaluate_and_sample!!(model, svi) + _, svi_new = DynamicPPL.init!!(model, svi) # Realization for `m` should be different wp. 1. for vn in DynamicPPL.TestUtils.varnames(model) @@ -228,9 +228,9 @@ # Initialize. svi_nt = DynamicPPL.settrans!!(SimpleVarInfo(), true) - svi_nt = last(DynamicPPL.evaluate_and_sample!!(model, svi_nt)) + svi_nt = last(DynamicPPL.init!!(model, svi_nt)) svi_vnv = DynamicPPL.settrans!!(SimpleVarInfo(DynamicPPL.VarNamedVector()), true) - svi_vnv = last(DynamicPPL.evaluate_and_sample!!(model, svi_vnv)) + svi_vnv = last(DynamicPPL.init!!(model, svi_vnv)) for svi in (svi_nt, svi_vnv) # Sample with large variations in unconstrained space. @@ -275,7 +275,7 @@ ) # Resulting varinfo should no longer be transformed. - vi_result = last(DynamicPPL.evaluate_and_sample!!(model, deepcopy(vi))) + vi_result = last(DynamicPPL.init!!(model, deepcopy(vi))) @test !DynamicPPL.istrans(vi_result) # Set the values to something that is out of domain if we're in constrained space. diff --git a/test/varinfo.jl b/test/varinfo.jl index 202ddc1b2..22c3a820b 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -42,7 +42,7 @@ end end model = gdemo(1.0, 2.0) - vi = DynamicPPL.untyped_varinfo(model, SampleFromUniform()) + _, vi = DynamicPPL.init!!(model, VarInfo(), UniformInit()) tvi = DynamicPPL.typed_varinfo(vi) meta = vi.metadata @@ -470,17 +470,18 @@ end end model = gdemo([1.0, 1.5], [2.0, 2.5]) - # Check that instantiating the model using SampleFromUniform does not + # Check that instantiating the model using UniformInit does not # perform linking - # Note (penelopeysm): The purpose of using SampleFromUniform (SFU) - # specifically in this test is because SFU samples from the linked - # distribution i.e. in unconstrained space. However, it does this not - # by linking the varinfo but by transforming the distributions on the - # fly. That's why it's worth specifically checking that it can do this - # without having to change the VarInfo object. + # Note (penelopeysm): The purpose of using UniformInit specifically in + # this test is because it samples from the linked distribution i.e. in + # unconstrained space. However, it does this not by linking the varinfo + # but by transforming the distributions on the fly. That's why it's + # worth specifically checking that it can do this without having to + # change the VarInfo object. + # TODO(penelopeysm): Move this to UniformInit tests rather than here. vi = VarInfo() meta = vi.metadata - _, vi = DynamicPPL.evaluate_and_sample!!(model, vi, SampleFromUniform()) + _, vi = DynamicPPL.init!!(model, vi, UniformInit()) @test all(x -> !istrans(vi, x), meta.vns) # Check that linking and invlinking set the `trans` flag accordingly @@ -544,7 +545,7 @@ end function test_linked_varinfo(model, vi) # vn and dist are taken from the containing scope - vi = last(DynamicPPL.evaluate_and_sample!!(model, vi)) + vi = last(DynamicPPL.init!!(model, vi, PriorInit())) f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) x = f(DynamicPPL.getindex_internal(vi, vn)) @test istrans(vi, vn) @@ -960,10 +961,9 @@ end end model1 = demo(1) varinfo1 = DynamicPPL.link!!(VarInfo(model1), model1) - # Sampling from `model2` should hit the `istrans(vi) == true` branches - # because all the existing variables are linked. + # Calling init!! should preserve the fact that the variables are linked. model2 = demo(2) - varinfo2 = last(DynamicPPL.evaluate_and_sample!!(model2, deepcopy(varinfo1))) + varinfo2 = last(DynamicPPL.init!!(model2, deepcopy(varinfo1), PriorInit())) for vn in [@varname(x[1]), @varname(x[2])] @test DynamicPPL.istrans(varinfo2, vn) end @@ -979,10 +979,9 @@ end end model1 = demo_dot(1) varinfo1 = DynamicPPL.link!!(DynamicPPL.untyped_varinfo(model1), model1) - # Sampling from `model2` should hit the `istrans(vi) == true` branches - # because all the existing variables are linked. + # Calling init!! should preserve the fact that the variables are linked. model2 = demo_dot(2) - varinfo2 = last(DynamicPPL.evaluate_and_sample!!(model2, deepcopy(varinfo1))) + varinfo2 = last(DynamicPPL.init!!(model2, deepcopy(varinfo1), PriorInit())) for vn in [@varname(x), @varname(y[1])] @test DynamicPPL.istrans(varinfo2, vn) end diff --git a/test/varnamedvector.jl b/test/varnamedvector.jl index 57a8175d4..af24be86f 100644 --- a/test/varnamedvector.jl +++ b/test/varnamedvector.jl @@ -610,9 +610,7 @@ end DynamicPPL.TestUtils.test_values(varinfo_eval, value_true, vns) # Is sampling correct? - varinfo_sample = last( - DynamicPPL.evaluate_and_sample!!(model, deepcopy(varinfo)) - ) + varinfo_sample = last(DynamicPPL.init!!(model, deepcopy(varinfo))) # Log density should be different. @test getlogjoint(varinfo_sample) != getlogjoint(varinfo) # Values should be different. From a0f308bb6412e685193e94971305b1100874e754 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 10 Jul 2025 01:51:13 +0100 Subject: [PATCH 09/20] Use `ParamsInit` for `predict`; remove `setval_and_resample!` and friends --- ext/DynamicPPLMCMCChainsExt.jl | 38 +++++--- src/model.jl | 11 ++- src/varinfo.jl | 143 ---------------------------- test/ext/DynamicPPLMCMCChainsExt.jl | 7 +- test/model.jl | 2 +- test/test_util.jl | 4 +- test/varinfo.jl | 60 +----------- 7 files changed, 49 insertions(+), 216 deletions(-) diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl index a29696720..cd86cfb5e 100644 --- a/ext/DynamicPPLMCMCChainsExt.jl +++ b/ext/DynamicPPLMCMCChainsExt.jl @@ -28,7 +28,7 @@ end function _check_varname_indexing(c::MCMCChains.Chains) return DynamicPPL.supports_varname_indexing(c) || - error("Chains do not support indexing using `VarName`s.") + error("This `Chains` object does not support indexing using `VarName`s.") end function DynamicPPL.getindex_varname( @@ -42,6 +42,15 @@ function DynamicPPL.varnames(c::MCMCChains.Chains) return keys(c.info.varname_to_symbol) end +function chain_sample_to_varname_dict(c::MCMCChains.Chains, sample_idx, chain_idx) + _check_varname_indexing(c) + d = Dict{DynamicPPL.VarName,Any}() + for vn in DynamicPPL.varnames(c) + d[vn] = DynamicPPL.getindex_varname(c, sample_idx, vn, chain_idx) + end + return d +end + """ predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false) @@ -114,9 +123,15 @@ function DynamicPPL.predict( iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) predictive_samples = map(iters) do (sample_idx, chain_idx) - DynamicPPL.setval_and_resample!(varinfo, parameter_only_chain, sample_idx, chain_idx) - varinfo = last(DynamicPPL.evaluate_and_sample!!(rng, model, varinfo)) - + # Extract values from the chain + values_dict = chain_sample_to_varname_dict(parameter_only_chain, sample_idx, chain_idx) + # Resample any variables that are not present in `values_dict` + _, varinfo = DynamicPPL.init!!( + rng, + model, + varinfo, + DynamicPPL.ParamsInit(values_dict, DynamicPPL.PriorInit()), + ) vals = DynamicPPL.values_as_in_model(model, false, varinfo) varname_vals = mapreduce( collect, @@ -248,13 +263,14 @@ function DynamicPPL.returned(model::DynamicPPL.Model, chain_full::MCMCChains.Cha varinfo = DynamicPPL.VarInfo(model) iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) return map(iters) do (sample_idx, chain_idx) - # TODO: Use `fix` once we've addressed https://github.com/TuringLang/DynamicPPL.jl/issues/702. - # Update the varinfo with the current sample and make variables not present in `chain` - # to be sampled. - DynamicPPL.setval_and_resample!(varinfo, chain, sample_idx, chain_idx) - # NOTE: Some of the varialbes can be a view into the `varinfo`, so we need to - # `deepcopy` the `varinfo` before passing it to the `model`. - model(deepcopy(varinfo)) + # Extract values from the chain + values_dict = chain_sample_to_varname_dict(chain, sample_idx, chain_idx) + # Resample any variables that are not present in `values_dict`, and + # return the model's retval. + retval, _ = DynamicPPL.init!!( + model, varinfo, DynamicPPL.ParamsInit(values_dict, DynamicPPL.PriorInit()) + ) + retval end end diff --git a/src/model.jl b/src/model.jl index 2cd647291..2f6036122 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1209,8 +1209,15 @@ function predict( varinfo = DynamicPPL.VarInfo(model) return map(chain) do params_varinfo vi = deepcopy(varinfo) - DynamicPPL.setval_and_resample!(vi, values_as(params_varinfo, NamedTuple)) - model(rng, vi) + # TODO(penelopeysm): Requires two model evaluations, one to extract the + # parameters and one to set them. The reason why we need values_as_in_model + # is because `params_varinfo` may well have some weird combination of + # linked/unlinked, whereas `varinfo` is always unlinked since it is + # freshly constructed. + # This is quite inefficient. It would of course be alright if + # ValuesAsInModelAccumulator was a default acc. + values_nt = values_as_in_model(model, false, params_varinfo) + _, vi = DynamicPPL.init!!(rng, model, vi, ParamsInit(values_nt, PriorInit())) return vi end end diff --git a/src/varinfo.jl b/src/varinfo.jl index 9be51cd15..26a5c34ac 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1514,42 +1514,6 @@ function islinked(vi::VarInfo) return any(istrans(vi, vn) for vn in keys(vi)) end -function nested_setindex_maybe!(vi::UntypedVarInfo, val, vn::VarName) - return _nested_setindex_maybe!(vi, getmetadata(vi, vn), val, vn) -end -function nested_setindex_maybe!( - vi::VarInfo{<:NamedTuple{names}}, val, vn::VarName{sym} -) where {names,sym} - return if sym in names - _nested_setindex_maybe!(vi, getmetadata(vi, vn), val, vn) - else - nothing - end -end -function _nested_setindex_maybe!( - vi::VarInfo, md::Union{Metadata,VarNamedVector}, val, vn::VarName -) - # If `vn` is in `vns`, then we can just use the standard `setindex!`. - vns = Base.keys(md) - if vn in vns - setindex!(vi, val, vn) - return vn - end - - # Otherwise, we need to check if either of the `vns` subsumes `vn`. - i = findfirst(Base.Fix2(subsumes, vn), vns) - i === nothing && return nothing - - vn_parent = vns[i] - val_parent = getindex(vi, vn_parent) # TODO: Ensure that we're working with a view here. - # Split the varname into its tail optic. - optic = remove_parent_optic(vn_parent, vn) - # Update the value for the parent. - val_parent_updated = set!!(val_parent, optic, val) - setindex!(vi, val_parent_updated, vn_parent) - return vn_parent -end - # The default getindex & setindex!() for get & set values # NOTE: vi[vn] will always transform the variable to its original space and Julia type function getindex(vi::VarInfo, vn::VarName) @@ -1972,113 +1936,6 @@ function _setval_kernel!(vi::VarInfoOrThreadSafeVarInfo, vn::VarName, values, ke return indices end -""" - setval_and_resample!(vi::VarInfo, x) - setval_and_resample!(vi::VarInfo, values, keys) - setval_and_resample!(vi::VarInfo, chains::AbstractChains, sample_idx, chain_idx) - -Set the values in `vi` to the provided values and those which are not present -in `x` or `chains` to *be* resampled. - -Note that this does *not* resample the values not provided! It will call -`setflag!(vi, vn, "del")` for variables `vn` for which no values are provided, which means -that the next time we call `model(vi)` these variables will be resampled. - -## Note -- This suffers from the same limitations as [`setval!`](@ref). See `setval!` for more info. - -## Example -```jldoctest -julia> using DynamicPPL, Distributions, StableRNGs - -julia> @model function demo(x) - m ~ Normal() - for i in eachindex(x) - x[i] ~ Normal(m, 1) - end - end; - -julia> rng = StableRNG(42); - -julia> m = demo([missing]); - -julia> var_info = DynamicPPL.VarInfo(rng, m); - # Checking the setting of "del" flags only makes sense for VarInfo{<:Metadata}. For VarInfo{<:VarNamedVector} the flag is effectively always set. - -julia> var_info[@varname(m)] --0.6702516921145671 - -julia> var_info[@varname(x[1])] --0.22312984965118443 - -julia> DynamicPPL.setval_and_resample!(var_info, (m = 100.0, )); # set `m` and ready `x[1]` for resampling - -julia> var_info[@varname(m)] # [✓] changed -100.0 - -julia> var_info[@varname(x[1])] # [✓] unchanged --0.22312984965118443 - -julia> m(rng, var_info); # sample `x[1]` conditioned on `m = 100.0` - -julia> var_info[@varname(m)] # [✓] unchanged -100.0 - -julia> var_info[@varname(x[1])] # [✓] changed -101.37363069798343 -``` - -## See also -- [`setval!`](@ref) -""" -function setval_and_resample!(vi::VarInfoOrThreadSafeVarInfo, x) - return setval_and_resample!(vi, values(x), keys(x)) -end -function setval_and_resample!(vi::VarInfoOrThreadSafeVarInfo, values, keys) - return _apply!(_setval_and_resample_kernel!, vi, values, keys) -end -function setval_and_resample!( - vi::VarInfoOrThreadSafeVarInfo, chains::AbstractChains, sample_idx::Int, chain_idx::Int -) - if supports_varname_indexing(chains) - # First we need to set every variable to be resampled. - for vn in keys(vi) - set_flag!(vi, vn, "del") - end - # Then we set the variables in `varinfo` from `chain`. - for vn in varnames(chains) - vn_updated = nested_setindex_maybe!( - vi, getindex_varname(chains, sample_idx, vn, chain_idx), vn - ) - - # Unset the `del` flag if we found something. - if vn_updated !== nothing - # NOTE: This will be triggered even if only a subset of a variable has been set! - unset_flag!(vi, vn_updated, "del") - end - end - else - setval_and_resample!(vi, chains.value[sample_idx, :, chain_idx], keys(chains)) - end -end - -function _setval_and_resample_kernel!( - vi::VarInfoOrThreadSafeVarInfo, vn::VarName, values, keys -) - indices = findall(Base.Fix1(subsumes_string, string(vn)), keys) - if !isempty(indices) - val = reduce(vcat, values[indices]) - setval!(vi, val, vn) - settrans!!(vi, false, vn) - else - # Ensures that we'll resample the variable corresponding to `vn` if we run - # the model on `vi` again. - set_flag!(vi, vn, "del") - end - - return indices -end - values_as(vi::VarInfo) = vi.metadata values_as(vi::VarInfo, ::Type{Vector}) = copy(getindex_internal(vi, Colon())) function values_as(vi::UntypedVarInfo, ::Type{NamedTuple}) diff --git a/test/ext/DynamicPPLMCMCChainsExt.jl b/test/ext/DynamicPPLMCMCChainsExt.jl index 3ba5edfe1..79e13ad84 100644 --- a/test/ext/DynamicPPLMCMCChainsExt.jl +++ b/test/ext/DynamicPPLMCMCChainsExt.jl @@ -2,7 +2,12 @@ @model demo() = x ~ Normal() model = demo() - chain = MCMCChains.Chains(randn(1000, 2, 1), [:x, :y], Dict(:internals => [:y])) + chain = MCMCChains.Chains( + randn(1000, 2, 1), + [:x, :y], + Dict(:internals => [:y]); + info=(; varname_to_symbol=Dict(@varname(x) => :x)), + ) chain_generated = @test_nowarn returned(model, chain) @test size(chain_generated) == (1000, 1) @test mean(chain_generated) ≈ 0 atol = 0.1 diff --git a/test/model.jl b/test/model.jl index 63a455c26..e5b90b0e2 100644 --- a/test/model.jl +++ b/test/model.jl @@ -580,7 +580,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() xs_train = 1:0.1:10 ys_train = ground_truth_β .* xs_train + rand(Normal(0, 0.1), length(xs_train)) m_lin_reg = linear_reg(xs_train, ys_train) - chain = [VarInfo(m_lin_reg) _ in 1:10000] + chain = [VarInfo(m_lin_reg) for _ in 1:10000] # chain is generated from the prior @test mean([chain[i][@varname(β)] for i in eachindex(chain)]) ≈ 1.0 atol = 0.1 diff --git a/test/test_util.jl b/test/test_util.jl index d5335249d..b7c46ff34 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -81,8 +81,10 @@ function make_chain_from_prior(rng::Random.AbstractRNG, model::Model, n_iters::I varnames = collect(varnames) # Construct matrix of values vals = [get(dict, vn, missing) for dict in dicts, vn in varnames] + # Construct dict of varnames -> symbol + vn_to_sym_dict = Dict(zip(varnames, map(Symbol, varnames))) # Construct and return the Chains object - return Chains(vals, varnames) + return Chains(vals, varnames; info=(; varname_to_symbol=vn_to_sym_dict)) end function make_chain_from_prior(model::Model, n_iters::Int) return make_chain_from_prior(Random.default_rng(), model, n_iters) diff --git a/test/varinfo.jl b/test/varinfo.jl index 22c3a820b..3fc770c61 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -262,7 +262,7 @@ end @test typed_vi[vn_y] == 2.0 end - @testset "setval! & setval_and_resample!" begin + @testset "setval!" begin @model function testmodel(x) n = length(x) s ~ truncated(Normal(); lower=0) @@ -313,8 +313,8 @@ end else DynamicPPL.setval!(vicopy, (m=zeros(5),)) end - # Setting `m` fails for univariate due to limitations of `setval!` - # and `setval_and_resample!`. See docstring of `setval!` for more info. + # Setting `m` fails for univariate due to limitations of `setval!`. + # See docstring of `setval!` for more info. if model == model_uv && vi in [vi_untyped, vi_typed] @test_broken vicopy[m_vns] == zeros(5) else @@ -339,57 +339,6 @@ end DynamicPPL.setval!(vicopy, (s=42,)) @test vicopy[m_vns] == 1:5 @test vicopy[s_vns] == 42 - - ### `setval_and_resample!` ### - if model == model_mv && vi == vi_untyped - # Trying to re-run model with `MvNormal` on `vi_untyped` will call - # `MvNormal(μ::Vector{Real}, Σ)` which causes `StackOverflowError` - # so we skip this particular case. - continue - end - - if vi in [vi_vnv, vi_vnv_typed] - # `setval_and_resample!` works differently for `VarNamedVector`: All - # values will be resampled when model(vicopy) is called. Hence the below - # tests are not applicable. - continue - end - - vicopy = deepcopy(vi) - DynamicPPL.setval_and_resample!(vicopy, (m=zeros(5),)) - model(vicopy) - # Setting `m` fails for univariate due to limitations of `subsumes(::String, ::String)` - if model == model_uv - @test_broken vicopy[m_vns] == zeros(5) - else - @test vicopy[m_vns] == zeros(5) - end - @test vicopy[s_vns] != vi[s_vns] - - # Ordering is NOT preserved. - DynamicPPL.setval_and_resample!( - vicopy, (; (Symbol("m[$i]") => i for i in (1, 3, 5, 4, 2))...) - ) - model(vicopy) - if model == model_uv - @test vicopy[m_vns] == 1:5 - else - @test vicopy[m_vns] == [1, 3, 5, 4, 2] - end - @test vicopy[s_vns] != vi[s_vns] - - # Correct ordering. - DynamicPPL.setval_and_resample!( - vicopy, (; (Symbol("m[$i]") => i for i in (1, 2, 3, 4, 5))...) - ) - model(vicopy) - @test vicopy[m_vns] == 1:5 - @test vicopy[s_vns] != vi[s_vns] - - DynamicPPL.setval_and_resample!(vicopy, (s=42,)) - model(vicopy) - @test vicopy[m_vns] != 1:5 - @test vicopy[s_vns] == 42 end end @@ -403,9 +352,6 @@ end ks = [@varname(x[1, 1]), @varname(x[2, 1]), @varname(x[1, 2]), @varname(x[2, 2])] DynamicPPL.setval!(vi, vi.metadata.x.vals, ks) @test vals_prev == vi.metadata.x.vals - - DynamicPPL.setval_and_resample!(vi, vi.metadata.x.vals, ks) - @test vals_prev == vi.metadata.x.vals end @testset "setval! on chain" begin From 4ae143cebb02aa96619102cfd7df97ddcfd3b9d0 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 10 Jul 2025 12:31:37 +0100 Subject: [PATCH 10/20] Use `init!!` for initialisation --- docs/src/api.md | 2 +- src/sampler.jl | 150 ++++++++++-------------------------------------- test/sampler.jl | 64 ++++----------------- 3 files changed, 41 insertions(+), 175 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index 7b0e7f927..9c132a6c4 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -508,7 +508,7 @@ The default implementation of [`Sampler`](@ref) uses the following unexported fu ```@docs DynamicPPL.initialstep DynamicPPL.loadstate -DynamicPPL.initialsampler +DynamicPPL.init_strategy ``` Finally, to specify which varinfo type a [`Sampler`](@ref) should use for a given [`Model`](@ref), this is specified by [`DynamicPPL.default_varinfo`](@ref) and can thus be overloaded for each `model`-`sampler` combination. This can be useful in cases where one has explicit knowledge that one type of varinfo will be more performant for the given `model` and `sampler`. diff --git a/src/sampler.jl b/src/sampler.jl index 184e4b70a..e184c4707 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -41,7 +41,7 @@ Generic sampler type for inference algorithms of type `T` in DynamicPPL. provided that supports resuming sampling from a previous state and setting initial parameter values. It requires to overload [`loadstate`](@ref) and [`initialstep`](@ref) for loading previous states and actually performing the initial sampling step, -respectively. Additionally, sometimes one might want to implement [`initialsampler`](@ref) +respectively. Additionally, sometimes one might want to implement [`init_strategy`](@ref) that specifies how the initial parameter values are sampled if they are not provided. By default, values are sampled from the prior. """ @@ -68,6 +68,8 @@ end Return a default varinfo object for the given `model` and `sampler`. +The default method for this returns an empty NTVarInfo (i.e. 'typed varinfo'). + # Arguments - `rng::Random.AbstractRNG`: Random number generator. - `model::Model`: Model for which we want to create a varinfo object. @@ -76,9 +78,10 @@ Return a default varinfo object for the given `model` and `sampler`. # Returns - `AbstractVarInfo`: Default varinfo object for the given `model` and `sampler`. """ -function default_varinfo(rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler) - init_sampler = initialsampler(sampler) - return typed_varinfo(rng, model, init_sampler) +function default_varinfo(::Random.AbstractRNG, ::Model, ::AbstractSampler) + # Note that variable values are unconditionally initialized later, so no + # point putting them in now. + return typed_varinfo(VarInfo()) end function AbstractMCMC.sample( @@ -96,24 +99,32 @@ function AbstractMCMC.sample( ) end -# initial step: general interface for resuming and +""" + init_strategy(sampler) + +Define the initialisation strategy used for generating initial values when +sampling with `sampler`. Defaults to `PriorInit()`, but can be overridden. +""" +init_strategy(::Sampler) = PriorInit() + function AbstractMCMC.step( - rng::Random.AbstractRNG, model::Model, spl::Sampler; initial_params=nothing, kwargs... + rng::Random.AbstractRNG, + model::Model, + spl::Sampler; + initial_params::AbstractInitStrategy=init_strategy(spl), + kwargs..., ) - # Sample initial values. + # Generate the default varinfo (usually this just makes an empty VarInfo + # with NamedTuple of Metadata). vi = default_varinfo(rng, model, spl) - # Update the parameters if provided. - if initial_params !== nothing - vi = initialize_parameters!!(vi, initial_params, model) - - # Update joint log probability. - # This is a quick fix for https://github.com/TuringLang/Turing.jl/issues/1588 - # and https://github.com/TuringLang/Turing.jl/issues/1563 - # to avoid that existing variables are resampled - vi = last(evaluate!!(model, vi)) - end + # Fill it with initial parameters. Note that, if `ParamsInit` is used, the + # parameters provided must be in unlinked space (when inserted into the + # varinfo, they will be adjusted to match the linking status of the + # varinfo). + _, vi = init!!(rng, model, vi, initial_params) + # Call the actual function that does the first step. return initialstep(rng, model, spl, vi; initial_params, kwargs...) end @@ -131,110 +142,7 @@ loadstate(data) = data Default type of the chain of posterior samples from `sampler`. """ -default_chain_type(sampler::Sampler) = Any - -""" - initialsampler(sampler::Sampler) - -Return the sampler that is used for generating the initial parameters when sampling with -`sampler`. - -By default, it returns an instance of [`SampleFromPrior`](@ref). -""" -initialsampler(spl::Sampler) = SampleFromPrior() - -""" - set_initial_values(varinfo::AbstractVarInfo, initial_params::AbstractVector) - set_initial_values(varinfo::AbstractVarInfo, initial_params::NamedTuple) - -Take the values inside `initial_params`, replace the corresponding values in -the given VarInfo object, and return a new VarInfo object with the updated values. - -This differs from `DynamicPPL.unflatten` in two ways: - -1. It works with `NamedTuple` arguments. -2. For the `AbstractVector` method, if any of the elements are missing, it will not -overwrite the original value in the VarInfo (it will just use the original -value instead). -""" -function set_initial_values(varinfo::AbstractVarInfo, initial_params::AbstractVector) - throw( - ArgumentError( - "`initial_params` must be a vector of type `Union{Real,Missing}`. " * - "If `initial_params` is a vector of vectors, please flatten it (e.g. using `vcat`) first.", - ), - ) -end - -function set_initial_values( - varinfo::AbstractVarInfo, initial_params::AbstractVector{<:Union{Real,Missing}} -) - flattened_param_vals = varinfo[:] - length(flattened_param_vals) == length(initial_params) || throw( - DimensionMismatch( - "Provided initial value size ($(length(initial_params))) doesn't match " * - "the model size ($(length(flattened_param_vals))).", - ), - ) - - # Update values that are provided. - for i in eachindex(initial_params) - x = initial_params[i] - if x !== missing - flattened_param_vals[i] = x - end - end - - # Update in `varinfo`. - new_varinfo = unflatten(varinfo, flattened_param_vals) - return new_varinfo -end - -function set_initial_values(varinfo::AbstractVarInfo, initial_params::NamedTuple) - varinfo = deepcopy(varinfo) - vars_in_varinfo = keys(varinfo) - for v in keys(initial_params) - vn = VarName{v}() - if !(vn in vars_in_varinfo) - for vv in vars_in_varinfo - if subsumes(vn, vv) - throw( - ArgumentError( - "The current model contains sub-variables of $v, such as ($vv). " * - "Using NamedTuple for initial_params is not supported in such a case. " * - "Please use AbstractVector for initial_params instead of NamedTuple.", - ), - ) - end - end - throw(ArgumentError("Variable $v not found in the model.")) - end - end - initial_params = NamedTuple(k => v for (k, v) in pairs(initial_params) if v !== missing) - return update_values!!( - varinfo, initial_params, map(k -> VarName{k}(), keys(initial_params)) - ) -end - -function initialize_parameters!!(vi::AbstractVarInfo, initial_params, model::Model) - @debug "Using passed-in initial variable values" initial_params - - # `link` the varinfo if needed. - linked = islinked(vi) - if linked - vi = invlink!!(vi, model) - end - - # Set the values in `vi`. - vi = set_initial_values(vi, initial_params) - - # `invlink` if needed. - if linked - vi = link!!(vi, model) - end - - return vi -end +default_chain_type(::Sampler) = Any """ initialstep(rng, model, sampler, varinfo; kwargs...) diff --git a/test/sampler.jl b/test/sampler.jl index d04f39ac9..9555c80d6 100644 --- a/test/sampler.jl +++ b/test/sampler.jl @@ -82,7 +82,9 @@ sampler = Sampler(alg) lptrue = logpdf(Binomial(25, 0.2), 10) let inits = (; p=0.2) - chain = sample(model, sampler, 1; initial_params=inits, progress=false) + chain = sample( + model, sampler, 1; initial_params=ParamsInit(inits), progress=false + ) @test chain[1].metadata.p.vals == [0.2] @test getlogjoint(chain[1]) == lptrue @@ -110,7 +112,9 @@ model = twovars() lptrue = logpdf(InverseGamma(2, 3), 4) + logpdf(Normal(0, 2), -1) for inits in ([4, -1], (; s=4, m=-1)) - chain = sample(model, sampler, 1; initial_params=inits, progress=false) + chain = sample( + model, sampler, 1; initial_params=ParamsInit(inits), progress=false + ) @test chain[1].metadata.s.vals == [4] @test chain[1].metadata.m.vals == [-1] @test getlogjoint(chain[1]) == lptrue @@ -122,7 +126,7 @@ MCMCThreads(), 1, 10; - initial_params=fill(inits, 10), + initial_params=fill(ParamsInit(inits), 10), progress=false, ) for c in chains @@ -133,8 +137,10 @@ end # set only m = -1 - for inits in ([missing, -1], (; s=missing, m=-1), (; m=-1)) - chain = sample(model, sampler, 1; initial_params=inits, progress=false) + for inits in ((; s=missing, m=-1), (; m=-1)) + chain = sample( + model, sampler, 1; initial_params=ParamsInit(inits), progress=false + ) @test !ismissing(chain[1].metadata.s.vals[1]) @test chain[1].metadata.m.vals == [-1] @@ -153,54 +159,6 @@ @test c[1].metadata.m.vals == [-1] end end - - # specify `initial_params=nothing` - Random.seed!(1234) - chain1 = sample(model, sampler, 1; progress=false) - Random.seed!(1234) - chain2 = sample(model, sampler, 1; initial_params=nothing, progress=false) - @test_throws DimensionMismatch sample( - model, sampler, 1; progress=false, initial_params=zeros(10) - ) - @test chain1[1].metadata.m.vals == chain2[1].metadata.m.vals - @test chain1[1].metadata.s.vals == chain2[1].metadata.s.vals - - # parallel sampling - Random.seed!(1234) - chains1 = sample(model, sampler, MCMCThreads(), 1, 10; progress=false) - Random.seed!(1234) - chains2 = sample( - model, sampler, MCMCThreads(), 1, 10; initial_params=nothing, progress=false - ) - for (c1, c2) in zip(chains1, chains2) - @test c1[1].metadata.m.vals == c2[1].metadata.m.vals - @test c1[1].metadata.s.vals == c2[1].metadata.s.vals - end - end - - @testset "error handling" begin - # https://github.com/TuringLang/Turing.jl/issues/2452 - @model function constrained_uniform(n) - Z ~ Uniform(10, 20) - X = Vector{Float64}(undef, n) - for i in 1:n - X[i] ~ Uniform(0, Z) - end - end - - n = 2 - initial_z = 15 - initial_x = [0.2, 0.5] - model = constrained_uniform(n) - vi = VarInfo(model) - - @test_throws ArgumentError DynamicPPL.initialize_parameters!!( - vi, [initial_z, initial_x], model - ) - - @test_throws ArgumentError DynamicPPL.initialize_parameters!!( - vi, (X=initial_x, Z=initial_z), model - ) end end end From c7e33e7f43701c2accb0b4f953e185749dd9b60f Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 10 Jul 2025 16:10:00 +0100 Subject: [PATCH 11/20] Paper over the `Sampling->Init` context stack (pending removal of SamplingContext) --- src/context_implementations.jl | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 786d7c913..484345a89 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -28,6 +28,13 @@ end function tilde_assume(rng::Random.AbstractRNG, ::DefaultContext, sampler, right, vn, vi) return assume(rng, sampler, right, vn, vi) end +function tilde_assume(rng::Random.AbstractRNG, ::InitContext, sampler, right, vn, vi) + @warn( + "Encountered SamplingContext->InitContext. This method will be removed in the next PR.", + ) + # just pretend the `InitContext` isn't there for now. + return assume(rng, sampler, right, vn, vi) +end function tilde_assume(::DefaultContext, sampler, right, vn, vi) # same as above but no rng return assume(Random.default_rng(), sampler, right, vn, vi) From 4b3df70df191df259bbf370e1087d24ebea615fc Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 10 Jul 2025 16:25:00 +0100 Subject: [PATCH 12/20] Remove SamplingContext from JETExt to avoid triggering `Sampling->Init` pathway --- ext/DynamicPPLJETExt.jl | 15 +++++---------- test/ext/DynamicPPLJETExt.jl | 5 +++++ 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/ext/DynamicPPLJETExt.jl b/ext/DynamicPPLJETExt.jl index 760d17bb0..89a36ffaf 100644 --- a/ext/DynamicPPLJETExt.jl +++ b/ext/DynamicPPLJETExt.jl @@ -21,22 +21,17 @@ end function DynamicPPL.Experimental._determine_varinfo_jet( model::DynamicPPL.Model; only_ddpl::Bool=true ) - # Use SamplingContext to test type stability. - sampling_model = DynamicPPL.contextualize( - model, DynamicPPL.SamplingContext(model.context) - ) - # First we try with the typed varinfo. - varinfo = DynamicPPL.typed_varinfo(sampling_model) + varinfo = DynamicPPL.typed_varinfo(model) - # Let's make sure that both evaluation and sampling doesn't result in type errors. + # Let's make sure that evaluation doesn't result in type errors. issuccess, result = DynamicPPL.Experimental.is_suitable_varinfo( - sampling_model, varinfo; only_ddpl + model, varinfo; only_ddpl ) if !issuccess # Useful information for debugging. - @debug "Evaluaton with typed varinfo failed with the following issues:" + @debug "Evaluation with typed varinfo failed with the following issues:" @debug result end @@ -46,7 +41,7 @@ function DynamicPPL.Experimental._determine_varinfo_jet( else # Warn the user that we can't use the type stable one. @warn "Model seems incompatible with typed varinfo. Falling back to untyped varinfo." - DynamicPPL.untyped_varinfo(sampling_model) + DynamicPPL.untyped_varinfo(model) end end diff --git a/test/ext/DynamicPPLJETExt.jl b/test/ext/DynamicPPLJETExt.jl index 6737cf056..9f7e05cb0 100644 --- a/test/ext/DynamicPPLJETExt.jl +++ b/test/ext/DynamicPPLJETExt.jl @@ -40,6 +40,11 @@ end end @test DynamicPPL.Experimental.determine_suitable_varinfo(demo4()) isa + DynamicPPL.NTVarInfo + init_model = DynamicPPL.contextualize( + demo4(), DynamicPPL.InitContext(DynamicPPL.PriorInit()) + ) + @test DynamicPPL.Experimental.determine_suitable_varinfo(init_model) isa DynamicPPL.UntypedVarInfo # In this model, the type error occurs in the user code rather than in DynamicPPL. From ef92a4bb9e82e4a2e53fc4f80982b3f9fa56d570 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sat, 26 Jul 2025 19:32:27 +0100 Subject: [PATCH 13/20] Remove `predict` on vector of VarInfo --- src/model.jl | 28 ++-------------------------- test/model.jl | 34 ---------------------------------- 2 files changed, 2 insertions(+), 60 deletions(-) diff --git a/src/model.jl b/src/model.jl index 2f6036122..501ee16a8 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1195,32 +1195,8 @@ function Distributions.loglikelihood(model::Model, chain::AbstractMCMC.AbstractC end end -""" - predict([rng::Random.AbstractRNG,] model::Model, chain::AbstractVector{<:AbstractVarInfo}) - -Generate samples from the posterior predictive distribution by evaluating `model` at each set -of parameter values provided in `chain`. The number of posterior predictive samples matches -the length of `chain`. The returned `AbstractVarInfo`s will contain both the posterior parameter values -and the predicted values. -""" -function predict( - rng::Random.AbstractRNG, model::Model, chain::AbstractArray{<:AbstractVarInfo} -) - varinfo = DynamicPPL.VarInfo(model) - return map(chain) do params_varinfo - vi = deepcopy(varinfo) - # TODO(penelopeysm): Requires two model evaluations, one to extract the - # parameters and one to set them. The reason why we need values_as_in_model - # is because `params_varinfo` may well have some weird combination of - # linked/unlinked, whereas `varinfo` is always unlinked since it is - # freshly constructed. - # This is quite inefficient. It would of course be alright if - # ValuesAsInModelAccumulator was a default acc. - values_nt = values_as_in_model(model, false, params_varinfo) - _, vi = DynamicPPL.init!!(rng, model, vi, ParamsInit(values_nt, PriorInit())) - return vi - end -end +# Implemented & documented in DynamicPPLMCMCChainsExt +function predict end """ returned(model::Model, parameters::NamedTuple) diff --git a/test/model.jl b/test/model.jl index e5b90b0e2..9838a2d2d 100644 --- a/test/model.jl +++ b/test/model.jl @@ -566,40 +566,6 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() end end end - - @testset "with AbstractVector{<:AbstractVarInfo}" begin - @model function linear_reg(x, y, σ=0.1) - β ~ Normal(1, 1) - for i in eachindex(y) - y[i] ~ Normal(β * x[i], σ) - end - end - - ground_truth_β = 2.0 - # the data will be ignored, as we are generating samples from the prior - xs_train = 1:0.1:10 - ys_train = ground_truth_β .* xs_train + rand(Normal(0, 0.1), length(xs_train)) - m_lin_reg = linear_reg(xs_train, ys_train) - chain = [VarInfo(m_lin_reg) for _ in 1:10000] - - # chain is generated from the prior - @test mean([chain[i][@varname(β)] for i in eachindex(chain)]) ≈ 1.0 atol = 0.1 - - xs_test = [10 + 0.1, 10 + 2 * 0.1] - m_lin_reg_test = linear_reg(xs_test, fill(missing, length(xs_test))) - predicted_vis = DynamicPPL.predict(m_lin_reg_test, chain) - - @test size(predicted_vis) == size(chain) - @test Set(keys(predicted_vis[1])) == - Set([@varname(β), @varname(y[1]), @varname(y[2])]) - # because β samples are from the prior, the std will be larger - @test mean([ - predicted_vis[i][@varname(y[1])] for i in eachindex(predicted_vis) - ]) ≈ 1.0 * xs_test[1] rtol = 0.1 - @test mean([ - predicted_vis[i][@varname(y[2])] for i in eachindex(predicted_vis) - ]) ≈ 1.0 * xs_test[2] rtol = 0.1 - end end @testset "ProductNamedTupleDistribution sampling" begin From ec2632b7f4e28463226673f279508f4bb1a36c17 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 20 Jul 2025 12:05:59 +0100 Subject: [PATCH 14/20] Fix some tests --- src/sampler.jl | 4 ++-- test/model.jl | 10 ++++++++-- test/sampler.jl | 24 +++++++++--------------- 3 files changed, 19 insertions(+), 19 deletions(-) diff --git a/src/sampler.jl b/src/sampler.jl index e184c4707..c4d040959 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -59,8 +59,8 @@ function AbstractMCMC.step( ) vi = VarInfo() strategy = sampler isa SampleFromPrior ? PriorInit() : UniformInit() - DynamicPPL.init!!(rng, model, vi, strategy) - return vi, nothing + _, new_vi = DynamicPPL.init!!(rng, model, vi, strategy) + return new_vi, nothing end """ diff --git a/test/model.jl b/test/model.jl index 9838a2d2d..964383c56 100644 --- a/test/model.jl +++ b/test/model.jl @@ -495,7 +495,11 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() # Construct a chain with 'sampled values' of β ground_truth_β = 2 - β_chain = MCMCChains.Chains(rand(Normal(ground_truth_β, 0.002), 1000), [:β]) + β_chain = MCMCChains.Chains( + rand(Normal(ground_truth_β, 0.002), 1000), + [:β]; + info=(; varname_to_symbol=Dict(@varname(β) => :β)), + ) # Generate predictions from that chain xs_test = [10 + 0.1, 10 + 2 * 0.1] @@ -541,7 +545,9 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() @testset "prediction from multiple chains" begin # Normal linreg model multiple_β_chain = MCMCChains.Chains( - reshape(rand(Normal(ground_truth_β, 0.002), 1000, 2), 1000, 1, 2), [:β] + reshape(rand(Normal(ground_truth_β, 0.002), 1000, 2), 1000, 1, 2), + [:β]; + info=(; varname_to_symbol=Dict(@varname(β) => :β)), ) predictions = DynamicPPL.predict(m_lin_reg_test, multiple_β_chain) @test size(multiple_β_chain, 3) == size(predictions, 3) diff --git a/test/sampler.jl b/test/sampler.jl index 9555c80d6..1a61b2f1f 100644 --- a/test/sampler.jl +++ b/test/sampler.jl @@ -25,8 +25,8 @@ @test length(chains) == N # `m` is Gaussian, i.e. no transformation is used, so it - # should have a mean equal to its prior, i.e. 2. - @test mean(vi[@varname(m)] for vi in chains) ≈ 2 atol = 0.1 + # will be drawn from U[-2, 2] and its mean should be 0. + @test mean(vi[@varname(m)] for vi in chains) ≈ 0.0 atol = 0.1 # Expected value of ``exp(X)`` where ``X ~ U[-2, 2]`` is ≈ 1.8. @test mean(vi[@varname(s)] for vi in chains) ≈ 1.8 atol = 0.1 @@ -81,10 +81,8 @@ model = coinflip() sampler = Sampler(alg) lptrue = logpdf(Binomial(25, 0.2), 10) - let inits = (; p=0.2) - chain = sample( - model, sampler, 1; initial_params=ParamsInit(inits), progress=false - ) + let inits = ParamsInit((; p=0.2)) + chain = sample(model, sampler, 1; initial_params=inits, progress=false) @test chain[1].metadata.p.vals == [0.2] @test getlogjoint(chain[1]) == lptrue @@ -111,10 +109,8 @@ end model = twovars() lptrue = logpdf(InverseGamma(2, 3), 4) + logpdf(Normal(0, 2), -1) - for inits in ([4, -1], (; s=4, m=-1)) - chain = sample( - model, sampler, 1; initial_params=ParamsInit(inits), progress=false - ) + let inits = ParamsInit((; s=4, m=-1)) + chain = sample(model, sampler, 1; initial_params=inits, progress=false) @test chain[1].metadata.s.vals == [4] @test chain[1].metadata.m.vals == [-1] @test getlogjoint(chain[1]) == lptrue @@ -126,7 +122,7 @@ MCMCThreads(), 1, 10; - initial_params=fill(ParamsInit(inits), 10), + initial_params=fill(inits, 10), progress=false, ) for c in chains @@ -137,10 +133,8 @@ end # set only m = -1 - for inits in ((; s=missing, m=-1), (; m=-1)) - chain = sample( - model, sampler, 1; initial_params=ParamsInit(inits), progress=false - ) + for inits in (ParamsInit((; s=missing, m=-1)), ParamsInit((; m=-1))) + chain = sample(model, sampler, 1; initial_params=inits, progress=false) @test !ismissing(chain[1].metadata.s.vals[1]) @test chain[1].metadata.m.vals == [-1] From 23cafe0b82cbeea5c6686e8440bffce99b14c723 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 20 Jul 2025 13:08:47 +0100 Subject: [PATCH 15/20] Remove duplicated test --- test/varinfo.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test/varinfo.jl b/test/varinfo.jl index 3fc770c61..3cc547449 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -502,6 +502,11 @@ end @test getlogprior(vi) ≈ Bijectors.logpdf_with_trans(dist, x, false) end + ### `VarInfo` + # Need to run once since we can't specify that we want to _sample_ + # in the unconstrained space for `VarInfo` without having `vn` + # present in the `varinfo`. + ## `untyped_varinfo` vi = DynamicPPL.untyped_varinfo(model) vi = DynamicPPL.settrans!!(vi, true, vn) @@ -512,11 +517,6 @@ end vi = DynamicPPL.settrans!!(vi, true, vn) test_linked_varinfo(model, vi) - ## `typed_varinfo` - vi = DynamicPPL.typed_varinfo(model) - vi = DynamicPPL.settrans!!(vi, true, vn) - test_linked_varinfo(model, vi) - ### `SimpleVarInfo` ## `SimpleVarInfo{<:NamedTuple}` vi = DynamicPPL.settrans!!(SimpleVarInfo(), true) From 707bc4e7084983d020c27affa68795e3baa50651 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 10 Jul 2025 17:57:23 +0100 Subject: [PATCH 16/20] Remove `SamplingContext` for good --- docs/src/api.md | 12 +--- ext/DynamicPPLEnzymeCoreExt.jl | 2 - src/DynamicPPL.jl | 3 - src/context_implementations.jl | 113 ++------------------------------- src/contexts.jl | 69 +------------------- src/debug_utils.jl | 2 +- src/sampler.jl | 45 ------------- src/simple_varinfo.jl | 19 ------ src/utils.jl | 44 ------------- test/Project.toml | 2 - test/ad.jl | 42 ------------ test/contexts.jl | 23 +------ test/debug_utils.jl | 2 +- test/ext/DynamicPPLJETExt.jl | 11 +--- test/lkj.jl | 33 +++------- test/sampler.jl | 54 ---------------- test/threadsafe.jl | 10 ++- 17 files changed, 26 insertions(+), 460 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index 9c132a6c4..5aa024b8c 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -450,12 +450,12 @@ AbstractPPL.evaluate!! This method mutates the `varinfo` used for execution. By default, it does not perform any actual sampling: it only evaluates the model using the values of the variables that are already in the `varinfo`. +If you wish to sample new values, see the section on [VarInfo initialisation](#VarInfo-initialisation) just below this. The behaviour of a model execution can be changed with evaluation contexts, which are a field of the model. Contexts are subtypes of `AbstractPPL.AbstractContext`. ```@docs -SamplingContext DefaultContext PrefixContext ConditionContext @@ -489,15 +489,7 @@ DynamicPPL.init ### Samplers -In DynamicPPL two samplers are defined that are used to initialize unobserved random variables: -[`SampleFromPrior`](@ref) which samples from the prior distribution, and [`SampleFromUniform`](@ref) which samples from a uniform distribution. - -```@docs -SampleFromPrior -SampleFromUniform -``` - -Additionally, a generic sampler for inference is implemented. +In DynamicPPL a generic sampler for inference is implemented. ```@docs Sampler diff --git a/ext/DynamicPPLEnzymeCoreExt.jl b/ext/DynamicPPLEnzymeCoreExt.jl index ceb3f4981..f2d24ad92 100644 --- a/ext/DynamicPPLEnzymeCoreExt.jl +++ b/ext/DynamicPPLEnzymeCoreExt.jl @@ -8,8 +8,6 @@ else using ..EnzymeCore end -@inline EnzymeCore.EnzymeRules.inactive_type(::Type{<:DynamicPPL.SamplingContext}) = true - # Mark istrans as having 0 derivative. The `nothing` return value is not significant, Enzyme # only checks whether such a method exists, and never runs it. @inline EnzymeCore.EnzymeRules.inactive_noinl(::typeof(DynamicPPL.istrans), args...) = diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 6050ce344..a19647eb9 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -97,13 +97,10 @@ export AbstractVarInfo, values_as_in_model, # Samplers Sampler, - SampleFromPrior, - SampleFromUniform, # LogDensityFunction LogDensityFunction, # Contexts contextualize, - SamplingContext, DefaultContext, PrefixContext, ConditionContext, diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 484345a89..e38ffe6e6 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -1,45 +1,14 @@ # assume -""" - tilde_assume(context::SamplingContext, right, vn, vi) - -Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), -accumulate the log probability, and return the sampled value with a context associated -with a sampler. - -Falls back to -```julia -tilde_assume(context.rng, context.context, context.sampler, right, vn, vi) -``` -""" -function tilde_assume(context::SamplingContext, right, vn, vi) - return tilde_assume(context.rng, context.context, context.sampler, right, vn, vi) -end - function tilde_assume(context::AbstractContext, args...) return tilde_assume(childcontext(context), args...) end function tilde_assume(::DefaultContext, right, vn, vi) - return assume(right, vn, vi) -end - -function tilde_assume(rng::Random.AbstractRNG, context::AbstractContext, args...) - return tilde_assume(rng, childcontext(context), args...) -end -function tilde_assume(rng::Random.AbstractRNG, ::DefaultContext, sampler, right, vn, vi) - return assume(rng, sampler, right, vn, vi) -end -function tilde_assume(rng::Random.AbstractRNG, ::InitContext, sampler, right, vn, vi) - @warn( - "Encountered SamplingContext->InitContext. This method will be removed in the next PR.", - ) - # just pretend the `InitContext` isn't there for now. - return assume(rng, sampler, right, vn, vi) -end -function tilde_assume(::DefaultContext, sampler, right, vn, vi) - # same as above but no rng - return assume(Random.default_rng(), sampler, right, vn, vi) + y = getindex_internal(vi, vn) + f = from_maybe_linked_internal_transform(vi, vn, right) + x, inv_logjac = with_logabsdet_jacobian(f, y) + vi = accumulate_assume!!(vi, x, -inv_logjac, vn, right) + return x, vi end - function tilde_assume(context::PrefixContext, right, vn, vi) # Note that we can't use something like this here: # new_vn = prefix(context, vn) @@ -53,12 +22,6 @@ function tilde_assume(context::PrefixContext, right, vn, vi) new_vn, new_context = prefix_and_strip_contexts(context, vn) return tilde_assume(new_context, right, new_vn, vi) end -function tilde_assume( - rng::Random.AbstractRNG, context::PrefixContext, sampler, right, vn, vi -) - new_vn, new_context = prefix_and_strip_contexts(context, vn) - return tilde_assume(rng, new_context, sampler, right, new_vn, vi) -end """ tilde_assume!!(context, right, vn, vi) @@ -78,17 +41,6 @@ function tilde_assume!!(context, right, vn, vi) end # observe -""" - tilde_observe!!(context::SamplingContext, right, left, vi) - -Handle observed constants with a `context` associated with a sampler. - -Falls back to `tilde_observe!!(context.context, right, left, vi)`. -""" -function tilde_observe!!(context::SamplingContext, right, left, vn, vi) - return tilde_observe!!(context.context, right, left, vn, vi) -end - function tilde_observe!!(context::AbstractContext, right, left, vn, vi) return tilde_observe!!(childcontext(context), right, left, vn, vi) end @@ -121,58 +73,3 @@ function tilde_observe!!(::DefaultContext, right, left, vn, vi) vi = accumulate_observe!!(vi, right, left, vn) return left, vi end - -function assume(::Random.AbstractRNG, spl::Sampler, dist) - return error("DynamicPPL.assume: unmanaged inference algorithm: $(typeof(spl))") -end - -# fallback without sampler -function assume(dist::Distribution, vn::VarName, vi) - y = getindex_internal(vi, vn) - f = from_maybe_linked_internal_transform(vi, vn, dist) - x, inv_logjac = with_logabsdet_jacobian(f, y) - vi = accumulate_assume!!(vi, x, -inv_logjac, vn, dist) - return x, vi -end - -# TODO: Remove this thing. -# SampleFromPrior and SampleFromUniform -function assume( - rng::Random.AbstractRNG, - sampler::Union{SampleFromPrior,SampleFromUniform}, - dist::Distribution, - vn::VarName, - vi::VarInfoOrThreadSafeVarInfo, -) - if haskey(vi, vn) - # Always overwrite the parameters with new ones for `SampleFromUniform`. - if sampler isa SampleFromUniform || is_flagged(vi, vn, "del") - # TODO(mhauru) Is it important to unset the flag here? The `true` allows us - # to ignore the fact that for VarNamedVector this does nothing, but I'm unsure - # if that's okay. - unset_flag!(vi, vn, "del", true) - r = init(rng, dist, sampler) - f = to_maybe_linked_internal_transform(vi, vn, dist) - # TODO(mhauru) This should probably be call a function called setindex_internal! - vi = BangBang.setindex!!(vi, f(r), vn) - else - # Otherwise we just extract it. - r = vi[vn, dist] - end - else - r = init(rng, dist, sampler) - if istrans(vi) - f = to_linked_internal_transform(vi, vn, dist) - vi = push!!(vi, vn, f(r), dist) - # By default `push!!` sets the transformed flag to `false`. - vi = settrans!!(vi, true, vn) - else - vi = push!!(vi, vn, r, dist) - end - end - - # HACK: The above code might involve an `invlink` somewhere, etc. so we need to correct. - logjac = logabsdetjac(istrans(vi, vn) ? link_transform(dist) : identity, r) - vi = accumulate_assume!!(vi, r, logjac, vn, dist) - return r, vi -end diff --git a/src/contexts.jl b/src/contexts.jl index cd9876768..8b5e866d0 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -47,7 +47,7 @@ effectively updating the child context. ```jldoctest julia> using DynamicPPL: DynamicTransformationContext -julia> ctx = SamplingContext(); +julia> ctx = ConditionContext((; a = 1)); julia> DynamicPPL.childcontext(ctx) DefaultContext() @@ -121,73 +121,6 @@ setleafcontext(::IsLeaf, ::IsParent, left, right) = right setleafcontext(::IsLeaf, ::IsLeaf, left, right) = right # Contexts -""" - SamplingContext( - [rng::Random.AbstractRNG=Random.default_rng()], - [sampler::AbstractSampler=SampleFromPrior()], - [context::AbstractContext=DefaultContext()], - ) - -Create a context that allows you to sample parameters with the `sampler` when running the model. -The `context` determines how the returned log density is computed when running the model. - -See also: [`DefaultContext`](@ref) -""" -struct SamplingContext{S<:AbstractSampler,C<:AbstractContext,R} <: AbstractContext - rng::R - sampler::S - context::C -end - -function SamplingContext( - rng::Random.AbstractRNG=Random.default_rng(), sampler::AbstractSampler=SampleFromPrior() -) - return SamplingContext(rng, sampler, DefaultContext()) -end - -function SamplingContext( - sampler::AbstractSampler, context::AbstractContext=DefaultContext() -) - return SamplingContext(Random.default_rng(), sampler, context) -end - -function SamplingContext(rng::Random.AbstractRNG, context::AbstractContext) - return SamplingContext(rng, SampleFromPrior(), context) -end - -function SamplingContext(context::AbstractContext) - return SamplingContext(Random.default_rng(), SampleFromPrior(), context) -end - -NodeTrait(context::SamplingContext) = IsParent() -childcontext(context::SamplingContext) = context.context -function setchildcontext(parent::SamplingContext, child) - return SamplingContext(parent.rng, parent.sampler, child) -end - -""" - hassampler(context) - -Return `true` if `context` has a sampler. -""" -hassampler(::SamplingContext) = true -hassampler(context::AbstractContext) = hassampler(NodeTrait(context), context) -hassampler(::IsLeaf, context::AbstractContext) = false -hassampler(::IsParent, context::AbstractContext) = hassampler(childcontext(context)) - -""" - getsampler(context) - -Return the sampler of the context `context`. - -This will traverse the context tree until it reaches the first [`SamplingContext`](@ref), -at which point it will return the sampler of that context. -""" -getsampler(context::SamplingContext) = context.sampler -getsampler(context::AbstractContext) = getsampler(NodeTrait(context), context) -getsampler(::IsParent, context::AbstractContext) = getsampler(childcontext(context)) -getsampler(::IsLeaf, ::AbstractContext) = error("No sampler found in context") - """ struct DefaultContext <: AbstractContext end diff --git a/src/debug_utils.jl b/src/debug_utils.jl index d71fa57cc..ef9e1b8cf 100644 --- a/src/debug_utils.jl +++ b/src/debug_utils.jl @@ -475,7 +475,7 @@ and checking if the model is consistent across runs. function has_static_constraints( rng::Random.AbstractRNG, model::Model; num_evals::Int=5, error_on_failure::Bool=false ) - new_model = DynamicPPL.contextualize(model, SamplingContext(rng, SampleFromPrior())) + new_model = DynamicPPL.contextualize(model, InitContext(rng)) results = map(1:num_evals) do _ check_model_and_trace(new_model, VarInfo(); error_on_failure=error_on_failure) end diff --git a/src/sampler.jl b/src/sampler.jl index c4d040959..711865008 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -1,34 +1,3 @@ -# TODO: Make `UniformSampling` and `Prior` algs + just use `Sampler` -# That would let us use all defaults for Sampler, combine it with other samplers etc. -""" - SampleFromUniform - -Sampling algorithm that samples unobserved random variables from a uniform distribution. - -# References - -[Stan reference manual](https://mc-stan.org/docs/2_28/reference-manual/initialization.html#random-initial-values) -""" -struct SampleFromUniform <: AbstractSampler end - -""" - SampleFromPrior - -Sampling algorithm that samples unobserved random variables from their prior distribution. -""" -struct SampleFromPrior <: AbstractSampler end - -# Initializations. -init(rng, dist, ::SampleFromPrior) = rand(rng, dist) -function init(rng, dist, ::SampleFromUniform) - return istransformable(dist) ? inittrans(rng, dist) : rand(rng, dist) -end - -init(rng, dist, ::SampleFromPrior, n::Int) = rand(rng, dist, n) -function init(rng, dist, ::SampleFromUniform, n::Int) - return istransformable(dist) ? inittrans(rng, dist, n) : rand(rng, dist, n) -end - # TODO(mhauru) Could we get rid of Sampler now that it's just a wrapper around `alg`? # (Selector has been removed). """ @@ -49,20 +18,6 @@ struct Sampler{T} <: AbstractSampler alg::T end -# AbstractMCMC interface for SampleFromUniform and SampleFromPrior -function AbstractMCMC.step( - rng::Random.AbstractRNG, - model::Model, - sampler::Union{SampleFromUniform,SampleFromPrior}, - state=nothing; - kwargs..., -) - vi = VarInfo() - strategy = sampler isa SampleFromPrior ? PriorInit() : UniformInit() - _, new_vi = DynamicPPL.init!!(rng, model, vi, strategy) - return new_vi, nothing -end - """ default_varinfo(rng, model, sampler) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 8d4a8191e..9bb56830d 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -462,25 +462,6 @@ function Base.merge(varinfo_left::SimpleVarInfo, varinfo_right::SimpleVarInfo) return SimpleVarInfo(values, accs, transformation) end -# Context implementations -# NOTE: Evaluations, i.e. those without `rng` are shared with other -# implementations of `AbstractVarInfo`. -function assume( - rng::Random.AbstractRNG, - sampler::Union{SampleFromPrior,SampleFromUniform}, - dist::Distribution, - vn::VarName, - vi::SimpleOrThreadSafeSimple, -) - value = init(rng, dist, sampler) - # Transform if we're working in unconstrained space. - f = to_maybe_linked_internal_transform(vi, vn, dist) - value_raw, logjac = with_logabsdet_jacobian(f, value) - vi = BangBang.push!!(vi, vn, value_raw, dist) - vi = accumulate_assume!!(vi, value, logjac, vn, dist) - return value, vi -end - function settrans!!(vi::SimpleVarInfo, trans) return settrans!!(vi, trans ? DynamicTransformation() : NoTransformation()) end diff --git a/src/utils.jl b/src/utils.jl index d3371271f..c70576a5d 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -456,50 +456,6 @@ function recombine(d::MultivariateDistribution, val::AbstractVector, n::Int) return copy(reshape(val, length(d), n)) end -# Uniform random numbers with range 4 for robust initializations -# Reference: https://mc-stan.org/docs/2_19/reference-manual/initialization.html -randrealuni(rng::Random.AbstractRNG) = 4 * rand(rng) - 2 -randrealuni(rng::Random.AbstractRNG, args...) = 4 .* rand(rng, args...) .- 2 - -istransformable(dist) = link_transform(dist) !== identity - -################################# -# Single-sample initialisations # -################################# - -inittrans(rng, dist::UnivariateDistribution) = Bijectors.invlink(dist, randrealuni(rng)) -function inittrans(rng, dist::MultivariateDistribution) - # Get the length of the unconstrained vector - b = link_transform(dist) - d = Bijectors.output_length(b, length(dist)) - return Bijectors.invlink(dist, randrealuni(rng, d)) -end -function inittrans(rng, dist::MatrixDistribution) - # Get the size of the unconstrained vector - b = link_transform(dist) - sz = Bijectors.output_size(b, size(dist)) - return Bijectors.invlink(dist, randrealuni(rng, sz...)) -end -function inittrans(rng, dist::Distribution{CholeskyVariate}) - # Get the size of the unconstrained vector - b = link_transform(dist) - sz = Bijectors.output_size(b, size(dist)) - return Bijectors.invlink(dist, randrealuni(rng, sz...)) -end -################################ -# Multi-sample initialisations # -################################ - -function inittrans(rng, dist::UnivariateDistribution, n::Int) - return Bijectors.invlink(dist, randrealuni(rng, n)) -end -function inittrans(rng, dist::MultivariateDistribution, n::Int) - return Bijectors.invlink(dist, randrealuni(rng, size(dist)[1], n)) -end -function inittrans(rng, dist::MatrixDistribution, n::Int) - return Bijectors.invlink(dist, [randrealuni(rng, size(dist)...) for _ in 1:n]) -end - ####################### # Convenience methods # ####################### diff --git a/test/Project.toml b/test/Project.toml index 6da3786f5..3bd424237 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -11,7 +11,6 @@ Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" -EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -39,7 +38,6 @@ DifferentiationInterface = "0.6.41, 0.7" Distributions = "0.25" DistributionsAD = "0.6.3" Documenter = "1" -EnzymeCore = "0.6 - 0.8" ForwardDiff = "0.10.12, 1" JET = "0.9, 0.10" LogDensityProblems = "2" diff --git a/test/ad.jl b/test/ad.jl index 371e79b06..23e676ee7 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -77,48 +77,6 @@ using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest end end - @testset "Turing#2151: ReverseDiff compilation & eltype(vi, spl)" begin - # Failing model - t = 1:0.05:8 - σ = 0.3 - y = @. rand(sin(t) + Normal(0, σ)) - @model function state_space(y, TT, ::Type{T}=Float64) where {T} - # Priors - α ~ Normal(y[1], 0.001) - τ ~ Exponential(1) - η ~ filldist(Normal(0, 1), TT - 1) - σ ~ Exponential(1) - # create latent variable - x = Vector{T}(undef, TT) - x[1] = α - for t in 2:TT - x[t] = x[t - 1] + η[t - 1] * τ - end - # measurement model - y ~ MvNormal(x, σ^2 * I) - return x - end - model = state_space(y, length(t)) - - # Dummy sampling algorithm for testing. The test case can only be replicated - # with a custom sampler, it doesn't work with SampleFromPrior(). We need to - # overload assume so that model evaluation doesn't fail due to a lack - # of implementation - struct MyEmptyAlg end - DynamicPPL.assume( - ::Random.AbstractRNG, ::DynamicPPL.Sampler{MyEmptyAlg}, dist, vn, vi - ) = DynamicPPL.assume(dist, vn, vi) - - # Compiling the ReverseDiff tape used to fail here - spl = Sampler(MyEmptyAlg()) - sampling_model = contextualize(model, SamplingContext(model.context)) - ldf = LogDensityFunction( - sampling_model, getlogjoint_internal; adtype=AutoReverseDiff(; compile=true) - ) - x = ldf.varinfo[:] - @test LogDensityProblems.logdensity_and_gradient(ldf, x) isa Any - end - # Test that various different ways of specifying array types as arguments work with all # ADTypes. @testset "Array argument types" begin diff --git a/test/contexts.jl b/test/contexts.jl index a802e0c53..9cff6f953 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -24,8 +24,6 @@ using DynamicPPL: using LinearAlgebra: I using Random: Xoshiro -using EnzymeCore - # TODO: Should we maybe put this in DPPL itself? function Base.iterate(context::AbstractContext) if NodeTrait(context) isa IsLeaf @@ -50,7 +48,6 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() contexts = Dict( :default => DefaultContext(), :testparent => DynamicPPL.TestUtils.TestParentContext(DefaultContext()), - :sampling => SamplingContext(), :prefix => PrefixContext(@varname(x)), :condition1 => ConditionContext((x=1.0,)), :condition2 => ConditionContext( @@ -151,11 +148,11 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() vn = @varname(x[1]) ctx1 = PrefixContext(@varname(a)) @test DynamicPPL.prefix(ctx1, vn) == @varname(a.x[1]) - ctx2 = SamplingContext(ctx1) + ctx2 = ConditionContext(Dict{VarName,Any}(), ctx1) @test DynamicPPL.prefix(ctx2, vn) == @varname(a.x[1]) ctx3 = PrefixContext(@varname(b), ctx2) @test DynamicPPL.prefix(ctx3, vn) == @varname(b.a.x[1]) - ctx4 = DynamicPPL.SamplingContext(ctx3) + ctx4 = FixedContext(Dict(), ctx3) @test DynamicPPL.prefix(ctx4, vn) == @varname(b.a.x[1]) end @@ -204,22 +201,6 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() end end - @testset "SamplingContext" begin - context = SamplingContext(Random.default_rng(), SampleFromPrior(), DefaultContext()) - @test context isa SamplingContext - - # convenience constructors - @test SamplingContext() == context - @test SamplingContext(Random.default_rng()) == context - @test SamplingContext(SampleFromPrior()) == context - @test SamplingContext(DefaultContext()) == context - @test SamplingContext(Random.default_rng(), SampleFromPrior()) == context - @test SamplingContext(Random.default_rng(), DefaultContext()) == context - @test SamplingContext(SampleFromPrior(), DefaultContext()) == context - @test SamplingContext(SampleFromPrior(), DefaultContext()) == context - @test EnzymeCore.EnzymeRules.inactive_type(typeof(context)) - end - @testset "ConditionContext" begin @testset "Nesting" begin @testset "NamedTuple" begin diff --git a/test/debug_utils.jl b/test/debug_utils.jl index 5bf741ff3..f950f6b45 100644 --- a/test/debug_utils.jl +++ b/test/debug_utils.jl @@ -149,7 +149,7 @@ model = demo_missing_in_multivariate([1.0, missing]) # Have to run this check_model call with an empty varinfo, because actually # instantiating the VarInfo would cause it to throw a MethodError. - model = contextualize(model, SamplingContext()) + model = contextualize(model, InitContext()) @test_throws ErrorException check_model(model, VarInfo(); error_on_failure=true) end diff --git a/test/ext/DynamicPPLJETExt.jl b/test/ext/DynamicPPLJETExt.jl index 9f7e05cb0..a820d885e 100644 --- a/test/ext/DynamicPPLJETExt.jl +++ b/test/ext/DynamicPPLJETExt.jl @@ -67,19 +67,14 @@ @testset "demo models" begin @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS - sampling_model = contextualize(model, SamplingContext(model.context)) # Use debug logging below. varinfo = DynamicPPL.Experimental.determine_suitable_varinfo(model) - # Check that the inferred varinfo is indeed suitable for evaluation and sampling + # Check that the inferred varinfo is indeed suitable for evaluation f_eval, argtypes_eval = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( model, varinfo ) JET.test_call(f_eval, argtypes_eval) - f_sample, argtypes_sample = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( - sampling_model, varinfo - ) - JET.test_call(f_sample, argtypes_sample) # For our demo models, they should all result in typed. is_typed = varinfo isa DynamicPPL.NTVarInfo @test is_typed @@ -90,10 +85,6 @@ model, typed_vi ) JET.test_call(f_eval, argtypes_eval) - f_sample, argtypes_sample = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( - sampling_model, typed_vi - ) - JET.test_call(f_sample, argtypes_sample) end end end diff --git a/test/lkj.jl b/test/lkj.jl index d581cd21b..03e744b84 100644 --- a/test/lkj.jl +++ b/test/lkj.jl @@ -16,20 +16,15 @@ end # Same for both distributions target_mean = vec(Matrix{Float64}(I, 2, 2)) +n_samples = 1000 _lkj_atol = 0.05 @testset "Sample from x ~ LKJ(2, 1)" begin model = lkj_prior_demo() - # `SampleFromPrior` will sample in constrained space. - @testset "SampleFromPrior" begin - samples = sample(model, SampleFromPrior(), 1_000; progress=false) - @test mean(map(Base.Fix2(getindex, Colon()), samples)) ≈ target_mean atol = - _lkj_atol - end - - # `SampleFromUniform` will sample in unconstrained space. - @testset "SampleFromUniform" begin - samples = sample(model, SampleFromUniform(), 1_000; progress=false) + for init_strategy in [PriorInit(), UniformInit()] + samples = [ + last(DynamicPPL.init!!(model, VarInfo(), init_strategy)) for _ in 1:n_samples + ] @test mean(map(Base.Fix2(getindex, Colon()), samples)) ≈ target_mean atol = _lkj_atol end @@ -38,20 +33,10 @@ end @testset "Sample from x ~ LKJCholesky(2, 1, $(uplo))" for uplo in ['U', 'L'] model = lkj_chol_prior_demo(uplo) # `SampleFromPrior` will sample in unconstrained space. - @testset "SampleFromPrior" begin - samples = sample(model, SampleFromPrior(), 1_000; progress=false) - # Build correlation matrix from factor - corr_matrices = map(samples) do s - M = reshape(s.metadata.vals, (2, 2)) - pd_from_triangular(M, uplo) - end - @test vec(mean(corr_matrices)) ≈ target_mean atol = _lkj_atol - end - - # `SampleFromUniform` will sample in unconstrained space. - @testset "SampleFromUniform" begin - samples = sample(model, SampleFromUniform(), 1_000; progress=false) - # Build correlation matrix from factor + for init_strategy in [PriorInit(), UniformInit()] + samples = [ + last(DynamicPPL.init!!(model, VarInfo(), init_strategy)) for _ in 1:n_samples + ] corr_matrices = map(samples) do s M = reshape(s.metadata.vals, (2, 2)) pd_from_triangular(M, uplo) diff --git a/test/sampler.jl b/test/sampler.jl index 1a61b2f1f..5b6a623e8 100644 --- a/test/sampler.jl +++ b/test/sampler.jl @@ -1,58 +1,4 @@ @testset "sampler.jl" begin - @testset "SampleFromPrior and SampleUniform" begin - @model function gdemo(x, y) - s ~ InverseGamma(2, 3) - m ~ Normal(2.0, sqrt(s)) - x ~ Normal(m, sqrt(s)) - return y ~ Normal(m, sqrt(s)) - end - - model = gdemo(1.0, 2.0) - N = 1_000 - - chains = sample(model, SampleFromPrior(), N; progress=false) - @test chains isa Vector{<:VarInfo} - @test length(chains) == N - - # Expected value of ``X`` where ``X ~ N(2, ...)`` is 2. - @test mean(vi[@varname(m)] for vi in chains) ≈ 2 atol = 0.15 - - # Expected value of ``X`` where ``X ~ IG(2, 3)`` is 3. - @test mean(vi[@varname(s)] for vi in chains) ≈ 3 atol = 0.2 - - chains = sample(model, SampleFromUniform(), N; progress=false) - @test chains isa Vector{<:VarInfo} - @test length(chains) == N - - # `m` is Gaussian, i.e. no transformation is used, so it - # will be drawn from U[-2, 2] and its mean should be 0. - @test mean(vi[@varname(m)] for vi in chains) ≈ 0.0 atol = 0.1 - - # Expected value of ``exp(X)`` where ``X ~ U[-2, 2]`` is ≈ 1.8. - @test mean(vi[@varname(s)] for vi in chains) ≈ 1.8 atol = 0.1 - end - - @testset "init" begin - @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS - N = 1000 - chain_init = sample(model, SampleFromUniform(), N; progress=false) - - for vn in keys(first(chain_init)) - if AbstractPPL.subsumes(@varname(s), vn) - # `s ~ InverseGamma(2, 3)` and its unconstrained value will be sampled from Unif[-2,2]. - dist = InverseGamma(2, 3) - b = DynamicPPL.link_transform(dist) - @test mean(mean(b(vi[vn])) for vi in chain_init) ≈ 0 atol = 0.11 - elseif AbstractPPL.subsumes(@varname(m), vn) - # `m ~ Normal(0, sqrt(s))` and its constrained value is the same. - @test mean(mean(vi[vn]) for vi in chain_init) ≈ 0 atol = 0.11 - else - error("Unknown variable name: $vn") - end - end - end - end - @testset "Initial parameters" begin # dummy algorithm that just returns initial value and does not perform any sampling abstract type OnlyInitAlg end diff --git a/test/threadsafe.jl b/test/threadsafe.jl index 24a738a78..c86f8da69 100644 --- a/test/threadsafe.jl +++ b/test/threadsafe.jl @@ -68,8 +68,7 @@ @time model(vi) # Ensure that we use `ThreadSafeVarInfo` to handle multithreaded observe statements. - sampling_model = contextualize(model, SamplingContext(model.context)) - DynamicPPL.evaluate_threadsafe!!(sampling_model, vi) + DynamicPPL.evaluate_threadsafe!!(model, vi) @test getlogjoint(vi) ≈ lp_w_threads # check that it's wrapped during the model evaluation @test vi_ isa DynamicPPL.ThreadSafeVarInfo @@ -77,7 +76,7 @@ @test vi isa VarInfo println(" evaluate_threadsafe!!:") - @time DynamicPPL.evaluate_threadsafe!!(sampling_model, vi) + @time DynamicPPL.evaluate_threadsafe!!(model, vi) @model function wothreads(x) global vi_ = __varinfo__ @@ -104,13 +103,12 @@ @test lp_w_threads ≈ lp_wo_threads # Ensure that we use `VarInfo`. - sampling_model = contextualize(model, SamplingContext(model.context)) - DynamicPPL.evaluate_threadunsafe!!(sampling_model, vi) + DynamicPPL.evaluate_threadunsafe!!(model, vi) @test getlogjoint(vi) ≈ lp_w_threads @test vi_ isa VarInfo @test vi isa VarInfo println(" evaluate_threadunsafe!!:") - @time DynamicPPL.evaluate_threadunsafe!!(sampling_model, vi) + @time DynamicPPL.evaluate_threadunsafe!!(model, vi) end end From 13988b5234a7c84f8d2944f8f98d713bba0bc1d4 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 10 Jul 2025 18:05:29 +0100 Subject: [PATCH 17/20] Remove `tilde_assume` as well --- docs/src/api.md | 15 ++++++++------- src/DynamicPPL.jl | 5 +++-- src/context_implementations.jl | 23 ++++++++--------------- src/contexts.jl | 2 +- src/contexts/init.jl | 2 +- src/transforming.jl | 2 +- 6 files changed, 22 insertions(+), 27 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index 5aa024b8c..1a9bac7fe 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -8,7 +8,7 @@ Part of the API of DynamicPPL is defined in the more lightweight interface packa A core component of DynamicPPL is the [`@model`](@ref) macro. It can be used to define probabilistic models in an intuitive way by specifying random variables and their distributions with `~` statements. -These statements are rewritten by `@model` as calls of [internal functions](@ref model_internal) for sampling the variables and computing their log densities. +These statements are rewritten by `@model` as calls of internal functions for sampling the variables and computing their log densities. ```@docs @model @@ -344,6 +344,13 @@ Base.empty! SimpleVarInfo ``` +### Tilde-pipeline + +```@docs +tilde_assume!! +tilde_observe!! +``` + ### Accumulators The subtypes of [`AbstractVarInfo`](@ref) store the cumulative log prior and log likelihood, and sometimes other variables that change during executing, in what are called accumulators. @@ -515,9 +522,3 @@ There is also the _experimental_ [`DynamicPPL.Experimental.determine_suitable_va DynamicPPL.Experimental.determine_suitable_varinfo DynamicPPL.Experimental.is_suitable_varinfo ``` - -### [Model-Internal Functions](@id model_internal) - -```@docs -tilde_assume -``` diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index a19647eb9..b1154142e 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -104,8 +104,9 @@ export AbstractVarInfo, DefaultContext, PrefixContext, ConditionContext, - assume, - tilde_assume, + # Tilde pipeline + tilde_assume!!, + tilde_observe!!, # Initialisation InitContext, AbstractInitStrategy, diff --git a/src/context_implementations.jl b/src/context_implementations.jl index e38ffe6e6..f25b63a64 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -1,18 +1,18 @@ # assume -function tilde_assume(context::AbstractContext, args...) - return tilde_assume(childcontext(context), args...) +function tilde_assume!!(context::AbstractContext, right::Distribution, vn, vi) + return tilde_assume!!(childcontext(context), right, vn, vi) end -function tilde_assume(::DefaultContext, right, vn, vi) +function tilde_assume!!(::DefaultContext, right::Distribution, vn, vi) y = getindex_internal(vi, vn) f = from_maybe_linked_internal_transform(vi, vn, right) x, inv_logjac = with_logabsdet_jacobian(f, y) vi = accumulate_assume!!(vi, x, -inv_logjac, vn, right) return x, vi end -function tilde_assume(context::PrefixContext, right, vn, vi) +function tilde_assume!!(context::PrefixContext, right::Distribution, vn, vi) # Note that we can't use something like this here: # new_vn = prefix(context, vn) - # return tilde_assume(childcontext(context), right, new_vn, vi) + # return tilde_assume!!(childcontext(context), right, new_vn, vi) # This is because `prefix` applies _all_ prefixes in a given context to a # variable name. Thus, if we had two levels of nested prefixes e.g. # `PrefixContext{:a}(PrefixContext{:b}(DefaultContext()))`, then the @@ -20,7 +20,7 @@ function tilde_assume(context::PrefixContext, right, vn, vi) # would apply the prefix `b._`, resulting in `b.a.b._`. # This is why we need a special function, `prefix_and_strip_contexts`. new_vn, new_context = prefix_and_strip_contexts(context, vn) - return tilde_assume(new_context, right, new_vn, vi) + return tilde_assume!!(new_context, right, new_vn, vi) end """ @@ -28,16 +28,9 @@ end Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the sampled value and updated `vi`. - -By default, calls `tilde_assume(context, right, vn, vi)` and accumulates the log -probability of `vi` with the returned value. """ -function tilde_assume!!(context, right, vn, vi) - return if right isa DynamicPPL.Submodel - _evaluate!!(right, vi, context, vn) - else - tilde_assume(context, right, vn, vi) - end +function tilde_assume!!(context, right::DynamicPPL.Submodel, vn, vi) + return _evaluate!!(right, vi, context, vn) end # observe diff --git a/src/contexts.jl b/src/contexts.jl index 8b5e866d0..439da47e5 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -185,7 +185,7 @@ PrefixContexts removed. NOTE: This does _not_ modify any variables in any `ConditionContext` and `FixedContext` that may be present in the context stack. This is because this -function is only used in `tilde_assume`, which is lower in the tilde-pipeline +function is only used in `tilde_assume!!`, which is lower in the tilde-pipeline than `contextual_isassumption` and `contextual_isfixed` (the functions which actually use the `ConditionContext` and `FixedContext` values). Thus, by this time, any `ConditionContext`s and `FixedContext`s present have already served diff --git a/src/contexts/init.jl b/src/contexts/init.jl index baa1bf088..631cc19e6 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -134,7 +134,7 @@ struct InitContext{R<:Random.AbstractRNG,S<:AbstractInitStrategy} <: AbstractCon end NodeTrait(::InitContext) = IsLeaf() -function tilde_assume( +function tilde_assume!!( ctx::InitContext, dist::Distribution, vn::VarName, vi::AbstractVarInfo ) in_varinfo = haskey(vi, vn) diff --git a/src/transforming.jl b/src/transforming.jl index 56f861cff..3569d1502 100644 --- a/src/transforming.jl +++ b/src/transforming.jl @@ -12,7 +12,7 @@ how to do the transformation, used by e.g. `SimpleVarInfo`. struct DynamicTransformationContext{isinverse} <: AbstractContext end NodeTrait(::DynamicTransformationContext) = IsLeaf() -function tilde_assume( +function tilde_assume!!( ::DynamicTransformationContext{isinverse}, right, vn, vi ) where {isinverse} # vi[vn, right] always provides the value in unlinked space. From 331279c1984e9cebe48ebdc6db9e29032848cbc9 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 18 Jul 2025 17:42:13 +0100 Subject: [PATCH 18/20] Split up tilde_observe!! for Distribution / Submodel --- src/context_implementations.jl | 8 +++++--- src/transforming.jl | 4 ++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index f25b63a64..92200582c 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -60,9 +60,11 @@ accumulate the log probability, and return the observed value and updated `vi`. Falls back to `tilde_observe!!(context, right, left, vi)` ignoring the information about variable name and indices; if needed, these can be accessed through this function, though. """ -function tilde_observe!!(::DefaultContext, right, left, vn, vi) - right isa DynamicPPL.Submodel && - throw(ArgumentError("`x ~ to_submodel(...)` is not supported when `x` is observed")) +function tilde_observe!!(::DefaultContext, right::Distribution, left, vn, vi) vi = accumulate_observe!!(vi, right, left, vn) return left, vi end + +function tilde_observe!!(::DefaultContext, ::DynamicPPL.Submodel, left, vn, vi) + throw(ArgumentError("`x ~ to_submodel(...)` is not supported when `x` is observed")) +end diff --git a/src/transforming.jl b/src/transforming.jl index 3569d1502..5465b2ff2 100644 --- a/src/transforming.jl +++ b/src/transforming.jl @@ -13,7 +13,7 @@ struct DynamicTransformationContext{isinverse} <: AbstractContext end NodeTrait(::DynamicTransformationContext) = IsLeaf() function tilde_assume!!( - ::DynamicTransformationContext{isinverse}, right, vn, vi + ::DynamicTransformationContext{isinverse}, right::Distribution, vn, vi ) where {isinverse} # vi[vn, right] always provides the value in unlinked space. x = vi[vn, right] @@ -31,7 +31,7 @@ function tilde_assume!!( return x, vi end -function tilde_observe!!(::DynamicTransformationContext, right, left, vn, vi) +function tilde_observe!!(::DynamicTransformationContext, right::Distribution, left, vn, vi) return tilde_observe!!(DefaultContext(), right, left, vn, vi) end From cf87ce7c04ef0a21fece53296c10e09f8d9c31b8 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 5 Aug 2025 16:57:19 +0100 Subject: [PATCH 19/20] Move `PrefixContext` to a model field --- docs/src/api.md | 1 - src/DynamicPPL.jl | 2 +- src/compiler.jl | 47 ++++---- src/context_implementations.jl | 38 +----- src/contexts.jl | 208 --------------------------------- src/contexts/init.jl | 2 +- src/model.jl | 135 +++++++-------------- src/prefix.jl | 108 +++++++++++++++++ src/submodel.jl | 26 +++-- src/transforming.jl | 2 +- test/contexts.jl | 191 +----------------------------- test/submodels.jl | 49 +++++++- 12 files changed, 246 insertions(+), 563 deletions(-) create mode 100644 src/prefix.jl diff --git a/docs/src/api.md b/docs/src/api.md index 1a9bac7fe..3dd157281 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -464,7 +464,6 @@ Contexts are subtypes of `AbstractPPL.AbstractContext`. ```@docs DefaultContext -PrefixContext ConditionContext InitContext ``` diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index b1154142e..4c2702f17 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -102,7 +102,6 @@ export AbstractVarInfo, # Contexts contextualize, DefaultContext, - PrefixContext, ConditionContext, # Tilde pipeline tilde_assume!!, @@ -177,6 +176,7 @@ include("chains.jl") include("contexts.jl") include("contexts/init.jl") include("model.jl") +include("prefix.jl") include("sampler.jl") include("varname.jl") include("distribution_wrappers.jl") diff --git a/src/compiler.jl b/src/compiler.jl index 6384eaa7c..4266ac9db 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -60,11 +60,14 @@ evaluates to a `VarName`, and this will be used in the subsequent checks. If `vn` is not specified, `AbstractPPL.varname(expr, need_concretize(expr))` will be used in its place. """ -function isassumption(expr::Union{Expr,Symbol}, vn=make_varname_expression(expr)) +function isassumption(expr::Union{Expr,Symbol}, left_vn=make_varname_expression(expr)) + @gensym vn return quote - if $(DynamicPPL.contextual_isassumption)( - __model__.context, $(DynamicPPL.prefix)(__model__.context, $vn) - ) + # TODO(penelopeysm): This re-prefixing seems a bit wasteful. I'd really like + # the whole `isassumption` thing to be simplified, though, so I'll + # leave it till later. + $vn = $(DynamicPPL.maybe_prefix)($left_vn, __model__.prefix) + if $(DynamicPPL.contextual_isassumption)(__model__.context, $vn) # Considered an assumption by `__model__.context` which means either: # 1. We hit the default implementation, e.g. using `DefaultContext`, # which in turn means that we haven't considered if it's one of @@ -78,8 +81,8 @@ function isassumption(expr::Union{Expr,Symbol}, vn=make_varname_expression(expr) # TODO: Support by adding context to model, and use `model.args` # as the default conditioning. Then we no longer need to check `inargnames` # since it will all be handled by `contextual_isassumption`. - if !($(DynamicPPL.inargnames)($vn, __model__)) || - $(DynamicPPL.inmissings)($vn, __model__) + if !($(DynamicPPL.inargnames)($left_vn, __model__)) || + $(DynamicPPL.inmissings)($left_vn, __model__) true else $(maybe_view(expr)) === missing @@ -99,7 +102,7 @@ isassumption(expr) = :(false) Return `true` if `vn` is considered an assumption by `context`. """ -function contextual_isassumption(context::AbstractContext, vn) +function contextual_isassumption(context::AbstractContext, vn::VarName) if hasconditioned_nested(context, vn) val = getconditioned_nested(context, vn) # TODO: Do we even need the `>: Missing`, i.e. does it even help the compiler? @@ -115,9 +118,7 @@ end isfixed(expr, vn) = false function isfixed(::Union{Symbol,Expr}, vn) - return :($(DynamicPPL.contextual_isfixed)( - __model__.context, $(DynamicPPL.prefix)(__model__.context, $vn) - )) + return :($(DynamicPPL.contextual_isfixed)(__model__.context, $vn)) end """ @@ -413,7 +414,9 @@ function generate_assign(left, right) return quote $right_val = $right if $(DynamicPPL.is_extracting_values)(__varinfo__) - $vn = $(DynamicPPL.prefix)(__model__.context, $(make_varname_expression(left))) + $vn = $(DynamicPPL.maybe_prefix)( + $(make_varname_expression(left)), __model__.prefix + ) __varinfo__ = $(map_accumulator!!)( $acc -> push!($acc, $vn, $right_val), __varinfo__, Val(:ValuesAsInModel) ) @@ -448,24 +451,23 @@ function generate_tilde(left, right) # Otherwise it is determined by the model or its value, # if the LHS represents an observation - @gensym vn isassumption value dist + @gensym left_vn vn isassumption value dist return quote $dist = $right - $vn = $(DynamicPPL.resolve_varnames)($(make_varname_expression(left)), $dist) - $isassumption = $(DynamicPPL.isassumption(left, vn)) + $left_vn = $(DynamicPPL.resolve_varnames)($(make_varname_expression(left)), $dist) + $vn = $(DynamicPPL.maybe_prefix)($left_vn, __model__.prefix) + $isassumption = $(DynamicPPL.isassumption(left, left_vn)) if $(DynamicPPL.isfixed(left, vn)) - $left = $(DynamicPPL.getfixed_nested)( - __model__.context, $(DynamicPPL.prefix)(__model__.context, $vn) - ) + $left = $(DynamicPPL.getfixed_nested)(__model__.context, $vn) elseif $isassumption $(generate_tilde_assume(left, dist, vn)) else - # If `vn` is not in `argnames`, we need to make sure that the variable is defined. - if !$(DynamicPPL.inargnames)($vn, __model__) - $left = $(DynamicPPL.getconditioned_nested)( - __model__.context, $(DynamicPPL.prefix)(__model__.context, $vn) - ) + # If `left_vn` is not in `argnames`, we need to make sure that the variable is defined. + # (Note: we use the unprefixed `left_vn` here rather than `vn` which will have had + # prefixes applied!) + if !$(DynamicPPL.inargnames)($left_vn, __model__) + $left = $(DynamicPPL.getconditioned_nested)(__model__.context, $vn) end $value, __varinfo__ = $(DynamicPPL.tilde_observe!!)( @@ -495,6 +497,7 @@ function generate_tilde_assume(left, right, vn) return quote $value, __varinfo__ = $(DynamicPPL.tilde_assume!!)( __model__.context, + __model__.prefix, $(DynamicPPL.unwrap_right_vn)($(DynamicPPL.check_tilde_rhs)($right), $vn)..., __varinfo__, ) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 92200582c..1afad3963 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -1,36 +1,23 @@ # assume -function tilde_assume!!(context::AbstractContext, right::Distribution, vn, vi) - return tilde_assume!!(childcontext(context), right, vn, vi) +function tilde_assume!!(context::AbstractContext, prefix, right::Distribution, vn, vi) + return tilde_assume!!(childcontext(context), prefix, right, vn, vi) end -function tilde_assume!!(::DefaultContext, right::Distribution, vn, vi) +function tilde_assume!!(::DefaultContext, prefix, right::Distribution, vn, vi) y = getindex_internal(vi, vn) f = from_maybe_linked_internal_transform(vi, vn, right) x, inv_logjac = with_logabsdet_jacobian(f, y) vi = accumulate_assume!!(vi, x, -inv_logjac, vn, right) return x, vi end -function tilde_assume!!(context::PrefixContext, right::Distribution, vn, vi) - # Note that we can't use something like this here: - # new_vn = prefix(context, vn) - # return tilde_assume!!(childcontext(context), right, new_vn, vi) - # This is because `prefix` applies _all_ prefixes in a given context to a - # variable name. Thus, if we had two levels of nested prefixes e.g. - # `PrefixContext{:a}(PrefixContext{:b}(DefaultContext()))`, then the - # first call would apply the prefix `a.b._`, and the recursive call - # would apply the prefix `b._`, resulting in `b.a.b._`. - # This is why we need a special function, `prefix_and_strip_contexts`. - new_vn, new_context = prefix_and_strip_contexts(context, vn) - return tilde_assume!!(new_context, right, new_vn, vi) -end """ - tilde_assume!!(context, right, vn, vi) + tilde_assume!!(context, prefix, right, vn, vi) Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the sampled value and updated `vi`. """ -function tilde_assume!!(context, right::DynamicPPL.Submodel, vn, vi) - return _evaluate!!(right, vi, context, vn) +function tilde_assume!!(context, prefix, right::DynamicPPL.Submodel, vn, vi) + return _evaluate!!(right, vi, context, prefix, vn) end # observe @@ -38,19 +25,6 @@ function tilde_observe!!(context::AbstractContext, right, left, vn, vi) return tilde_observe!!(childcontext(context), right, left, vn, vi) end -# `PrefixContext` -function tilde_observe!!(context::PrefixContext, right, left, vn, vi) - # In the observe case, unlike assume, `vn` may be `nothing` if the LHS is a literal - # value. For the need for prefix_and_strip_contexts rather than just prefix, see the - # comment in `tilde_assume!!`. - new_vn, new_context = if vn !== nothing - prefix_and_strip_contexts(context, vn) - else - vn, childcontext(context) - end - return tilde_observe!!(new_context, right, left, new_vn, vi) -end - """ tilde_observe!!(context, right, left, vn, vi) diff --git a/src/contexts.jl b/src/contexts.jl index 439da47e5..0679ed7e3 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -130,89 +130,6 @@ when running the model. struct DefaultContext <: AbstractContext end NodeTrait(::DefaultContext) = IsLeaf() -""" - PrefixContext(vn::VarName[, context::AbstractContext]) - PrefixContext(vn::Val{sym}[, context::AbstractContext]) where {sym} - -Create a context that allows you to use the wrapped `context` when running the model and -prefixes all parameters with the VarName `vn`. - -`PrefixContext(Val(:a), context)` is equivalent to `PrefixContext(@varname(a), context)`. -If `context` is not provided, it defaults to `DefaultContext()`. - -This context is useful in nested models to ensure that the names of the parameters are -unique. - -See also: [`to_submodel`](@ref) -""" -struct PrefixContext{Tvn<:VarName,C<:AbstractContext} <: AbstractContext - vn_prefix::Tvn - context::C -end -PrefixContext(vn::VarName) = PrefixContext(vn, DefaultContext()) -function PrefixContext(::Val{sym}, context::AbstractContext) where {sym} - return PrefixContext(VarName{sym}(), context) -end -PrefixContext(::Val{sym}) where {sym} = PrefixContext(VarName{sym}()) - -NodeTrait(::PrefixContext) = IsParent() -childcontext(context::PrefixContext) = context.context -function setchildcontext(ctx::PrefixContext, child::AbstractContext) - return PrefixContext(ctx.vn_prefix, child) -end - -""" - prefix(ctx::AbstractContext, vn::VarName) - -Apply the prefixes in the context `ctx` to the variable name `vn`. -""" -function prefix(ctx::PrefixContext, vn::VarName) - return AbstractPPL.prefix(prefix(childcontext(ctx), vn), ctx.vn_prefix) -end -function prefix(ctx::AbstractContext, vn::VarName) - return prefix(NodeTrait(ctx), ctx, vn) -end -prefix(::IsLeaf, ::AbstractContext, vn::VarName) = vn -function prefix(::IsParent, ctx::AbstractContext, vn::VarName) - return prefix(childcontext(ctx), vn) -end - -""" - prefix_and_strip_contexts(ctx::PrefixContext, vn::VarName) - -Same as `prefix`, but additionally returns a new context stack that has all the -PrefixContexts removed. - -NOTE: This does _not_ modify any variables in any `ConditionContext` and -`FixedContext` that may be present in the context stack. This is because this -function is only used in `tilde_assume!!`, which is lower in the tilde-pipeline -than `contextual_isassumption` and `contextual_isfixed` (the functions which -actually use the `ConditionContext` and `FixedContext` values). Thus, by this -time, any `ConditionContext`s and `FixedContext`s present have already served -their purpose. - -If you call this function, you must therefore be careful to ensure that you _do -not_ need to modify any inner `ConditionContext`s and `FixedContext`s. If you -_do_ need to modify them, then you may need to use -`prefix_cond_and_fixed_variables` instead. -""" -function prefix_and_strip_contexts(ctx::PrefixContext, vn::VarName) - child_context = childcontext(ctx) - # vn_prefixed contains the prefixes from all lower levels - vn_prefixed, child_context_without_prefixes = prefix_and_strip_contexts( - child_context, vn - ) - return AbstractPPL.prefix(vn_prefixed, ctx.vn_prefix), child_context_without_prefixes -end -function prefix_and_strip_contexts(ctx::AbstractContext, vn::VarName) - return prefix_and_strip_contexts(NodeTrait(ctx), ctx, vn) -end -prefix_and_strip_contexts(::IsLeaf, ctx::AbstractContext, vn::VarName) = (vn, ctx) -function prefix_and_strip_contexts(::IsParent, ctx::AbstractContext, vn::VarName) - vn, new_ctx = prefix_and_strip_contexts(childcontext(ctx), vn) - return vn, setchildcontext(ctx, new_ctx) -end - """ ConditionContext{Values<:Union{NamedTuple,AbstractDict},Ctx<:AbstractContext} @@ -298,9 +215,6 @@ hasconditioned_nested(::IsLeaf, context, vn) = hasconditioned(context, vn) function hasconditioned_nested(::IsParent, context, vn) return hasconditioned(context, vn) || hasconditioned_nested(childcontext(context), vn) end -function hasconditioned_nested(context::PrefixContext, vn) - return hasconditioned_nested(collapse_prefix_stack(context), vn) -end """ getconditioned_nested(context, vn) @@ -316,9 +230,6 @@ end function getconditioned_nested(::IsLeaf, context, vn) return error("context $(context) does not contain value for $vn") end -function getconditioned_nested(context::PrefixContext, vn) - return getconditioned_nested(collapse_prefix_stack(context), vn) -end function getconditioned_nested(::IsParent, context, vn) return if hasconditioned(context, vn) getconditioned(context, vn) @@ -387,9 +298,6 @@ function conditioned(context::ConditionContext) # precedence over decendants of `context`. return _merge(context.values, conditioned(childcontext(context))) end -function conditioned(context::PrefixContext) - return conditioned(collapse_prefix_stack(context)) -end struct FixedContext{Values,Ctx<:AbstractContext} <: AbstractContext values::Values @@ -452,9 +360,6 @@ hasfixed_nested(::IsLeaf, context, vn) = hasfixed(context, vn) function hasfixed_nested(::IsParent, context, vn) return hasfixed(context, vn) || hasfixed_nested(childcontext(context), vn) end -function hasfixed_nested(context::PrefixContext, vn) - return hasfixed_nested(collapse_prefix_stack(context), vn) -end """ getfixed_nested(context, vn) @@ -470,9 +375,6 @@ end function getfixed_nested(::IsLeaf, context, vn) return error("context $(context) does not contain value for $vn") end -function getfixed_nested(context::PrefixContext, vn) - return getfixed_nested(collapse_prefix_stack(context), vn) -end function getfixed_nested(::IsParent, context, vn) return if hasfixed(context, vn) getfixed(context, vn) @@ -566,113 +468,3 @@ function fixed(context::FixedContext) # precedence over decendants of `context`. return _merge(context.values, fixed(childcontext(context))) end -function fixed(context::PrefixContext) - return fixed(collapse_prefix_stack(context)) -end - -""" - collapse_prefix_stack(context::AbstractContext) - -Apply `PrefixContext`s to any conditioned or fixed values inside them, and remove -the `PrefixContext`s from the context stack. - -!!! note - If you are reading this docstring, you might probably be interested in a more -thorough explanation of how PrefixContext and ConditionContext / FixedContext -interact with one another, especially in the context of submodels. - The DynamicPPL documentation contains [a separate page on this -topic](https://turinglang.org/DynamicPPL.jl/previews/PR892/internals/submodel_condition/) -which explains this in much more detail. - -```jldoctest -julia> using DynamicPPL: collapse_prefix_stack - -julia> c1 = PrefixContext(@varname(a), ConditionContext((x=1, ))); - -julia> collapse_prefix_stack(c1) -ConditionContext(Dict(a.x => 1), DefaultContext()) - -julia> # Here, `x` gets prefixed only with `a`, whereas `y` is prefixed with both. - c2 = PrefixContext(@varname(a), ConditionContext((x=1, ), PrefixContext(@varname(b), ConditionContext((y=2,))))); - -julia> collapsed = collapse_prefix_stack(c2); - -julia> # `collapsed` really looks something like this: - # ConditionContext(Dict{VarName{:a}, Int64}(a.b.y => 2, a.x => 1), DefaultContext()) - # To avoid fragility arising from the order of the keys in the doctest, we test - # this indirectly: - collapsed.values[@varname(a.x)], collapsed.values[@varname(a.b.y)] -(1, 2) -``` -""" -function collapse_prefix_stack(context::PrefixContext) - # Collapse the child context (thus applying any inner prefixes first) - collapsed = collapse_prefix_stack(childcontext(context)) - # Prefix any conditioned variables with the current prefix - # Note: prefix_conditioned_variables is O(N) in the depth of the context stack. - # So is this function. In the worst case scenario, this is O(N^2) in the - # depth of the context stack. - return prefix_cond_and_fixed_variables(collapsed, context.vn_prefix) -end -function collapse_prefix_stack(context::AbstractContext) - return collapse_prefix_stack(NodeTrait(collapse_prefix_stack, context), context) -end -collapse_prefix_stack(::IsLeaf, context) = context -function collapse_prefix_stack(::IsParent, context) - new_child_context = collapse_prefix_stack(childcontext(context)) - return setchildcontext(context, new_child_context) -end - -""" - prefix_cond_and_fixed_variables(context::AbstractContext, prefix::VarName) - -Prefix all the conditioned and fixed variables in a given context with a single -`prefix`. - -```jldoctest -julia> using DynamicPPL: prefix_cond_and_fixed_variables, ConditionContext - -julia> c1 = ConditionContext((a=1, )) -ConditionContext((a = 1,), DefaultContext()) - -julia> prefix_cond_and_fixed_variables(c1, @varname(y)) -ConditionContext(Dict(y.a => 1), DefaultContext()) -``` -""" -function prefix_cond_and_fixed_variables(ctx::ConditionContext, prefix::VarName) - # Replace the prefix of the conditioned variables - vn_dict = to_varname_dict(ctx.values) - prefixed_vn_dict = Dict( - AbstractPPL.prefix(vn, prefix) => value for (vn, value) in vn_dict - ) - # Prefix the child context as well - prefixed_child_ctx = prefix_cond_and_fixed_variables(childcontext(ctx), prefix) - return ConditionContext(prefixed_vn_dict, prefixed_child_ctx) -end -function prefix_cond_and_fixed_variables(ctx::FixedContext, prefix::VarName) - # Replace the prefix of the conditioned variables - vn_dict = to_varname_dict(ctx.values) - prefixed_vn_dict = Dict( - AbstractPPL.prefix(vn, prefix) => value for (vn, value) in vn_dict - ) - # Prefix the child context as well - prefixed_child_ctx = prefix_cond_and_fixed_variables(childcontext(ctx), prefix) - return FixedContext(prefixed_vn_dict, prefixed_child_ctx) -end -function prefix_cond_and_fixed_variables(c::AbstractContext, prefix::VarName) - return prefix_cond_and_fixed_variables( - NodeTrait(prefix_cond_and_fixed_variables, c), c, prefix - ) -end -function prefix_cond_and_fixed_variables( - ::IsLeaf, context::AbstractContext, prefix::VarName -) - return context -end -function prefix_cond_and_fixed_variables( - ::IsParent, context::AbstractContext, prefix::VarName -) - return setchildcontext( - context, prefix_cond_and_fixed_variables(childcontext(context), prefix) - ) -end diff --git a/src/contexts/init.jl b/src/contexts/init.jl index 631cc19e6..23b7fa7ab 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -135,7 +135,7 @@ end NodeTrait(::InitContext) = IsLeaf() function tilde_assume!!( - ctx::InitContext, dist::Distribution, vn::VarName, vi::AbstractVarInfo + ctx::InitContext, prefix, dist::Distribution, vn::VarName, vi::AbstractVarInfo ) in_varinfo = haskey(vi, vn) # `init()` always returns values in original space, i.e. possibly diff --git a/src/model.jl b/src/model.jl index 501ee16a8..21f94f8c9 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1,9 +1,19 @@ """ - struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractContext} + struct Model{ + F, + argnames, + defaultnames, + missings, + Targs, + Tdefaults, + Ctx<:AbstractContext, + Tprefix<:Union{Nothing,<:VarName} + } f::F args::NamedTuple{argnames,Targs} defaults::NamedTuple{defaultnames,Tdefaults} context::Ctx=DefaultContext() + prefix::Tprefix=nothing end A `Model` struct with model evaluation function of type `F`, arguments of names `argnames` @@ -33,12 +43,21 @@ julia> Model{(:y,)}(f, (x = 1.0, y = 2.0), (x = 42,)) # with special definition Model{typeof(f),(:x, :y),(:x,),(:y,),Tuple{Float64,Float64},Tuple{Int64}}(f, (x = 1.0, y = 2.0), (x = 42,)) ``` """ -struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractContext} <: - AbstractProbabilisticProgram +struct Model{ + F, + argnames, + defaultnames, + missings, + Targs, + Tdefaults, + Ctx<:AbstractContext, + Tprefix<:Union{Nothing,<:VarName}, +} <: AbstractProbabilisticProgram f::F args::NamedTuple{argnames,Targs} defaults::NamedTuple{defaultnames,Tdefaults} context::Ctx + prefix::Tprefix @doc """ Model{missings}(f, args::NamedTuple, defaults::NamedTuple) @@ -51,9 +70,10 @@ struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractConte args::NamedTuple{argnames,Targs}, defaults::NamedTuple{defaultnames,Tdefaults}, context::Ctx=DefaultContext(), - ) where {missings,F,argnames,Targs,defaultnames,Tdefaults,Ctx} - return new{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx}( - f, args, defaults, context + prefix::Tprefix=nothing, + ) where {missings,F,argnames,Targs,defaultnames,Tdefaults,Ctx,Tprefix} + return new{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx,Tprefix}( + f, args, defaults, context, prefix ) end end @@ -71,18 +91,27 @@ model with different arguments. args::NamedTuple{argnames,Targs}, defaults::NamedTuple{kwargnames,Tkwargs}, context::AbstractContext=DefaultContext(), -) where {F,argnames,Targs,kwargnames,Tkwargs} + prefix::Tprefix=nothing, +) where {F,argnames,Targs,kwargnames,Tkwargs,Tprefix} missing_args = Tuple( name for (name, typ) in zip(argnames, Targs.types) if typ <: Missing ) missing_kwargs = Tuple( name for (name, typ) in zip(kwargnames, Tkwargs.types) if typ <: Missing ) - return :(Model{$(missing_args..., missing_kwargs...)}(f, args, defaults, context)) + return :(Model{$(missing_args..., missing_kwargs...)}( + f, args, defaults, context, prefix + )) end -function Model(f, args::NamedTuple, context::AbstractContext=DefaultContext(); kwargs...) - return Model(f, args, NamedTuple(kwargs), context) +function Model( + f, + args::NamedTuple, + context::AbstractContext=DefaultContext(), + prefix::Union{Nothing,<:VarName}=nothing; + kwargs..., +) + return Model(f, args, NamedTuple(kwargs), context, prefix) end """ @@ -92,7 +121,7 @@ Return a new `Model` with the same evaluation function and other arguments, but with its underlying context set to `context`. """ function contextualize(model::Model, context::AbstractContext) - return Model(model.f, model.args, model.defaults, context) + return Model(model.f, model.args, model.defaults, context, model.prefix) end """ @@ -430,33 +459,6 @@ julia> m = demo(); julia> # Returns all the variables we have conditioned on + their values. conditioned(condition(m, x=100.0, m=1.0)) (x = 100.0, m = 1.0) - -julia> # Nested ones also work. - # (Note that `PrefixContext` also prefixes the variables of any - # ConditionContext that is _inside_ it; because of this, the type of the - # container has to be broadened to a `Dict`.) - cm = condition(contextualize(m, PrefixContext(@varname(a), ConditionContext((m=1.0,)))), x=100.0); - -julia> Set(keys(conditioned(cm))) == Set([@varname(a.m), @varname(x)]) -true - -julia> # Since we conditioned on `a.m`, it is not treated as a random variable. - # However, `a.x` will still be a random variable. - keys(VarInfo(cm)) -1-element Vector{VarName{:a, Accessors.PropertyLens{:x}}}: - a.x - -julia> # We can also condition on `a.m` _outside_ of the PrefixContext: - cm = condition(contextualize(m, PrefixContext(@varname(a))), (@varname(a.m) => 1.0)); - -julia> conditioned(cm) -Dict{VarName{:a, Accessors.PropertyLens{:m}}, Float64} with 1 entry: - a.m => 1.0 - -julia> # Now `a.x` will be sampled. - keys(VarInfo(cm)) -1-element Vector{VarName{:a, Accessors.PropertyLens{:x}}}: - a.x ``` """ conditioned(model::Model) = conditioned(model.context) @@ -773,67 +775,10 @@ julia> m = demo(); julia> # Returns all the variables we have fixed on + their values. fixed(fix(m, x=100.0, m=1.0)) (x = 100.0, m = 1.0) - -julia> # The rest of this is the same as the `condition` example above. - cm = fix(contextualize(m, PrefixContext(@varname(a), fix(m=1.0))), x=100.0); - -julia> Set(keys(fixed(cm))) == Set([@varname(a.m), @varname(x)]) -true - -julia> keys(VarInfo(cm)) -1-element Vector{VarName{:a, Accessors.PropertyLens{:x}}}: - a.x - -julia> # We can also condition on `a.m` _outside_ of the PrefixContext: - cm = fix(contextualize(m, PrefixContext(@varname(a))), (@varname(a.m) => 1.0)); - -julia> fixed(cm) -Dict{VarName{:a, Accessors.PropertyLens{:m}}, Float64} with 1 entry: - a.m => 1.0 - -julia> # Now `a.x` will be sampled. - keys(VarInfo(cm)) -1-element Vector{VarName{:a, Accessors.PropertyLens{:x}}}: - a.x ``` """ fixed(model::Model) = fixed(model.context) -""" - prefix(model::Model, x::VarName) - prefix(model::Model, x::Val{sym}) - prefix(model::Model, x::Any) - -Return `model` but with all random variables prefixed by `x`, where `x` is either: -- a `VarName` (e.g. `@varname(a)`), -- a `Val{sym}` (e.g. `Val(:a)`), or -- for any other type, `x` is converted to a Symbol and then to a `VarName`. Note that - this will introduce runtime overheads so is not recommended unless absolutely - necessary. - -# Examples - -```jldoctest -julia> using DynamicPPL: prefix - -julia> @model demo() = x ~ Dirac(1) -demo (generic function with 2 methods) - -julia> rand(prefix(demo(), @varname(my_prefix))) -(var"my_prefix.x" = 1,) - -julia> rand(prefix(demo(), Val(:my_prefix))) -(var"my_prefix.x" = 1,) -``` -""" -prefix(model::Model, x::VarName) = contextualize(model, PrefixContext(x, model.context)) -function prefix(model::Model, x::Val{sym}) where {sym} - return contextualize(model, PrefixContext(VarName{sym}(), model.context)) -end -function prefix(model::Model, x) - return contextualize(model, PrefixContext(VarName{Symbol(x)}(), model.context)) -end - """ (model::Model)([rng, varinfo]) diff --git a/src/prefix.jl b/src/prefix.jl new file mode 100644 index 000000000..c8f258cac --- /dev/null +++ b/src/prefix.jl @@ -0,0 +1,108 @@ +""" + maybe_prefix(inner::Union{Nothing,<:VarName}, outer::Union{Nothing,<:VarName}) + +Prefix `inner` with the prefix `outer`. Both `inner` and `outer` can be either +`VarName`s or `Nothing`. + +Note that this differs from `AbstractPPL.prefix` in that it handles `nothing` values. +This can happen e.g. when prefixing a model that is not already prefixed; or when +executing submodels without automatic prefixing. +""" +maybe_prefix(inner::VarName, outer::VarName) = AbstractPPL.prefix(inner, outer) +maybe_prefix(vn::VarName, ::Nothing) = vn +maybe_prefix(::Nothing, vn::VarName) = vn +maybe_prefix(::Nothing, ::Nothing) = nothing + +""" + prefix_cond_and_fixed_variables(context::AbstractContext, prefix::VarName) + +Prefix all the conditioned and fixed variables in a given context with a single +`prefix`. + +```jldoctest +julia> using DynamicPPL: prefix_cond_and_fixed_variables, ConditionContext + +julia> c1 = ConditionContext((a=1, )) +ConditionContext((a = 1,), DefaultContext()) + +julia> prefix_cond_and_fixed_variables(c1, @varname(y)) +ConditionContext(Dict(y.a => 1), DefaultContext()) +``` +""" +function prefix_cond_and_fixed_variables(ctx::ConditionContext, prefix::VarName) + # Replace the prefix of the conditioned variables + vn_dict = to_varname_dict(ctx.values) + prefixed_vn_dict = Dict( + AbstractPPL.prefix(vn, prefix) => value for (vn, value) in vn_dict + ) + # Prefix the child context as well + prefixed_child_ctx = prefix_cond_and_fixed_variables(childcontext(ctx), prefix) + return ConditionContext(prefixed_vn_dict, prefixed_child_ctx) +end +function prefix_cond_and_fixed_variables(ctx::FixedContext, prefix::VarName) + # Replace the prefix of the conditioned variables + vn_dict = to_varname_dict(ctx.values) + prefixed_vn_dict = Dict( + AbstractPPL.prefix(vn, prefix) => value for (vn, value) in vn_dict + ) + # Prefix the child context as well + prefixed_child_ctx = prefix_cond_and_fixed_variables(childcontext(ctx), prefix) + return FixedContext(prefixed_vn_dict, prefixed_child_ctx) +end +function prefix_cond_and_fixed_variables(c::AbstractContext, prefix::VarName) + return prefix_cond_and_fixed_variables( + NodeTrait(prefix_cond_and_fixed_variables, c), c, prefix + ) +end +function prefix_cond_and_fixed_variables( + ::IsLeaf, context::AbstractContext, prefix::VarName +) + return context +end +function prefix_cond_and_fixed_variables( + ::IsParent, context::AbstractContext, prefix::VarName +) + return setchildcontext( + context, prefix_cond_and_fixed_variables(childcontext(context), prefix) + ) +end + +""" + DynamicPPL.prefix(model::Model, x::VarName) + DynamicPPL.prefix(model::Model, x::Val{sym}) + DynamicPPL.prefix(model::Model, x::Any) + +Return `model` but with all random variables prefixed by `x`, where `x` is either: +- a `VarName` (e.g. `@varname(a)`), +- a `Val{sym}` (e.g. `Val(:a)`), or +- for any other type, `x` is converted to a Symbol and then to a `VarName`. Note that + this will introduce runtime overheads so is not recommended unless absolutely + necessary. + +If `x` is `nothing`, then the model is returned unchanged. + +# Examples + +```jldoctest +julia> using DynamicPPL: prefix + +julia> @model demo() = x ~ Dirac(1) +demo (generic function with 2 methods) + +julia> rand(prefix(demo(), @varname(my_prefix))) +(var"my_prefix.x" = 1,) + +julia> rand(prefix(demo(), Val(:my_prefix))) +(var"my_prefix.x" = 1,) +``` +""" +prefix(model::Model, ::Nothing) = model +function prefix(model::Model, vn::VarName) + # Add it to the model prefix field + new_prefix = maybe_prefix(model.prefix, vn) + # And also make sure to prefix any conditioned and fixed variables stored in the model + new_context = prefix_cond_and_fixed_variables(model.context, vn) + return Model(model.f, model.args, model.defaults, new_context, new_prefix) +end +prefix(model::Model, ::Val{sym}) where {sym} = prefix(model, VarName{sym}()) +prefix(model::Model, x) = return prefix(model, VarName{Symbol(x)}()) diff --git a/src/submodel.jl b/src/submodel.jl index dcb107bb4..4ed22db1d 100644 --- a/src/submodel.jl +++ b/src/submodel.jl @@ -158,28 +158,32 @@ to_submodel(m::Model, auto_prefix::Bool=true) = Submodel{typeof(m),auto_prefix}( # passed into this function. # # `parent_context` here refers to the context of the model that contains the -# submodel. +# submodel. `parent_prefix` is the prefix that is applied to the parent model. function _evaluate!!( submodel::Submodel{M,AutoPrefix}, vi::AbstractVarInfo, parent_context::AbstractContext, - left_vn::VarName, + parent_prefix::Union{Nothing,<:VarName}, + vn::VarName, ) where {M<:Model,AutoPrefix} # First, we construct the context to be used when evaluating the submodel. There # are several considerations here: - # (1) We need to apply an appropriate PrefixContext when evaluating the submodel, but - # _only_ if automatic prefixing is supposed to be applied. - submodel_context_prefixed = if AutoPrefix - PrefixContext(left_vn, submodel.model.context) + + # (1) Before even touching the contexts, we need to make sure that we apply + # automatic prefixing if it was requested. (If the prefix was manually applied, then + # `prefix()` will have been called by the user, and we don't need to do it again.) + submodel_prefix = if AutoPrefix + # Note that by the time we see it here (in `tilde_assume!!`), `vn` + # has already prefixed with `parent_prefix`, so no need to re-prefix it + vn else - submodel.model.context + parent_prefix end + submodel_model = DynamicPPL.prefix(submodel.model, submodel_prefix) # (2) We need to respect the leaf-context of the parent model. This, unfortunately, # means disregarding the leaf-context of the submodel. - submodel_context = setleafcontext( - submodel_context_prefixed, leafcontext(parent_context) - ) + submodel_context = setleafcontext(submodel_model.context, leafcontext(parent_context)) # (3) We need to use the parent model's context to wrap the whole thing, so that # e.g. if the user conditions the parent model, the conditioned variables will be @@ -187,7 +191,7 @@ function _evaluate!!( eval_context = setleafcontext(parent_context, submodel_context) # (4) Finally, we need to store that context inside the submodel. - model = contextualize(submodel.model, eval_context) + model = contextualize(submodel_model, eval_context) # Once that's all set up nicely, we can just _evaluate!! the wrapped model. This # returns a tuple of submodel.model's return value and the new varinfo. diff --git a/src/transforming.jl b/src/transforming.jl index 5465b2ff2..22493b49b 100644 --- a/src/transforming.jl +++ b/src/transforming.jl @@ -13,7 +13,7 @@ struct DynamicTransformationContext{isinverse} <: AbstractContext end NodeTrait(::DynamicTransformationContext) = IsLeaf() function tilde_assume!!( - ::DynamicTransformationContext{isinverse}, right::Distribution, vn, vi + ::DynamicTransformationContext{isinverse}, prefix, right::Distribution, vn, vi ) where {isinverse} # vi[vn, right] always provides the value in unlinked space. x = vi[vn, right] diff --git a/test/contexts.jl b/test/contexts.jl index 9cff6f953..058b98d14 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -18,9 +18,7 @@ using DynamicPPL: conditioned, fixed, hasconditioned_nested, - getconditioned_nested, - collapse_prefix_stack, - prefix_cond_and_fixed_variables + getconditioned_nested using LinearAlgebra: I using Random: Xoshiro @@ -48,15 +46,10 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() contexts = Dict( :default => DefaultContext(), :testparent => DynamicPPL.TestUtils.TestParentContext(DefaultContext()), - :prefix => PrefixContext(@varname(x)), :condition1 => ConditionContext((x=1.0,)), :condition2 => ConditionContext( (x=1.0,), DynamicPPL.TestUtils.TestParentContext(ConditionContext((y=2.0,))) ), - :condition3 => ConditionContext( - (x=1.0,), - PrefixContext(@varname(a), ConditionContext(Dict(@varname(y) => 2.0))), - ), :condition4 => ConditionContext((x=[1.0, missing],)), ) @@ -118,89 +111,6 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() end end - @testset "PrefixContext" begin - @testset "prefixing" begin - ctx = @inferred PrefixContext( - @varname(a), - PrefixContext( - @varname(b), - PrefixContext( - @varname(c), - PrefixContext( - @varname(d), - PrefixContext( - @varname(e), PrefixContext(@varname(f), DefaultContext()) - ), - ), - ), - ), - ) - vn = @varname(x) - vn_prefixed = @inferred DynamicPPL.prefix(ctx, vn) - @test vn_prefixed == @varname(a.b.c.d.e.f.x) - - vn = @varname(x[1]) - vn_prefixed = @inferred DynamicPPL.prefix(ctx, vn) - @test vn_prefixed == @varname(a.b.c.d.e.f.x[1]) - end - - @testset "nested within arbitrary context stacks" begin - vn = @varname(x[1]) - ctx1 = PrefixContext(@varname(a)) - @test DynamicPPL.prefix(ctx1, vn) == @varname(a.x[1]) - ctx2 = ConditionContext(Dict{VarName,Any}(), ctx1) - @test DynamicPPL.prefix(ctx2, vn) == @varname(a.x[1]) - ctx3 = PrefixContext(@varname(b), ctx2) - @test DynamicPPL.prefix(ctx3, vn) == @varname(b.a.x[1]) - ctx4 = FixedContext(Dict(), ctx3) - @test DynamicPPL.prefix(ctx4, vn) == @varname(b.a.x[1]) - end - - @testset "prefix_and_strip_contexts" begin - vn = @varname(x[1]) - ctx1 = PrefixContext(@varname(a)) - new_vn, new_ctx = DynamicPPL.prefix_and_strip_contexts(ctx1, vn) - @test new_vn == @varname(a.x[1]) - @test new_ctx == DefaultContext() - - ctx2 = FixedContext((b=4,), PrefixContext(@varname(a))) - new_vn, new_ctx = DynamicPPL.prefix_and_strip_contexts(ctx2, vn) - @test new_vn == @varname(a.x[1]) - @test new_ctx == FixedContext((b=4,)) - - ctx3 = PrefixContext(@varname(a), ConditionContext((a=1,))) - new_vn, new_ctx = DynamicPPL.prefix_and_strip_contexts(ctx3, vn) - @test new_vn == @varname(a.x[1]) - @test new_ctx == ConditionContext((a=1,)) - - ctx4 = FixedContext( - (b=4,), PrefixContext(@varname(a), ConditionContext((a=1,))) - ) - new_vn, new_ctx = DynamicPPL.prefix_and_strip_contexts(ctx4, vn) - @test new_vn == @varname(a.x[1]) - @test new_ctx == FixedContext((b=4,), ConditionContext((a=1,))) - end - - @testset "evaluation: $(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS - prefix_vn = @varname(my_prefix) - context = DynamicPPL.PrefixContext(prefix_vn, DefaultContext()) - new_model = contextualize(model, context) - # Initialize a new varinfo with the prefixed model - _, varinfo = DynamicPPL.init!!(new_model, DynamicPPL.VarInfo()) - # Extract the resulting varnames - vns_actual = Set(keys(varinfo)) - - # Extract the ground truth varnames - vns_expected = Set([ - AbstractPPL.prefix(vn, prefix_vn) for - vn in DynamicPPL.TestUtils.varnames(model) - ]) - - # Check that all variables are prefixed correctly. - @test vns_actual == vns_expected - end - end - @testset "ConditionContext" begin @testset "Nesting" begin @testset "NamedTuple" begin @@ -316,105 +226,6 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() end end - @testset "PrefixContext + Condition/FixedContext interactions" begin - @testset "prefix_cond_and_fixed_variables" begin - c1 = ConditionContext((c=1, d=2)) - c1_prefixed = prefix_cond_and_fixed_variables(c1, @varname(a)) - @test c1_prefixed isa ConditionContext - @test childcontext(c1_prefixed) isa DefaultContext - @test c1_prefixed.values[@varname(a.c)] == 1 - @test c1_prefixed.values[@varname(a.d)] == 2 - - c2 = FixedContext((f=1, g=2)) - c2_prefixed = prefix_cond_and_fixed_variables(c2, @varname(a)) - @test c2_prefixed isa FixedContext - @test childcontext(c2_prefixed) isa DefaultContext - @test c2_prefixed.values[@varname(a.f)] == 1 - @test c2_prefixed.values[@varname(a.g)] == 2 - - c3 = ConditionContext((c=1, d=2), FixedContext((f=1, g=2))) - c3_prefixed = prefix_cond_and_fixed_variables(c3, @varname(a)) - c3_prefixed_child = childcontext(c3_prefixed) - @test c3_prefixed isa ConditionContext - @test c3_prefixed.values[@varname(a.c)] == 1 - @test c3_prefixed.values[@varname(a.d)] == 2 - @test c3_prefixed_child isa FixedContext - @test c3_prefixed_child.values[@varname(a.f)] == 1 - @test c3_prefixed_child.values[@varname(a.g)] == 2 - @test childcontext(c3_prefixed_child) isa DefaultContext - end - - @testset "collapse_prefix_stack" begin - # Utility function to make sure that there are no PrefixContexts in - # the context stack. - function has_no_prefixcontexts(ctx::AbstractContext) - return !(ctx isa PrefixContext) && ( - NodeTrait(ctx) isa IsLeaf || has_no_prefixcontexts(childcontext(ctx)) - ) - end - - # Prefix -> Condition - c1 = PrefixContext(@varname(a), ConditionContext((c=1, d=2))) - c1 = collapse_prefix_stack(c1) - @test has_no_prefixcontexts(c1) - c1_vals = conditioned(c1) - @test length(c1_vals) == 2 - @test getvalue(c1_vals, @varname(a.c)) == 1 - @test getvalue(c1_vals, @varname(a.d)) == 2 - - # Condition -> Prefix - c2 = ConditionContext((c=1, d=2), PrefixContext(@varname(a))) - c2 = collapse_prefix_stack(c2) - @test has_no_prefixcontexts(c2) - c2_vals = conditioned(c2) - @test length(c2_vals) == 2 - @test getvalue(c2_vals, @varname(c)) == 1 - @test getvalue(c2_vals, @varname(d)) == 2 - - # Prefix -> Fixed - c3 = PrefixContext(@varname(a), FixedContext((f=1, g=2))) - c3 = collapse_prefix_stack(c3) - c3_vals = fixed(c3) - @test length(c3_vals) == 2 - @test length(c3_vals) == 2 - @test getvalue(c3_vals, @varname(a.f)) == 1 - @test getvalue(c3_vals, @varname(a.g)) == 2 - - # Fixed -> Prefix - c4 = FixedContext((f=1, g=2), PrefixContext(@varname(a))) - c4 = collapse_prefix_stack(c4) - @test has_no_prefixcontexts(c4) - c4_vals = fixed(c4) - @test length(c4_vals) == 2 - @test getvalue(c4_vals, @varname(f)) == 1 - @test getvalue(c4_vals, @varname(g)) == 2 - - # Prefix -> Condition -> Prefix -> Condition - c5 = PrefixContext( - @varname(a), - ConditionContext( - (c=1,), PrefixContext(@varname(b), ConditionContext((d=2,))) - ), - ) - c5 = collapse_prefix_stack(c5) - @test has_no_prefixcontexts(c5) - c5_vals = conditioned(c5) - @test length(c5_vals) == 2 - @test getvalue(c5_vals, @varname(a.c)) == 1 - @test getvalue(c5_vals, @varname(a.b.d)) == 2 - - # Prefix -> Condition -> Prefix -> Fixed - c6 = PrefixContext( - @varname(a), - ConditionContext((c=1,), PrefixContext(@varname(b), FixedContext((d=2,)))), - ) - c6 = collapse_prefix_stack(c6) - @test has_no_prefixcontexts(c6) - @test conditioned(c6) == Dict(@varname(a.c) => 1) - @test fixed(c6) == Dict(@varname(a.b.d) => 2) - end - end - @testset "InitContext" begin empty_varinfos = [ VarInfo(), diff --git a/test/submodels.jl b/test/submodels.jl index 986aea1d0..7463ed0e2 100644 --- a/test/submodels.jl +++ b/test/submodels.jl @@ -135,7 +135,54 @@ end end end - @testset "Nested submodels" begin + @testset "Nested submodels with auto prefix" begin + @model function f() + x ~ Normal() + return y ~ Normal() + end + @model function g() + return b ~ to_submodel(f()) + end + @model function h() + return a ~ to_submodel(g()) + end + + # No conditioning + vi = VarInfo(h()) + @test Set(keys(vi)) == Set([@varname(a.b.x), @varname(a.b.y)]) + @test getlogjoint(vi) == + logpdf(Normal(), vi[@varname(a.b.x)]) + + logpdf(Normal(), vi[@varname(a.b.y)]) + + # Conditioning/fixing at the top level + op_h = op(h(), (@varname(a.b.x) => x_val)) + + # Conditioning/fixing at the second level + op_g = op(g(), (@varname(b.x) => x_val)) + @model function h2() + return a ~ to_submodel(op_g) + end + + # Conditioning/fixing at the very bottom + op_f = op(f(), (@varname(x) => x_val)) + @model function g2() + return _unused ~ to_submodel(prefix(op_f, :b), false) + end + @model function h3() + return a ~ to_submodel(g2()) + end + + models = [("top", op_h), ("middle", h2()), ("bottom", h3())] + @testset "$name" for (name, model) in models + vi = VarInfo(model) + @test Set(keys(vi)) == Set([@varname(a.b.y)]) + @test getlogjoint(vi) == x_logp + logpdf(Normal(), vi[@varname(a.b.y)]) + end + end + + @testset "Nested submodels with manual prefix" begin + # Same tests as above, just that the middle layer has manual prefixing + # rather than automatic. @model function f() x ~ Normal() return y ~ Normal() From c670ef0e57d5b30be2766cb132abbd8587d34174 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 5 Aug 2025 17:10:59 +0100 Subject: [PATCH 20/20] Re-add tests and doctests --- src/model.jl | 44 +++++++++++++++--- test/prefix.jl | 121 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 159 insertions(+), 6 deletions(-) create mode 100644 test/prefix.jl diff --git a/src/model.jl b/src/model.jl index 21f94f8c9..7e4134993 100644 --- a/src/model.jl +++ b/src/model.jl @@ -446,7 +446,7 @@ Return the conditioned values in `model`. ```jldoctest julia> using Distributions -julia> using DynamicPPL: conditioned, contextualize +julia> using DynamicPPL: conditioned julia> @model function demo() m ~ Normal() @@ -456,8 +456,24 @@ demo (generic function with 2 methods) julia> m = demo(); -julia> # Returns all the variables we have conditioned on + their values. - conditioned(condition(m, x=100.0, m=1.0)) +julia> # Condition on some values. + cm = m | (; x = 100.0, m = 1.0); + +julia> # Returns all the variables we have conditioned on, and their values. + conditioned(cm) +(x = 100.0, m = 1.0) + +julia> # If we prefix the model, the conditioned variables will also be prefixed. + pm = prefix(cm, @varname(f)); conditioned(pm) +Dict{VarName{:f}, Float64} with 2 entries: + f.x => 100.0 + f.m => 1.0 + +julia> # If we condition _after_ the prefix, the prefix is not applied. + pm2 = prefix(m, @varname(f)); cm2 = pm2 | (; x = 100.0, m = 1.0); + +julia> # When running this model, the variables inside are not treated as conditioned! + conditioned(cm2) (x = 100.0, m = 1.0) ``` """ @@ -762,7 +778,7 @@ Return the fixed values in `model`. ```jldoctest julia> using Distributions -julia> using DynamicPPL: fixed, contextualize +julia> using DynamicPPL: fixed julia> @model function demo() m ~ Normal() @@ -772,8 +788,24 @@ demo (generic function with 2 methods) julia> m = demo(); -julia> # Returns all the variables we have fixed on + their values. - fixed(fix(m, x=100.0, m=1.0)) +julia> # Fix some values. + fm = fix(m, (; x = 100.0, m = 1.0)); + +julia> # Returns all the variables we have fixed on, and their values. + fixed(fm) +(x = 100.0, m = 1.0) + +julia> # If we prefix the model, the fixed variables will also be prefixed. + pm = prefix(fm, @varname(f)); fixed(pm) +Dict{VarName{:f}, Float64} with 2 entries: + f.x => 100.0 + f.m => 1.0 + +julia> # If we fix _after_ the prefix, the prefix is not applied. + pm2 = prefix(m, @varname(f)); fm2 = fix(pm2, (; x = 100.0, m = 1.0)); + +julia> # When running this model, the variables inside are not treated as fixed! + fixed(fm2) (x = 100.0, m = 1.0) ``` """ diff --git a/test/prefix.jl b/test/prefix.jl new file mode 100644 index 000000000..57065689b --- /dev/null +++ b/test/prefix.jl @@ -0,0 +1,121 @@ +""" +Note that `test/submodel.jl` also contains a number of tests which make use of +prefixing functionality (more like end-to-end tests). This file contains what +are essentially unit tests for prefixing functions. +""" +module DPPLPrefixTests + +using DynamicPPL +# not exported +using DynamicPPL: FixedContext, prefix_cond_and_fixed_variables, childcontext +using Distributions +using Test + +@testset "prefix.jl" begin + @testset "prefix_cond_and_fixed_variables" begin + @testset "ConditionContext" begin + c1 = ConditionContext((c=1, d=2)) + c1_prefixed = prefix_cond_and_fixed_variables(c1, @varname(a)) + @test c1_prefixed isa ConditionContext + @test childcontext(c1_prefixed) isa DefaultContext + @test length(c1_prefixed.values) == 2 + @test c1_prefixed.values[@varname(a.c)] == 1 + @test c1_prefixed.values[@varname(a.d)] == 2 + end + + @testset "FixedContext" begin + c2 = FixedContext((f=1, g=2)) + c2_prefixed = prefix_cond_and_fixed_variables(c2, @varname(a)) + @test c2_prefixed isa FixedContext + @test childcontext(c2_prefixed) isa DefaultContext + @test length(c2_prefixed.values) == 2 + @test c2_prefixed.values[@varname(a.f)] == 1 + @test c2_prefixed.values[@varname(a.g)] == 2 + end + + @testset "Nested ConditionContext and FixedContext" begin + c3 = ConditionContext((c=1, d=2), FixedContext((f=1, g=2))) + c3_prefixed = prefix_cond_and_fixed_variables(c3, @varname(a)) + c3_prefixed_child = childcontext(c3_prefixed) + @test c3_prefixed isa ConditionContext + @test length(c3_prefixed.values) == 2 + @test c3_prefixed.values[@varname(a.c)] == 1 + @test c3_prefixed.values[@varname(a.d)] == 2 + @test c3_prefixed_child isa FixedContext + @test length(c3_prefixed_child.values) == 2 + @test c3_prefixed_child.values[@varname(a.f)] == 1 + @test c3_prefixed_child.values[@varname(a.g)] == 2 + @test childcontext(c3_prefixed_child) isa DefaultContext + end + end + + @testset "DynamicPPL.prefix(::Model, x)" begin + @model function demo() + x ~ Normal() + return y ~ Normal() + end + model = demo() + + @testset "No conditioning / fixing" begin + pmodel = DynamicPPL.prefix(model, @varname(a)) + @test pmodel.prefix == @varname(a) + vi = VarInfo(pmodel) + @test Set(keys(vi)) == Set([@varname(a.x), @varname(a.y)]) + end + + @testset "Prefixing a conditioned model" begin + cmodel = model | (; x=1.0) + # Sanity check. + vi = VarInfo(cmodel) + @test Set(keys(vi)) == Set([@varname(y)]) + # Now prefix. + pcmodel = DynamicPPL.prefix(cmodel, @varname(a)) + @test pcmodel.prefix == @varname(a) + # Because the model was conditioned on `x` _prior_ to prefixing, + # the resulting `a.x` variable should also be conditioned. In + # other words, which variables are treated as conditioned should be + # invariant to prefixing. + vi = VarInfo(pcmodel) + @test Set(keys(vi)) == Set([@varname(a.y)]) + end + + @testset "Prefixing a fixed model" begin + # Same as above but for FixedContext rather than Condition. + fmodel = fix(model, (; y=1.0)) + # Sanity check. + vi = VarInfo(fmodel) + @test Set(keys(vi)) == Set([@varname(x)]) + # Now prefix. + pfmodel = DynamicPPL.prefix(fmodel, @varname(a)) + @test pfmodel.prefix == @varname(a) + # Because the model was conditioned on `x` _prior_ to prefixing, + # the resulting `a.x` variable should also be conditioned. In + # other words, which variables are treated as conditioned should be + # invariant to prefixing. + vi = VarInfo(pfmodel) + @test Set(keys(vi)) == Set([@varname(a.x)]) + end + + @testset "Conditioning a prefixed model" begin + # If the prefixing happens first, then we want to make sure that the + # user is forced to apply conditioning WITH the prefix. + pmodel = DynamicPPL.prefix(model, @varname(a)) + + # If this doesn't happen... + cpmodel_wrong = pmodel | (; x=1.0) + @test cpmodel_wrong.prefix == @varname(a) + vi = VarInfo(cpmodel_wrong) + # Then `a.x` will be `assume`d + @test Set(keys(vi)) == Set([@varname(a.x), @varname(a.y)]) + + # If it does... + cpmodel_right = pmodel | (@varname(a.x) => 1.0) + @test cpmodel_right.prefix == @varname(a) + vi = VarInfo(cpmodel_right) + # Then `a.x` will be `observe`d + @test Set(keys(vi)) == Set([@varname(a.y)]) + end + end +end + +end