From b1cdc2af4ef5b89c83493ae268e53af8f52a6f25 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 8 Aug 2025 11:17:00 +0100 Subject: [PATCH 01/14] Bump minor version --- Project.toml | 2 +- docs/Project.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 1f37515ab..e61adae8c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.37.0" +version = "0.38.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/docs/Project.toml b/docs/Project.toml index 1f01b11ef..124da3315 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -18,7 +18,7 @@ Accessors = "0.1" Distributions = "0.25" Documenter = "1" DocumenterMermaid = "0.1, 0.2" -DynamicPPL = "0.37" +DynamicPPL = "0.38" FillArrays = "0.13, 1" ForwardDiff = "0.10, 1" JET = "0.9, 0.10" From 5a9e9d221a015477701e1a992e5fe8579c66f32a Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 10 Aug 2025 14:44:03 +0100 Subject: [PATCH 02/14] bump benchmarks compat --- benchmarks/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/Project.toml b/benchmarks/Project.toml index 3d14d03ff..cd4545cb9 100644 --- a/benchmarks/Project.toml +++ b/benchmarks/Project.toml @@ -22,7 +22,7 @@ DynamicPPL = {path = "../"} ADTypes = "1.14.0" BenchmarkTools = "1.6.0" Distributions = "0.25.117" -DynamicPPL = "0.37" +DynamicPPL = "0.38" ForwardDiff = "0.10.38, 1" LogDensityProblems = "2.1.2" Mooncake = "0.4" From 7b55aa30437ae044ab8f925053a9d838000ed224 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 13 Aug 2025 16:58:43 +0100 Subject: [PATCH 03/14] add a skeletal changelog --- HISTORY.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/HISTORY.md b/HISTORY.md index c0db1cd5d..87bd2d552 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,5 +1,9 @@ # DynamicPPL Changelog +## 0.38.0 + +[...] + ## 0.37.1 Update DynamicPPLMooncakeExt to work with Mooncake 0.4.147. From 991e825334a5273b477c82fcb753c625063b71ae Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 13 Aug 2025 17:47:03 +0100 Subject: [PATCH 04/14] `InitContext`, part 3 - Introduce `InitContext` (#981) * Implement InitContext * Fix loading order of modules; move `prefix(::Model)` to model.jl * Add tests for InitContext behaviour * inline `rand(::Distributions.Uniform)` Note that, apart from being simpler code, Distributions.Uniform also doesn't allow the lower and upper bounds to be exactly equal (but we might like to keep that option open in DynamicPPL, e.g. if the user wants to initialise all values to the same value in linked space). * Document * Add a test to check that `init!!` doesn't change linking * Fix `push!` for VarNamedVector This should have been changed in #940, but slipped through as the file wasn't listed as one of the changed files. * Add some line breaks Co-authored-by: Markus Hauru * Add the option of no fallback for ParamsInit * Improve docstrings * typo * `p.default` -> `p.fallback` * Rename `{Prior,Uniform,Params}Init` -> `InitFrom{Prior,Uniform,Params}` --------- Co-authored-by: Markus Hauru --- docs/src/api.md | 21 ++++ src/DynamicPPL.jl | 9 +- src/contexts.jl | 35 ------ src/contexts/init.jl | 196 +++++++++++++++++++++++++++++++++ src/model.jl | 70 ++++++++++++ src/varnamedvector.jl | 5 + test/contexts.jl | 251 +++++++++++++++++++++++++++++++++++++++++- 7 files changed, 547 insertions(+), 40 deletions(-) create mode 100644 src/contexts/init.jl diff --git a/docs/src/api.md b/docs/src/api.md index 9a1923b53..c6244b75f 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -463,6 +463,27 @@ SamplingContext DefaultContext PrefixContext ConditionContext +InitContext +``` + +### VarInfo initialisation + +`InitContext` is used to initialise, or overwrite, values in a VarInfo. + +To accomplish this, an initialisation _strategy_ is required, which defines how new values are to be obtained. +There are three concrete strategies provided in DynamicPPL: + +```@docs +InitFromPrior +InitFromUniform +InitFromParams +``` + +If you wish to write your own, you have to subtype [`DynamicPPL.AbstractInitStrategy`](@ref) and implement the `init` method. + +```@docs +DynamicPPL.AbstractInitStrategy +DynamicPPL.init ``` ### Samplers diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index b400e83dd..859c7d49d 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -108,6 +108,12 @@ export AbstractVarInfo, ConditionContext, assume, tilde_assume, + # Initialisation + InitContext, + AbstractInitStrategy, + InitFromPrior, + InitFromUniform, + InitFromParams, # Pseudo distributions NamedDist, NoDist, @@ -169,11 +175,12 @@ abstract type AbstractVarInfo <: AbstractModelTrace end # Necessary forward declarations include("utils.jl") include("chains.jl") +include("contexts.jl") +include("contexts/init.jl") include("model.jl") include("sampler.jl") include("varname.jl") include("distribution_wrappers.jl") -include("contexts.jl") include("submodel.jl") include("varnamedvector.jl") include("accumulators.jl") diff --git a/src/contexts.jl b/src/contexts.jl index addadfa1a..cd9876768 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -280,41 +280,6 @@ function prefix_and_strip_contexts(::IsParent, ctx::AbstractContext, vn::VarName return vn, setchildcontext(ctx, new_ctx) end -""" - prefix(model::Model, x::VarName) - prefix(model::Model, x::Val{sym}) - prefix(model::Model, x::Any) - -Return `model` but with all random variables prefixed by `x`, where `x` is either: -- a `VarName` (e.g. `@varname(a)`), -- a `Val{sym}` (e.g. `Val(:a)`), or -- for any other type, `x` is converted to a Symbol and then to a `VarName`. Note that - this will introduce runtime overheads so is not recommended unless absolutely - necessary. - -# Examples - -```jldoctest -julia> using DynamicPPL: prefix - -julia> @model demo() = x ~ Dirac(1) -demo (generic function with 2 methods) - -julia> rand(prefix(demo(), @varname(my_prefix))) -(var"my_prefix.x" = 1,) - -julia> rand(prefix(demo(), Val(:my_prefix))) -(var"my_prefix.x" = 1,) -``` -""" -prefix(model::Model, x::VarName) = contextualize(model, PrefixContext(x, model.context)) -function prefix(model::Model, x::Val{sym}) where {sym} - return contextualize(model, PrefixContext(VarName{sym}(), model.context)) -end -function prefix(model::Model, x) - return contextualize(model, PrefixContext(VarName{Symbol(x)}(), model.context)) -end - """ ConditionContext{Values<:Union{NamedTuple,AbstractDict},Ctx<:AbstractContext} diff --git a/src/contexts/init.jl b/src/contexts/init.jl new file mode 100644 index 000000000..636847117 --- /dev/null +++ b/src/contexts/init.jl @@ -0,0 +1,196 @@ +""" + AbstractInitStrategy + +Abstract type representing the possible ways of initialising new values for +the random variables in a model (e.g., when creating a new VarInfo). + +Any subtype of `AbstractInitStrategy` must implement the +[`DynamicPPL.init`](@ref) method. +""" +abstract type AbstractInitStrategy end + +""" + init(rng::Random.AbstractRNG, vn::VarName, dist::Distribution, strategy::AbstractInitStrategy) + +Generate a new value for a random variable with the given distribution. + +!!! warning "Return values must be unlinked" + The values returned by `init` must always be in the untransformed space, i.e., + they must be within the support of the original distribution. That means that, + for example, `init(rng, dist, u::InitFromUniform)` will in general return values that + are outside the range [u.lower, u.upper]. +""" +function init end + +""" + InitFromPrior() + +Obtain new values by sampling from the prior distribution. +""" +struct InitFromPrior <: AbstractInitStrategy end +function init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, ::InitFromPrior) + return rand(rng, dist) +end + +""" + InitFromUniform() + InitFromUniform(lower, upper) + +Obtain new values by first transforming the distribution of the random variable +to unconstrained space, then sampling a value uniformly between `lower` and +`upper`, and transforming that value back to the original space. + +If `lower` and `upper` are unspecified, they default to `(-2, 2)`, which mimics +Stan's default initialisation strategy. + +Requires that `lower <= upper`. + +# References + +[Stan reference manual page on initialization](https://mc-stan.org/docs/reference-manual/execution.html#initialization) +""" +struct InitFromUniform{T<:AbstractFloat} <: AbstractInitStrategy + lower::T + upper::T + function InitFromUniform(lower::T, upper::T) where {T<:AbstractFloat} + lower > upper && + throw(ArgumentError("`lower` must be less than or equal to `upper`")) + return new{T}(lower, upper) + end + InitFromUniform() = InitFromUniform(-2.0, 2.0) +end +function init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, u::InitFromUniform) + b = Bijectors.bijector(dist) + sz = Bijectors.output_size(b, size(dist)) + y = u.lower .+ ((u.upper - u.lower) .* rand(rng, sz...)) + b_inv = Bijectors.inverse(b) + x = b_inv(y) + # 0-dim arrays: https://github.com/TuringLang/Bijectors.jl/issues/398 + if x isa Array{<:Any,0} + x = x[] + end + return x +end + +""" + InitFromParams( + params::Union{AbstractDict{<:VarName},NamedTuple}, + fallback::Union{AbstractInitStrategy,Nothing}=InitFromPrior() + ) + +Obtain new values by extracting them from the given dictionary or NamedTuple. + +The parameter `fallback` specifies how new values are to be obtained if they +cannot be found in `params`, or they are specified as `missing`. `fallback` +can either be an initialisation strategy itself, in which case it will be +used to obtain new values, or it can be `nothing`, in which case an error +will be thrown. The default for `fallback` is `InitFromPrior()`. + +!!! note + The values in `params` must be provided in the space of the untransformed + distribution. +""" +struct InitFromParams{P,S<:Union{AbstractInitStrategy,Nothing}} <: AbstractInitStrategy + params::P + fallback::S + function InitFromParams( + params::AbstractDict{<:VarName}, fallback::Union{AbstractInitStrategy,Nothing} + ) + return new{typeof(params),typeof(fallback)}(params, fallback) + end + function InitFromParams(params::AbstractDict{<:VarName}) + return InitFromParams(params, InitFromPrior()) + end + function InitFromParams( + params::NamedTuple, fallback::Union{AbstractInitStrategy,Nothing}=InitFromPrior() + ) + return InitFromParams(to_varname_dict(params), fallback) + end +end +function init(rng::Random.AbstractRNG, vn::VarName, dist::Distribution, p::InitFromParams) + # TODO(penelopeysm): It would be nice to do a check to make sure that all + # of the parameters in `p.params` were actually used, and either warn or + # error if they aren't. This is actually quite non-trivial though because + # the structure of Dicts in particular can have arbitrary nesting. + return if hasvalue(p.params, vn, dist) + x = getvalue(p.params, vn, dist) + if x === missing + p.fallback === nothing && + error("A `missing` value was provided for the variable `$(vn)`.") + init(rng, vn, dist, p.fallback) + else + # TODO(penelopeysm): Since x is user-supplied, maybe we could also + # check here that the type / size of x matches the dist? + x + end + else + p.fallback === nothing && error("No value was provided for the variable `$(vn)`.") + init(rng, vn, dist, p.fallback) + end +end + +""" + InitContext( + [rng::Random.AbstractRNG=Random.default_rng()], + [strategy::AbstractInitStrategy=InitFromPrior()], + ) + +A leaf context that indicates that new values for random variables are +currently being obtained through sampling. Used e.g. when initialising a fresh +VarInfo. Note that, if `leafcontext(model.context) isa InitContext`, then +`evaluate!!(model, varinfo)` will override all values in the VarInfo. +""" +struct InitContext{R<:Random.AbstractRNG,S<:AbstractInitStrategy} <: AbstractContext + rng::R + strategy::S + function InitContext( + rng::Random.AbstractRNG, strategy::AbstractInitStrategy=InitFromPrior() + ) + return new{typeof(rng),typeof(strategy)}(rng, strategy) + end + function InitContext(strategy::AbstractInitStrategy=InitFromPrior()) + return InitContext(Random.default_rng(), strategy) + end +end +NodeTrait(::InitContext) = IsLeaf() + +function tilde_assume( + ctx::InitContext, dist::Distribution, vn::VarName, vi::AbstractVarInfo +) + in_varinfo = haskey(vi, vn) + # `init()` always returns values in original space, i.e. possibly + # constrained + x = init(ctx.rng, vn, dist, ctx.strategy) + # Determine whether to insert a transformed value into the VarInfo. + # If the VarInfo alrady had a value for this variable, we will + # keep the same linked status as in the original VarInfo. If not, we + # check the rest of the VarInfo to see if other variables are linked. + # istrans(vi) returns true if vi is nonempty and all variables in vi + # are linked. + insert_transformed_value = in_varinfo ? istrans(vi, vn) : istrans(vi) + f = if insert_transformed_value + link_transform(dist) + else + identity + end + y, logjac = with_logabsdet_jacobian(f, x) + # Add the new value to the VarInfo. `push!!` errors if the value already + # exists, hence the need for setindex!!. + if in_varinfo + vi = setindex!!(vi, y, vn) + else + vi = push!!(vi, vn, y, dist) + end + # Neither of these set the `trans` flag so we have to do it manually if + # necessary. + insert_transformed_value && settrans!!(vi, true, vn) + # `accumulate_assume!!` wants untransformed values as the second argument. + vi = accumulate_assume!!(vi, x, logjac, vn, dist) + # We always return the untransformed value here, as that will determine + # what the lhs of the tilde-statement is set to. + return x, vi +end + +function tilde_observe!!(::InitContext, right, left, vn, vi) + return tilde_observe!!(DefaultContext(), right, left, vn, vi) +end diff --git a/src/model.jl b/src/model.jl index 9f9c6ec3b..e7a1a864f 100644 --- a/src/model.jl +++ b/src/model.jl @@ -799,6 +799,41 @@ julia> # Now `a.x` will be sampled. """ fixed(model::Model) = fixed(model.context) +""" + prefix(model::Model, x::VarName) + prefix(model::Model, x::Val{sym}) + prefix(model::Model, x::Any) + +Return `model` but with all random variables prefixed by `x`, where `x` is either: +- a `VarName` (e.g. `@varname(a)`), +- a `Val{sym}` (e.g. `Val(:a)`), or +- for any other type, `x` is converted to a Symbol and then to a `VarName`. Note that + this will introduce runtime overheads so is not recommended unless absolutely + necessary. + +# Examples + +```jldoctest +julia> using DynamicPPL: prefix + +julia> @model demo() = x ~ Dirac(1) +demo (generic function with 2 methods) + +julia> rand(prefix(demo(), @varname(my_prefix))) +(var"my_prefix.x" = 1,) + +julia> rand(prefix(demo(), Val(:my_prefix))) +(var"my_prefix.x" = 1,) +``` +""" +prefix(model::Model, x::VarName) = contextualize(model, PrefixContext(x, model.context)) +function prefix(model::Model, x::Val{sym}) where {sym} + return contextualize(model, PrefixContext(VarName{sym}(), model.context)) +end +function prefix(model::Model, x) + return contextualize(model, PrefixContext(VarName{Symbol(x)}(), model.context)) +end + """ (model::Model)([rng, varinfo]) @@ -854,6 +889,41 @@ function evaluate_and_sample!!( return evaluate_and_sample!!(Random.default_rng(), model, varinfo, sampler) end +""" + init!!( + [rng::Random.AbstractRNG,] + model::Model, + varinfo::AbstractVarInfo, + [init_strategy::AbstractInitStrategy=InitFromPrior()] + ) + +Evaluate the `model` and replace the values of the model's random variables in +the given `varinfo` with new values using a specified initialisation strategy. +If the values in `varinfo` are not already present, they will be added using +that same strategy. + +If `init_strategy` is not provided, defaults to InitFromPrior(). + +Returns a tuple of the model's return value, plus the updated `varinfo` object. +""" +function init!!( + rng::Random.AbstractRNG, + model::Model, + varinfo::AbstractVarInfo, + init_strategy::AbstractInitStrategy=InitFromPrior(), +) + new_context = setleafcontext(model.context, InitContext(rng, init_strategy)) + new_model = contextualize(model, new_context) + return evaluate!!(new_model, varinfo) +end +function init!!( + model::Model, + varinfo::AbstractVarInfo, + init_strategy::AbstractInitStrategy=InitFromPrior(), +) + return init!!(Random.default_rng(), model, varinfo, init_strategy) +end + """ evaluate!!(model::Model, varinfo) diff --git a/src/varnamedvector.jl b/src/varnamedvector.jl index d756a4922..2336b89b6 100644 --- a/src/varnamedvector.jl +++ b/src/varnamedvector.jl @@ -766,6 +766,11 @@ function update_internal!( return nothing end +function BangBang.push!(vnv::VarNamedVector, vn, val, dist) + f = from_vec_transform(dist) + return setindex_internal!(vnv, tovec(val), vn, f) +end + # BangBang versions of the above functions. # The only difference is that update_internal!! and insert_internal!! check whether the # container types of the VarNamedVector vector need to be expanded to accommodate the new diff --git a/test/contexts.jl b/test/contexts.jl index 597ab736c..365865e7e 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -1,5 +1,5 @@ using Test, DynamicPPL, Accessors -using AbstractPPL: getoptic +using AbstractPPL: getoptic, hasvalue, getvalue using DynamicPPL: leafcontext, setleafcontext, @@ -20,8 +20,9 @@ using DynamicPPL: hasconditioned_nested, getconditioned_nested, collapse_prefix_stack, - prefix_cond_and_fixed_variables, - getvalue + prefix_cond_and_fixed_variables +using LinearAlgebra: I +using Random: Xoshiro using EnzymeCore @@ -103,7 +104,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() # sometimes only the main symbol (e.g. it contains `x` when # `vn` is `x[1]`) for vn in conditioned_vns - val = DynamicPPL.getvalue(conditioned_values, vn) + val = getvalue(conditioned_values, vn) # These VarNames are present in the conditioning values, so # we should always be able to extract the value. @test hasconditioned_nested(context, vn) @@ -431,4 +432,246 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() @test fixed(c6) == Dict(@varname(a.b.d) => 2) end end + + @testset "InitContext" begin + empty_varinfos = [ + ("untyped+metadata", VarInfo()), + ("typed+metadata", DynamicPPL.typed_varinfo(VarInfo())), + ("untyped+VNV", VarInfo(DynamicPPL.VarNamedVector())), + ( + "typed+VNV", + DynamicPPL.typed_vector_varinfo(DynamicPPL.typed_varinfo(VarInfo())), + ), + ("SVI+NamedTuple", SimpleVarInfo()), + ("Svi+Dict", SimpleVarInfo(Dict{VarName,Any}())), + ] + + @model function test_init_model() + x ~ Normal() + y ~ MvNormal(fill(x, 2), I) + 1.0 ~ Normal() + return nothing + end + + function test_generating_new_values(strategy::AbstractInitStrategy) + @testset "generating new values: $(typeof(strategy))" begin + # Check that init!! can generate values that weren't there + # previously. + model = test_init_model() + @testset "$vi_name" for (vi_name, empty_vi) in empty_varinfos + this_vi = deepcopy(empty_vi) + _, vi = DynamicPPL.init!!(model, this_vi, strategy) + @test Set(keys(vi)) == Set([@varname(x), @varname(y)]) + x, y = vi[@varname(x)], vi[@varname(y)] + @test x isa Real + @test y isa AbstractVector{<:Real} + @test length(y) == 2 + (; logprior, loglikelihood) = getlogp(vi) + @test logpdf(Normal(), x) + logpdf(MvNormal(fill(x, 2), I), y) == + logprior + @test logpdf(Normal(), 1.0) == loglikelihood + end + end + end + + function test_replacing_values(strategy::AbstractInitStrategy) + @testset "replacing old values: $(typeof(strategy))" begin + # Check that init!! can overwrite values that were already there. + model = test_init_model() + @testset "$vi_name" for (vi_name, empty_vi) in empty_varinfos + # start by generating some rubbish values + vi = deepcopy(empty_vi) + old_x, old_y = 100000.00, [300000.00, 500000.00] + push!!(vi, @varname(x), old_x, Normal()) + push!!(vi, @varname(y), old_y, MvNormal(fill(old_x, 2), I)) + # then overwrite it + _, new_vi = DynamicPPL.init!!(model, vi, strategy) + new_x, new_y = new_vi[@varname(x)], new_vi[@varname(y)] + # check that the values are (presumably) different + @test old_x != new_x + @test old_y != new_y + end + end + end + + function test_rng_respected(strategy::AbstractInitStrategy) + @testset "check that RNG is respected: $(typeof(strategy))" begin + model = test_init_model() + @testset "$vi_name" for (vi_name, empty_vi) in empty_varinfos + _, vi1 = DynamicPPL.init!!( + Xoshiro(468), model, deepcopy(empty_vi), strategy + ) + _, vi2 = DynamicPPL.init!!( + Xoshiro(468), model, deepcopy(empty_vi), strategy + ) + _, vi3 = DynamicPPL.init!!( + Xoshiro(469), model, deepcopy(empty_vi), strategy + ) + @test vi1[@varname(x)] == vi2[@varname(x)] + @test vi1[@varname(y)] == vi2[@varname(y)] + @test vi1[@varname(x)] != vi3[@varname(x)] + @test vi1[@varname(y)] != vi3[@varname(y)] + end + end + end + + function test_link_status_respected(strategy::AbstractInitStrategy) + @testset "check that varinfo linking is preserved: $(typeof(strategy))" begin + @model logn() = a ~ LogNormal() + model = logn() + vi = VarInfo(model) + linked_vi = DynamicPPL.link!!(vi, model) + _, new_vi = DynamicPPL.init!!(model, linked_vi, strategy) + @test DynamicPPL.istrans(new_vi) + # this is the unlinked value, since it uses `getindex` + a = new_vi[@varname(a)] + # internal logjoint should correspond to the transformed value + @test isapprox( + DynamicPPL.getlogjoint_internal(new_vi), logpdf(Normal(), log(a)) + ) + # user logjoint should correspond to the transformed value + @test isapprox(DynamicPPL.getlogjoint(new_vi), logpdf(LogNormal(), a)) + @test isapprox( + only(DynamicPPL.getindex_internal(new_vi, @varname(a))), log(a) + ) + end + end + + @testset "InitFromPrior" begin + test_generating_new_values(InitFromPrior()) + test_replacing_values(InitFromPrior()) + test_rng_respected(InitFromPrior()) + test_link_status_respected(InitFromPrior()) + + @testset "check that values are within support" begin + # Not many other sensible checks we can do for priors. + @model just_unif() = x ~ Uniform(0.0, 1e-7) + for _ in 1:100 + _, vi = DynamicPPL.init!!(just_unif(), VarInfo(), InitFromPrior()) + @test vi[@varname(x)] isa Real + @test 0.0 <= vi[@varname(x)] <= 1e-7 + end + end + end + + @testset "InitFromUniform" begin + test_generating_new_values(InitFromUniform()) + test_replacing_values(InitFromUniform()) + test_rng_respected(InitFromUniform()) + test_link_status_respected(InitFromUniform()) + + @testset "check that bounds are respected" begin + @testset "unconstrained" begin + umin, umax = -1.0, 1.0 + @model just_norm() = x ~ Normal() + for _ in 1:100 + _, vi = DynamicPPL.init!!( + just_norm(), VarInfo(), InitFromUniform(umin, umax) + ) + @test vi[@varname(x)] isa Real + @test umin <= vi[@varname(x)] <= umax + end + end + @testset "constrained" begin + umin, umax = -1.0, 1.0 + @model just_beta() = x ~ Beta(2, 2) + inv_bijector = inverse(Bijectors.bijector(Beta(2, 2))) + tmin, tmax = inv_bijector(umin), inv_bijector(umax) + for _ in 1:100 + _, vi = DynamicPPL.init!!( + just_beta(), VarInfo(), InitFromUniform(umin, umax) + ) + @test vi[@varname(x)] isa Real + @test tmin <= vi[@varname(x)] <= tmax + end + end + end + end + + @testset "InitFromParams" begin + test_link_status_respected(InitFromParams((; a=1.0))) + test_link_status_respected(InitFromParams(Dict(@varname(a) => 1.0))) + + @testset "given full set of parameters" begin + # test_init_model has x ~ Normal() and y ~ MvNormal(zeros(2), I) + my_x, my_y = 1.0, [2.0, 3.0] + params_nt = (; x=my_x, y=my_y) + params_dict = Dict(@varname(x) => my_x, @varname(y) => my_y) + model = test_init_model() + @testset "$vi_name" for (vi_name, empty_vi) in empty_varinfos + _, vi = DynamicPPL.init!!( + model, deepcopy(empty_vi), InitFromParams(params_nt) + ) + @test vi[@varname(x)] == my_x + @test vi[@varname(y)] == my_y + logp_nt = getlogp(vi) + _, vi = DynamicPPL.init!!( + model, deepcopy(empty_vi), InitFromParams(params_dict) + ) + @test vi[@varname(x)] == my_x + @test vi[@varname(y)] == my_y + logp_dict = getlogp(vi) + @test logp_nt == logp_dict + end + end + + @testset "given only partial parameters" begin + my_x = 1.0 + params_nt = (; x=my_x) + params_dict = Dict(@varname(x) => my_x) + model = test_init_model() + @testset "$vi_name" for (vi_name, empty_vi) in empty_varinfos + @testset "with InitFromPrior fallback" begin + _, vi = DynamicPPL.init!!( + Xoshiro(468), + model, + deepcopy(empty_vi), + InitFromParams(params_nt, InitFromPrior()), + ) + @test vi[@varname(x)] == my_x + nt_y = vi[@varname(y)] + @test nt_y isa AbstractVector{<:Real} + @test length(nt_y) == 2 + _, vi = DynamicPPL.init!!( + Xoshiro(469), + model, + deepcopy(empty_vi), + InitFromParams(params_dict, InitFromPrior()), + ) + @test vi[@varname(x)] == my_x + dict_y = vi[@varname(y)] + @test dict_y isa AbstractVector{<:Real} + @test length(dict_y) == 2 + # the values should be different since we used different seeds + @test dict_y != nt_y + end + + @testset "with no fallback" begin + # These just don't have an entry for `y`. + @test_throws ErrorException DynamicPPL.init!!( + model, deepcopy(empty_vi), InitFromParams(params_nt, nothing) + ) + @test_throws ErrorException DynamicPPL.init!!( + model, deepcopy(empty_vi), InitFromParams(params_dict, nothing) + ) + # We also explicitly test the case where `y = missing`. + params_nt_missing = (; x=my_x, y=missing) + params_dict_missing = Dict( + @varname(x) => my_x, @varname(y) => missing + ) + @test_throws ErrorException DynamicPPL.init!!( + model, + deepcopy(empty_vi), + InitFromParams(params_nt_missing, nothing), + ) + @test_throws ErrorException DynamicPPL.init!!( + model, + deepcopy(empty_vi), + InitFromParams(params_dict_missing, nothing), + ) + end + end + end + end + end end From c8e58412e060a1294958efd5ff1d8fa49ca98679 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 31 Aug 2025 14:29:44 +0100 Subject: [PATCH 05/14] use `varname_leaves` from AbstractPPL instead (#1030) * use `varname_leaves` from AbstractPPL instead * add changelog entry * fix import --- HISTORY.md | 3 + Project.toml | 2 +- docs/src/api.md | 2 - ext/DynamicPPLMCMCChainsExt.jl | 11 +- src/utils.jl | 239 --------------------------------- test/model.jl | 2 +- test/test_util.jl | 2 +- test/varinfo.jl | 2 +- 8 files changed, 10 insertions(+), 253 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index 87bd2d552..91218d1fc 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -2,6 +2,9 @@ ## 0.38.0 +The `varname_leaves` and `varname_and_value_leaves` functions have been moved to AbstractPPL.jl. +Their behaviour is otherwise identical. + [...] ## 0.37.1 diff --git a/Project.toml b/Project.toml index 8e0ada64e..6dff71a03 100644 --- a/Project.toml +++ b/Project.toml @@ -47,7 +47,7 @@ DynamicPPLMooncakeExt = ["Mooncake"] [compat] ADTypes = "1" AbstractMCMC = "5" -AbstractPPL = "0.13" +AbstractPPL = "0.13.1" Accessors = "0.1" BangBang = "0.4.1" Bijectors = "0.13.18, 0.14, 0.15" diff --git a/docs/src/api.md b/docs/src/api.md index c6244b75f..d2150f3d7 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -435,8 +435,6 @@ DynamicPPL.maybe_invlink_before_eval!! Base.merge(::AbstractVarInfo) DynamicPPL.subset DynamicPPL.unflatten -DynamicPPL.varname_leaves -DynamicPPL.varname_and_value_leaves ``` ### Evaluation Contexts diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl index a29696720..48efc1464 100644 --- a/ext/DynamicPPLMCMCChainsExt.jl +++ b/ext/DynamicPPLMCMCChainsExt.jl @@ -1,12 +1,7 @@ module DynamicPPLMCMCChainsExt -if isdefined(Base, :get_extension) - using DynamicPPL: DynamicPPL - using MCMCChains: MCMCChains -else - using ..DynamicPPL: DynamicPPL - using ..MCMCChains: MCMCChains -end +using DynamicPPL: DynamicPPL, AbstractPPL +using MCMCChains: MCMCChains # Load state from a `Chains`: By convention, it is stored in `:samplerstate` metadata function DynamicPPL.loadstate(chain::MCMCChains.Chains) @@ -121,7 +116,7 @@ function DynamicPPL.predict( varname_vals = mapreduce( collect, vcat, - map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)), + map(AbstractPPL.varname_and_value_leaves, keys(vals), values(vals)), ) return (varname_and_values=varname_vals, logp=DynamicPPL.getlogjoint(varinfo)) diff --git a/src/utils.jl b/src/utils.jl index d3371271f..c7d1e089f 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -837,245 +837,6 @@ end # Handle `AbstractDict` differently since `eltype` results in a `Pair`. infer_nested_eltype(::Type{<:AbstractDict{<:Any,ET}}) where {ET} = infer_nested_eltype(ET) -""" - varname_leaves(vn::VarName, val) - -Return an iterator over all varnames that are represented by `vn` on `val`. - -# Examples -```jldoctest -julia> using DynamicPPL: varname_leaves - -julia> foreach(println, varname_leaves(@varname(x), rand(2))) -x[1] -x[2] - -julia> foreach(println, varname_leaves(@varname(x[1:2]), rand(2))) -x[1:2][1] -x[1:2][2] - -julia> x = (y = 1, z = [[2.0], [3.0]]); - -julia> foreach(println, varname_leaves(@varname(x), x)) -x.y -x.z[1][1] -x.z[2][1] -``` -""" -varname_leaves(vn::VarName, ::Real) = [vn] -function varname_leaves(vn::VarName, val::AbstractArray{<:Union{Real,Missing}}) - return ( - VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) ∘ getoptic(vn)) for - I in CartesianIndices(val) - ) -end -function varname_leaves(vn::VarName, val::AbstractArray) - return Iterators.flatten( - varname_leaves( - VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) ∘ getoptic(vn)), val[I] - ) for I in CartesianIndices(val) - ) -end -function varname_leaves(vn::VarName, val::NamedTuple) - iter = Iterators.map(keys(val)) do k - optic = Accessors.PropertyLens{k}() - varname_leaves(VarName{getsym(vn)}(optic ∘ getoptic(vn)), optic(val)) - end - return Iterators.flatten(iter) -end - -""" - varname_and_value_leaves(vn::VarName, val) - -Return an iterator over all varname-value pairs that are represented by `vn` on `val`. - -# Examples -```jldoctest varname-and-value-leaves -julia> using DynamicPPL: varname_and_value_leaves - -julia> foreach(println, varname_and_value_leaves(@varname(x), 1:2)) -(x[1], 1) -(x[2], 2) - -julia> foreach(println, varname_and_value_leaves(@varname(x[1:2]), 1:2)) -(x[1:2][1], 1) -(x[1:2][2], 2) - -julia> x = (y = 1, z = [[2.0], [3.0]]); - -julia> foreach(println, varname_and_value_leaves(@varname(x), x)) -(x.y, 1) -(x.z[1][1], 2.0) -(x.z[2][1], 3.0) -``` - -There is also some special handling for certain types: - -```jldoctest varname-and-value-leaves -julia> using LinearAlgebra - -julia> x = reshape(1:4, 2, 2); - -julia> # `LowerTriangular` - foreach(println, varname_and_value_leaves(@varname(x), LowerTriangular(x))) -(x[1, 1], 1) -(x[2, 1], 2) -(x[2, 2], 4) - -julia> # `UpperTriangular` - foreach(println, varname_and_value_leaves(@varname(x), UpperTriangular(x))) -(x[1, 1], 1) -(x[1, 2], 3) -(x[2, 2], 4) - -julia> # `Cholesky` with lower-triangular - foreach(println, varname_and_value_leaves(@varname(x), Cholesky([1.0 0.0; 0.0 1.0], 'L', 0))) -(x.L[1, 1], 1.0) -(x.L[2, 1], 0.0) -(x.L[2, 2], 1.0) - -julia> # `Cholesky` with upper-triangular - foreach(println, varname_and_value_leaves(@varname(x), Cholesky([1.0 0.0; 0.0 1.0], 'U', 0))) -(x.U[1, 1], 1.0) -(x.U[1, 2], 0.0) -(x.U[2, 2], 1.0) -``` -""" -function varname_and_value_leaves(vn::VarName, x) - return Iterators.map(value, Iterators.flatten(varname_and_value_leaves_inner(vn, x))) -end - -""" - varname_and_value_leaves(container) - -Return an iterator over all varname-value pairs that are represented by `container`. - -This is the same as [`varname_and_value_leaves(vn::VarName, x)`](@ref) but over a container -containing multiple varnames. - -See also: [`varname_and_value_leaves(vn::VarName, x)`](@ref). - -# Examples -```jldoctest varname-and-value-leaves-container -julia> using DynamicPPL: varname_and_value_leaves - -julia> # With an `OrderedDict` - dict = OrderedDict(@varname(y) => 1, @varname(z) => [[2.0], [3.0]]); - -julia> foreach(println, varname_and_value_leaves(dict)) -(y, 1) -(z[1][1], 2.0) -(z[2][1], 3.0) - -julia> # With a `NamedTuple` - nt = (y = 1, z = [[2.0], [3.0]]); - -julia> foreach(println, varname_and_value_leaves(nt)) -(y, 1) -(z[1][1], 2.0) -(z[2][1], 3.0) -``` -""" -function varname_and_value_leaves(container::OrderedDict) - return Iterators.flatten(varname_and_value_leaves(k, v) for (k, v) in container) -end -function varname_and_value_leaves(container::NamedTuple) - return Iterators.flatten( - varname_and_value_leaves(VarName{k}(), v) for (k, v) in pairs(container) - ) -end - -""" - Leaf{T} - -A container that represents the leaf of a nested structure, implementing -`iterate` to return itself. - -This is particularly useful in conjunction with `Iterators.flatten` to -prevent flattening of nested structures. -""" -struct Leaf{T} - value::T -end - -Leaf(xs...) = Leaf(xs) - -# Allow us to treat `Leaf` as an iterator containing a single element. -# Something like an `[x]` would also be an iterator with a single element, -# but when we call `flatten` on this, it would also iterate over `x`, -# unflattening that too. By making `Leaf` a single-element iterator, which -# returns itself, we can call `iterate` on this as many times as we like -# without causing any change. The result is that `Iterators.flatten` -# will _not_ unflatten `Leaf`s. -# Note that this is similar to how `Base.iterate` is implemented for `Real`:: -# -# julia> iterate(1) -# (1, nothing) -# -# One immediate example where this becomes in our scenario is that we might -# have `missing` values in our data, which does _not_ have an `iterate` -# implemented. Calling `Iterators.flatten` on this would cause an error. -Base.iterate(leaf::Leaf) = leaf, nothing -Base.iterate(::Leaf, _) = nothing - -# Convenience. -value(leaf::Leaf) = leaf.value - -# Leaf-types. -varname_and_value_leaves_inner(vn::VarName, x::Real) = [Leaf(vn, x)] -function varname_and_value_leaves_inner( - vn::VarName, val::AbstractArray{<:Union{Real,Missing}} -) - return ( - Leaf( - VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) ∘ AbstractPPL.getoptic(vn)), - val[I], - ) for I in CartesianIndices(val) - ) -end -# Containers. -function varname_and_value_leaves_inner(vn::VarName, val::AbstractArray) - return Iterators.flatten( - varname_and_value_leaves_inner( - VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) ∘ AbstractPPL.getoptic(vn)), - val[I], - ) for I in CartesianIndices(val) - ) -end -function varname_and_value_leaves_inner(vn::VarName, val::NamedTuple) - iter = Iterators.map(keys(val)) do k - optic = Accessors.PropertyLens{k}() - varname_and_value_leaves_inner( - VarName{getsym(vn)}(optic ∘ getoptic(vn)), optic(val) - ) - end - - return Iterators.flatten(iter) -end -# Special types. -function varname_and_value_leaves_inner(vn::VarName, x::Cholesky) - # TODO: Or do we use `PDMat` here? - return if x.uplo == 'L' - varname_and_value_leaves_inner(Accessors.PropertyLens{:L}() ∘ vn, x.L) - else - varname_and_value_leaves_inner(Accessors.PropertyLens{:U}() ∘ vn, x.U) - end -end -function varname_and_value_leaves_inner(vn::VarName, x::LinearAlgebra.LowerTriangular) - return ( - Leaf(VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) ∘ getoptic(vn)), x[I]) - # Iteration over the lower-triangular indices. - for I in CartesianIndices(x) if I[1] >= I[2] - ) -end -function varname_and_value_leaves_inner(vn::VarName, x::LinearAlgebra.UpperTriangular) - return ( - Leaf(VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) ∘ getoptic(vn)), x[I]) - # Iteration over the upper-triangular indices. - for I in CartesianIndices(x) if I[1] <= I[2] - ) -end - broadcast_safe(x) = x broadcast_safe(x::Distribution) = (x,) broadcast_safe(x::AbstractContext) = (x,) diff --git a/test/model.jl b/test/model.jl index 81f84e548..f062a70b4 100644 --- a/test/model.jl +++ b/test/model.jl @@ -347,7 +347,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() # Extract varnames and values. vns_and_vals_xs = map( - collect ∘ Base.Fix1(DynamicPPL.varname_and_value_leaves, @varname(x)), xs + collect ∘ Base.Fix1(AbstractPPL.varname_and_value_leaves, @varname(x)), xs ) vns = map(first, first(vns_and_vals_xs)) vals = map(vns_and_vals_xs) do vns_and_vals diff --git a/test/test_util.jl b/test/test_util.jl index e04486760..c6762ed45 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -72,7 +72,7 @@ function make_chain_from_prior(rng::Random.AbstractRNG, model::Model, n_iters::I # We have to use varname_and_value_leaves so that each parameter is a scalar dicts = map(varinfos) do t vals = DynamicPPL.values_as(t, OrderedDict) - iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)) + iters = map(AbstractPPL.varname_and_value_leaves, keys(vals), values(vals)) tuples = mapreduce(collect, vcat, iters) # The following loop is a replacement for: # push!(varnames, map(first, tuples)...) diff --git a/test/varinfo.jl b/test/varinfo.jl index ba7c17b34..f36af44a1 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -488,7 +488,7 @@ end θ_new = var_info[:] @test θ_old != θ_new vals = DynamicPPL.values_as(var_info, OrderedDict) - iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)) + iters = map(AbstractPPL.varname_and_value_leaves, keys(vals), values(vals)) for (n, v) in mapreduce(collect, vcat, iters) n = string(n) if Symbol(n) ∉ keys(chain) From fead2a287ca29888f8a01f1b3f416c5ce25bd482 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 31 Aug 2025 15:11:04 +0100 Subject: [PATCH 06/14] tidy occurrences of varname_leaves as well (#1031) --- src/model_utils.jl | 4 ++-- src/test_utils.jl | 2 +- src/test_utils/sampler.jl | 2 +- test/contexts.jl | 2 +- test/model.jl | 2 +- test/model_utils.jl | 6 +++--- 6 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/model_utils.jl b/src/model_utils.jl index ac4ec7022..e4c326b39 100644 --- a/src/model_utils.jl +++ b/src/model_utils.jl @@ -81,7 +81,7 @@ function varname_in_chain!( # We use `VarName{sym}()` so that the resulting leaf `vn` only contains the tail of the optic. # This way we can use `getoptic(vn)` to extract the value from `x` and use `getoptic(vn) ∘ vn_parent` # to extract the value from the `chain`. - for vn in varname_leaves(VarName{sym}(), x) + for vn in AbstractPPL.varname_leaves(VarName{sym}(), x) # Update `out`, possibly in place, and return. l = AbstractPPL.getoptic(vn) varname_in_chain!(x, l ∘ vn_parent, chain, chain_idx, iteration_idx, out) @@ -107,7 +107,7 @@ function values_from_chain( # This way we can use `getoptic(vn)` to extract the value from `x` and use `getoptic(vn) ∘ vn_parent` # to extract the value from the `chain`. out = similar(x) - for vn in varname_leaves(VarName{sym}(), x) + for vn in AbstractPPL.varname_leaves(VarName{sym}(), x) # Update `out`, possibly in place, and return. l = AbstractPPL.getoptic(vn) out = Accessors.set( diff --git a/src/test_utils.jl b/src/test_utils.jl index 65079f023..195345d60 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -11,7 +11,7 @@ using Bijectors: Bijectors using Accessors: Accessors # For backwards compat. -using DynamicPPL: varname_leaves, update_values!! +using DynamicPPL: update_values!! include("test_utils/model_interface.jl") include("test_utils/models.jl") diff --git a/src/test_utils/sampler.jl b/src/test_utils/sampler.jl index 71cdb1cac..3ef965bad 100644 --- a/src/test_utils/sampler.jl +++ b/src/test_utils/sampler.jl @@ -51,7 +51,7 @@ function test_sampler( for vn in filter(varnames_filter, varnames(model)) # We want to compare elementwise which can be achieved by # extracting the leaves of the `VarName` and the corresponding value. - for vn_leaf in varname_leaves(vn, get(target_values, vn)) + for vn_leaf in AbstractPPL.varname_leaves(vn, get(target_values, vn)) target_value = get(target_values, vn_leaf) chain_mean_value = marginal_mean_of_samples(chain, vn_leaf) @test chain_mean_value ≈ target_value atol = atol rtol = rtol diff --git a/test/contexts.jl b/test/contexts.jl index 365865e7e..107607d99 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -93,7 +93,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() # here to split up arrays which could potentially have some, # but not all, elements being `missing`. conditioned_vns = mapreduce( - p -> DynamicPPL.TestUtils.varname_leaves(p.first, p.second), + p -> AbstractPPL.varname_leaves(p.first, p.second), vcat, pairs(conditioned_values), ) diff --git a/test/model.jl b/test/model.jl index f062a70b4..2234dde8f 100644 --- a/test/model.jl +++ b/test/model.jl @@ -71,7 +71,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() chain_sym_map = Dict{Symbol,Symbol}() for vn_parent in keys(var_info) sym = DynamicPPL.getsym(vn_parent) - vn_children = DynamicPPL.varname_leaves(vn_parent, var_info[vn_parent]) + vn_children = AbstractPPL.varname_leaves(vn_parent, var_info[vn_parent]) for vn_child in vn_children chain_sym_map[Symbol(vn_child)] = sym end diff --git a/test/model_utils.jl b/test/model_utils.jl index 720ae55aa..af695dbf2 100644 --- a/test/model_utils.jl +++ b/test/model_utils.jl @@ -6,11 +6,11 @@ chain = make_chain_from_prior(model, 10) for (i, d) in enumerate(value_iterator_from_chain(model, chain)) for vn in keys(d) - val = DynamicPPL.getvalue(d, vn) + val = AbstractPPL.getvalue(d, vn) # Because value_iterator_from_chain groups varnames with # the same parent symbol, we have to ungroup them here - for vn_leaf in DynamicPPL.varname_leaves(vn, val) - val_leaf = DynamicPPL.getvalue(d, vn_leaf) + for vn_leaf in AbstractPPL.varname_leaves(vn, val) + val_leaf = AbstractPPL.getvalue(d, vn_leaf) @test val_leaf == chain[i, Symbol(vn_leaf), 1] end end From 729bfba4e6ed736a2df0a603bbe8acddacd8a800 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 18 Sep 2025 10:29:20 +0100 Subject: [PATCH 07/14] `InitContext`, part 4 - Use `init!!` to replace `evaluate_and_sample!!`, `predict`, `returned`, and `initialize_values` (#984) * Replace `evaluate_and_sample!!` -> `init!!` * Use `ParamsInit` for `predict`; remove `setval_and_resample!` and friends * Use `init!!` for initialisation * Paper over the `Sampling->Init` context stack (pending removal of SamplingContext) * Remove SamplingContext from JETExt to avoid triggering `Sampling->Init` pathway * Remove `predict` on vector of VarInfo * Fix some tests * Remove duplicated test * Simplify context testing * Rename FooInit -> InitFromFoo * Fix JETExt * Fix JETExt properly * Fix tests * Improve comments * Remove duplicated tests * Docstring improvements Co-authored-by: Markus Hauru * Concretise `chain_sample_to_varname_dict` using chain value type * Clarify testset name * Re-add comment that shouldn't have vanished * Fix stale Requires dep * Fix default_varinfo/initialisation for odd models * Add comment to src/sampler.jl Co-authored-by: Markus Hauru --------- Co-authored-by: Markus Hauru --- Project.toml | 2 - docs/src/api.md | 14 +- ext/DynamicPPLJETExt.jl | 43 +++--- ext/DynamicPPLMCMCChainsExt.jl | 42 ++++-- src/DynamicPPL.jl | 4 - src/extract_priors.jl | 2 +- src/model.jl | 65 ++------ src/sampler.jl | 161 +++++--------------- src/simple_varinfo.jl | 52 ++++--- src/test_utils/contexts.jl | 80 ++++++---- src/test_utils/model_interface.jl | 4 +- src/varinfo.jl | 223 ++++++---------------------- test/ad.jl | 3 +- test/compiler.jl | 16 +- test/contexts.jl | 20 +-- test/ext/DynamicPPLJETExt.jl | 42 +++--- test/ext/DynamicPPLMCMCChainsExt.jl | 7 +- test/model.jl | 67 ++------- test/sampler.jl | 75 +++------- test/simple_varinfo.jl | 8 +- test/test_util.jl | 4 +- test/varinfo.jl | 130 +++------------- test/varnamedvector.jl | 4 +- 23 files changed, 348 insertions(+), 720 deletions(-) diff --git a/Project.toml b/Project.toml index 6dff71a03..f2e39b778 100644 --- a/Project.toml +++ b/Project.toml @@ -23,7 +23,6 @@ MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -Requires = "ae029012-a4dd-5104-9daa-d747884805df" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" @@ -71,7 +70,6 @@ Mooncake = "0.4.147" OrderedCollections = "1" Printf = "1.10" Random = "1.6" -Requires = "1" Statistics = "1" Test = "1.6" julia = "1.10.8" diff --git a/docs/src/api.md b/docs/src/api.md index d2150f3d7..d1dddb560 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -447,11 +447,6 @@ AbstractPPL.evaluate!! This method mutates the `varinfo` used for execution. By default, it does not perform any actual sampling: it only evaluates the model using the values of the variables that are already in the `varinfo`. -To perform sampling, you can either wrap `model.context` in a `SamplingContext`, or use this convenience method: - -```@docs -DynamicPPL.evaluate_and_sample!! -``` The behaviour of a model execution can be changed with evaluation contexts, which are a field of the model. Contexts are subtypes of `AbstractPPL.AbstractContext`. @@ -466,7 +461,12 @@ InitContext ### VarInfo initialisation -`InitContext` is used to initialise, or overwrite, values in a VarInfo. +The function `init!!` is used to initialise, or overwrite, values in a VarInfo. +It is really a thin wrapper around using `evaluate!!` with an `InitContext`. + +```@docs +DynamicPPL.init!! +``` To accomplish this, an initialisation _strategy_ is required, which defines how new values are to be obtained. There are three concrete strategies provided in DynamicPPL: @@ -505,7 +505,7 @@ The default implementation of [`Sampler`](@ref) uses the following unexported fu ```@docs DynamicPPL.initialstep DynamicPPL.loadstate -DynamicPPL.initialsampler +DynamicPPL.init_strategy ``` Finally, to specify which varinfo type a [`Sampler`](@ref) should use for a given [`Model`](@ref), this is specified by [`DynamicPPL.default_varinfo`](@ref) and can thus be overloaded for each `model`-`sampler` combination. This can be useful in cases where one has explicit knowledge that one type of varinfo will be more performant for the given `model` and `sampler`. diff --git a/ext/DynamicPPLJETExt.jl b/ext/DynamicPPLJETExt.jl index 760d17bb0..55016d40c 100644 --- a/ext/DynamicPPLJETExt.jl +++ b/ext/DynamicPPLJETExt.jl @@ -6,7 +6,6 @@ using JET: JET function DynamicPPL.Experimental.is_suitable_varinfo( model::DynamicPPL.Model, varinfo::DynamicPPL.AbstractVarInfo; only_ddpl::Bool=true ) - # Let's make sure that both evaluation and sampling doesn't result in type errors. f, argtypes = DynamicPPL.DebugUtils.gen_evaluator_call_with_types(model, varinfo) # If specified, we only check errors originating somewhere in the DynamicPPL.jl. # This way we don't just fall back to untyped if the user's code is the issue. @@ -21,32 +20,40 @@ end function DynamicPPL.Experimental._determine_varinfo_jet( model::DynamicPPL.Model; only_ddpl::Bool=true ) - # Use SamplingContext to test type stability. - sampling_model = DynamicPPL.contextualize( - model, DynamicPPL.SamplingContext(model.context) - ) - - # First we try with the typed varinfo. - varinfo = DynamicPPL.typed_varinfo(sampling_model) + # Generate a typed varinfo to test model type stability with + varinfo = DynamicPPL.typed_varinfo(model) - # Let's make sure that both evaluation and sampling doesn't result in type errors. - issuccess, result = DynamicPPL.Experimental.is_suitable_varinfo( - sampling_model, varinfo; only_ddpl + # Check type stability of evaluation (i.e. DefaultContext) + model = DynamicPPL.contextualize( + model, DynamicPPL.setleafcontext(model.context, DynamicPPL.DefaultContext()) + ) + eval_issuccess, eval_result = DynamicPPL.Experimental.is_suitable_varinfo( + model, varinfo; only_ddpl ) + if !eval_issuccess + @debug "Evaluation with typed varinfo failed with the following issues:" + @debug eval_result + end - if !issuccess - # Useful information for debugging. - @debug "Evaluaton with typed varinfo failed with the following issues:" - @debug result + # Check type stability of initialisation (i.e. InitContext) + model = DynamicPPL.contextualize( + model, DynamicPPL.setleafcontext(model.context, DynamicPPL.InitContext()) + ) + init_issuccess, init_result = DynamicPPL.Experimental.is_suitable_varinfo( + model, varinfo; only_ddpl + ) + if !init_issuccess + @debug "Initialisation with typed varinfo failed with the following issues:" + @debug init_result end - # If we didn't fail anywhere, we return the type stable one. - return if issuccess + # If neither of them failed, we can return the typed varinfo as it's type stable. + return if (eval_issuccess && init_issuccess) varinfo else # Warn the user that we can't use the type stable one. @warn "Model seems incompatible with typed varinfo. Falling back to untyped varinfo." - DynamicPPL.untyped_varinfo(sampling_model) + DynamicPPL.untyped_varinfo(model) end end diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl index 48efc1464..7b9322254 100644 --- a/ext/DynamicPPLMCMCChainsExt.jl +++ b/ext/DynamicPPLMCMCChainsExt.jl @@ -23,7 +23,7 @@ end function _check_varname_indexing(c::MCMCChains.Chains) return DynamicPPL.supports_varname_indexing(c) || - error("Chains do not support indexing using `VarName`s.") + error("This `Chains` object does not support indexing using `VarName`s.") end function DynamicPPL.getindex_varname( @@ -37,6 +37,17 @@ function DynamicPPL.varnames(c::MCMCChains.Chains) return keys(c.info.varname_to_symbol) end +function chain_sample_to_varname_dict( + c::MCMCChains.Chains{Tval}, sample_idx, chain_idx +) where {Tval} + _check_varname_indexing(c) + d = Dict{DynamicPPL.VarName,Tval}() + for vn in DynamicPPL.varnames(c) + d[vn] = DynamicPPL.getindex_varname(c, sample_idx, vn, chain_idx) + end + return d +end + """ predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false) @@ -109,9 +120,15 @@ function DynamicPPL.predict( iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) predictive_samples = map(iters) do (sample_idx, chain_idx) - DynamicPPL.setval_and_resample!(varinfo, parameter_only_chain, sample_idx, chain_idx) - varinfo = last(DynamicPPL.evaluate_and_sample!!(rng, model, varinfo)) - + # Extract values from the chain + values_dict = chain_sample_to_varname_dict(parameter_only_chain, sample_idx, chain_idx) + # Resample any variables that are not present in `values_dict` + _, varinfo = DynamicPPL.init!!( + rng, + model, + varinfo, + DynamicPPL.InitFromParams(values_dict, DynamicPPL.InitFromPrior()), + ) vals = DynamicPPL.values_as_in_model(model, false, varinfo) varname_vals = mapreduce( collect, @@ -243,13 +260,16 @@ function DynamicPPL.returned(model::DynamicPPL.Model, chain_full::MCMCChains.Cha varinfo = DynamicPPL.VarInfo(model) iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) return map(iters) do (sample_idx, chain_idx) - # TODO: Use `fix` once we've addressed https://github.com/TuringLang/DynamicPPL.jl/issues/702. - # Update the varinfo with the current sample and make variables not present in `chain` - # to be sampled. - DynamicPPL.setval_and_resample!(varinfo, chain, sample_idx, chain_idx) - # NOTE: Some of the varialbes can be a view into the `varinfo`, so we need to - # `deepcopy` the `varinfo` before passing it to the `model`. - model(deepcopy(varinfo)) + # Extract values from the chain + values_dict = chain_sample_to_varname_dict(chain, sample_idx, chain_idx) + # Resample any variables that are not present in `values_dict`, and + # return the model's retval. + retval, _ = DynamicPPL.init!!( + model, + varinfo, + DynamicPPL.InitFromParams(values_dict, DynamicPPL.InitFromPrior()), + ) + retval end end diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 859c7d49d..6a01884a9 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -206,10 +206,6 @@ include("test_utils.jl") include("experimental.jl") include("deprecated.jl") -if !isdefined(Base, :get_extension) - using Requires -end - # Better error message if users forget to load JET if isdefined(Base.Experimental, :register_error_hint) function __init__() diff --git a/src/extract_priors.jl b/src/extract_priors.jl index d311a5f63..8c7b5f7db 100644 --- a/src/extract_priors.jl +++ b/src/extract_priors.jl @@ -123,7 +123,7 @@ extract_priors(args::Union{Model,AbstractVarInfo}...) = function extract_priors(rng::Random.AbstractRNG, model::Model) varinfo = VarInfo() varinfo = setaccs!!(varinfo, (PriorDistributionAccumulator(),)) - varinfo = last(evaluate_and_sample!!(rng, model, varinfo)) + varinfo = last(init!!(rng, model, varinfo)) return getacc(varinfo, Val(:PriorDistributionAccumulator)).priors end diff --git a/src/model.jl b/src/model.jl index e7a1a864f..a6a3e0685 100644 --- a/src/model.jl +++ b/src/model.jl @@ -850,7 +850,7 @@ end # ^ Weird Documenter.jl bug means that we have to write the two above separately # as it can only detect the `function`-less syntax. function (model::Model)(rng::Random.AbstractRNG, varinfo::AbstractVarInfo=VarInfo()) - return first(evaluate_and_sample!!(rng, model, varinfo)) + return first(init!!(rng, model, varinfo)) end """ @@ -863,32 +863,6 @@ function use_threadsafe_eval(context::AbstractContext, varinfo::AbstractVarInfo) return Threads.nthreads() > 1 end -""" - evaluate_and_sample!!([rng::Random.AbstractRNG, ]model::Model, varinfo[, sampler]) - -Evaluate the `model` with the given `varinfo`, but perform sampling during the -evaluation using the given `sampler` by wrapping the model's context in a -`SamplingContext`. - -If `sampler` is not provided, defaults to [`SampleFromPrior`](@ref). - -Returns a tuple of the model's return value, plus the updated `varinfo` object. -""" -function evaluate_and_sample!!( - rng::Random.AbstractRNG, - model::Model, - varinfo::AbstractVarInfo, - sampler::AbstractSampler=SampleFromPrior(), -) - sampling_model = contextualize(model, SamplingContext(rng, sampler, model.context)) - return evaluate!!(sampling_model, varinfo) -end -function evaluate_and_sample!!( - model::Model, varinfo::AbstractVarInfo, sampler::AbstractSampler=SampleFromPrior() -) - return evaluate_and_sample!!(Random.default_rng(), model, varinfo, sampler) -end - """ init!!( [rng::Random.AbstractRNG,] @@ -897,12 +871,12 @@ end [init_strategy::AbstractInitStrategy=InitFromPrior()] ) -Evaluate the `model` and replace the values of the model's random variables in -the given `varinfo` with new values using a specified initialisation strategy. -If the values in `varinfo` are not already present, they will be added using -that same strategy. +Evaluate the `model` and replace the values of the model's random variables +in the given `varinfo` with new values, using a specified initialisation strategy. +If the values in `varinfo` are not set, they will be added +using a specified initialisation strategy. -If `init_strategy` is not provided, defaults to InitFromPrior(). +If `init_strategy` is not provided, defaults to `InitFromPrior()`. Returns a tuple of the model's return value, plus the updated `varinfo` object. """ @@ -1051,11 +1025,7 @@ Base.nameof(model::Model{<:Function}) = nameof(model.f) Generate a sample of type `T` from the prior distribution of the `model`. """ function Base.rand(rng::Random.AbstractRNG, ::Type{T}, model::Model) where {T} - x = last( - evaluate_and_sample!!( - rng, model, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}()) - ), - ) + x = last(init!!(rng, model, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}()))) return values_as(x, T) end @@ -1227,25 +1197,8 @@ function Distributions.loglikelihood(model::Model, chain::AbstractMCMC.AbstractC end end -""" - predict([rng::Random.AbstractRNG,] model::Model, chain::AbstractVector{<:AbstractVarInfo}) - -Generate samples from the posterior predictive distribution by evaluating `model` at each set -of parameter values provided in `chain`. The number of posterior predictive samples matches -the length of `chain`. The returned `AbstractVarInfo`s will contain both the posterior parameter values -and the predicted values. -""" -function predict( - rng::Random.AbstractRNG, model::Model, chain::AbstractArray{<:AbstractVarInfo} -) - varinfo = DynamicPPL.VarInfo(model) - return map(chain) do params_varinfo - vi = deepcopy(varinfo) - DynamicPPL.setval_and_resample!(vi, values_as(params_varinfo, NamedTuple)) - model(rng, vi) - return vi - end -end +# Implemented & documented in DynamicPPLMCMCChainsExt +function predict end """ returned(model::Model, parameters::NamedTuple) diff --git a/src/sampler.jl b/src/sampler.jl index 27b990336..98b50ba55 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -41,7 +41,7 @@ Generic sampler type for inference algorithms of type `T` in DynamicPPL. provided that supports resuming sampling from a previous state and setting initial parameter values. It requires to overload [`loadstate`](@ref) and [`initialstep`](@ref) for loading previous states and actually performing the initial sampling step, -respectively. Additionally, sometimes one might want to implement [`initialsampler`](@ref) +respectively. Additionally, sometimes one might want to implement an [`init_strategy`](@ref) that specifies how the initial parameter values are sampled if they are not provided. By default, values are sampled from the prior. """ @@ -58,8 +58,9 @@ function AbstractMCMC.step( kwargs..., ) vi = VarInfo() - DynamicPPL.evaluate_and_sample!!(rng, model, vi, sampler) - return vi, nothing + strategy = sampler isa SampleFromPrior ? InitFromPrior() : InitFromUniform() + _, new_vi = DynamicPPL.init!!(rng, model, vi, strategy) + return new_vi, nothing end """ @@ -67,6 +68,8 @@ end Return a default varinfo object for the given `model` and `sampler`. +The default method for this returns a NTVarInfo (i.e. 'typed varinfo'). + # Arguments - `rng::Random.AbstractRNG`: Random number generator. - `model::Model`: Model for which we want to create a varinfo object. @@ -75,11 +78,26 @@ Return a default varinfo object for the given `model` and `sampler`. # Returns - `AbstractVarInfo`: Default varinfo object for the given `model` and `sampler`. """ -function default_varinfo(rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler) - init_sampler = initialsampler(sampler) - return typed_varinfo(rng, model, init_sampler) +function default_varinfo(rng::Random.AbstractRNG, model::Model, ::AbstractSampler) + # Note that in `AbstractMCMC.step`, the values in the varinfo returned here are + # immediately overwritten by a subsequent call to `init!!`. The reason why we + # _do_ create a varinfo with parameters here (as opposed to simply returning + # an empty `typed_varinfo(VarInfo())`) is to avoid issues where pushing to an empty + # typed VarInfo would fail. This can happen if two VarNames have different types + # but share the same symbol (e.g. `x.a` and `x.b`). + # TODO(mhauru) Fix push!! to work with arbitrary lens types, and then remove the arguments + # and return an empty VarInfo instead. + return typed_varinfo(VarInfo(rng, model)) end +""" + init_strategy(sampler) + +Define the initialisation strategy used for generating initial values when +sampling with `sampler`. Defaults to `InitFromPrior()`, but can be overridden. +""" +init_strategy(::Sampler) = InitFromPrior() + function AbstractMCMC.sample( rng::Random.AbstractRNG, model::Model, @@ -112,24 +130,24 @@ function AbstractMCMC.sample( ) end -# initial step: general interface for resuming and function AbstractMCMC.step( - rng::Random.AbstractRNG, model::Model, spl::Sampler; initial_params=nothing, kwargs... + rng::Random.AbstractRNG, + model::Model, + spl::Sampler; + initial_params::AbstractInitStrategy=init_strategy(spl), + kwargs..., ) - # Sample initial values. + # Generate the default varinfo. Note that any parameters inside this varinfo + # will be immediately overwritten by the next call to `init!!`. vi = default_varinfo(rng, model, spl) - # Update the parameters if provided. - if initial_params !== nothing - vi = initialize_parameters!!(vi, initial_params, model) - - # Update joint log probability. - # This is a quick fix for https://github.com/TuringLang/Turing.jl/issues/1588 - # and https://github.com/TuringLang/Turing.jl/issues/1563 - # to avoid that existing variables are resampled - vi = last(evaluate!!(model, vi)) - end + # Fill it with initial parameters. Note that, if `InitFromParams` is used, the + # parameters provided must be in unlinked space (when inserted into the + # varinfo, they will be adjusted to match the linking status of the + # varinfo). + _, vi = init!!(rng, model, vi, initial_params) + # Call the actual function that does the first step. return initialstep(rng, model, spl, vi; initial_params, kwargs...) end @@ -147,110 +165,7 @@ loadstate(data) = data Default type of the chain of posterior samples from `sampler`. """ -default_chain_type(sampler::Sampler) = Any - -""" - initialsampler(sampler::Sampler) - -Return the sampler that is used for generating the initial parameters when sampling with -`sampler`. - -By default, it returns an instance of [`SampleFromPrior`](@ref). -""" -initialsampler(spl::Sampler) = SampleFromPrior() - -""" - set_initial_values(varinfo::AbstractVarInfo, initial_params::AbstractVector) - set_initial_values(varinfo::AbstractVarInfo, initial_params::NamedTuple) - -Take the values inside `initial_params`, replace the corresponding values in -the given VarInfo object, and return a new VarInfo object with the updated values. - -This differs from `DynamicPPL.unflatten` in two ways: - -1. It works with `NamedTuple` arguments. -2. For the `AbstractVector` method, if any of the elements are missing, it will not -overwrite the original value in the VarInfo (it will just use the original -value instead). -""" -function set_initial_values(varinfo::AbstractVarInfo, initial_params::AbstractVector) - throw( - ArgumentError( - "`initial_params` must be a vector of type `Union{Real,Missing}`. " * - "If `initial_params` is a vector of vectors, please flatten it (e.g. using `vcat`) first.", - ), - ) -end - -function set_initial_values( - varinfo::AbstractVarInfo, initial_params::AbstractVector{<:Union{Real,Missing}} -) - flattened_param_vals = varinfo[:] - length(flattened_param_vals) == length(initial_params) || throw( - DimensionMismatch( - "Provided initial value size ($(length(initial_params))) doesn't match " * - "the model size ($(length(flattened_param_vals))).", - ), - ) - - # Update values that are provided. - for i in eachindex(initial_params) - x = initial_params[i] - if x !== missing - flattened_param_vals[i] = x - end - end - - # Update in `varinfo`. - new_varinfo = unflatten(varinfo, flattened_param_vals) - return new_varinfo -end - -function set_initial_values(varinfo::AbstractVarInfo, initial_params::NamedTuple) - varinfo = deepcopy(varinfo) - vars_in_varinfo = keys(varinfo) - for v in keys(initial_params) - vn = VarName{v}() - if !(vn in vars_in_varinfo) - for vv in vars_in_varinfo - if subsumes(vn, vv) - throw( - ArgumentError( - "The current model contains sub-variables of $v, such as ($vv). " * - "Using NamedTuple for initial_params is not supported in such a case. " * - "Please use AbstractVector for initial_params instead of NamedTuple.", - ), - ) - end - end - throw(ArgumentError("Variable $v not found in the model.")) - end - end - initial_params = NamedTuple(k => v for (k, v) in pairs(initial_params) if v !== missing) - return update_values!!( - varinfo, initial_params, map(k -> VarName{k}(), keys(initial_params)) - ) -end - -function initialize_parameters!!(vi::AbstractVarInfo, initial_params, model::Model) - @debug "Using passed-in initial variable values" initial_params - - # `link` the varinfo if needed. - linked = islinked(vi) - if linked - vi = invlink!!(vi, model) - end - - # Set the values in `vi`. - vi = set_initial_values(vi, initial_params) - - # `invlink` if needed. - if linked - vi = link!!(vi, model) - end - - return vi -end +default_chain_type(::Sampler) = Any """ initialstep(rng, model, sampler, varinfo; kwargs...) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index cfad93ed9..27365e4dc 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -39,7 +39,7 @@ julia> rng = StableRNG(42); julia> # In the `NamedTuple` version we need to provide the place-holder values for # the variables which are using "containers", e.g. `Array`. # In this case, this means that we need to specify `x` but not `m`. - _, vi = DynamicPPL.evaluate_and_sample!!(rng, m, SimpleVarInfo((x = ones(2), ))); + _, vi = DynamicPPL.init!!(rng, m, SimpleVarInfo((x = ones(2), ))); julia> # (✓) Vroom, vroom! FAST!!! vi[@varname(x[1])] @@ -57,12 +57,12 @@ julia> vi[@varname(x[1:2])] 1.3736306979834252 julia> # (×) If we don't provide the container... - _, vi = DynamicPPL.evaluate_and_sample!!(rng, m, SimpleVarInfo()); vi + _, vi = DynamicPPL.init!!(rng, m, SimpleVarInfo()); vi ERROR: type NamedTuple has no field x [...] julia> # If one does not know the varnames, we can use a `OrderedDict` instead. - _, vi = DynamicPPL.evaluate_and_sample!!(rng, m, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}())); + _, vi = DynamicPPL.init!!(rng, m, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}())); julia> # (✓) Sort of fast, but only possible at runtime. vi[@varname(x[1])] @@ -91,28 +91,28 @@ demo_constrained (generic function with 2 methods) julia> m = demo_constrained(); -julia> _, vi = DynamicPPL.evaluate_and_sample!!(rng, m, SimpleVarInfo()); +julia> _, vi = DynamicPPL.init!!(rng, m, SimpleVarInfo()); julia> vi[@varname(x)] # (✓) 0 ≤ x < ∞ 1.8632965762164932 -julia> _, vi = DynamicPPL.evaluate_and_sample!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)); +julia> _, vi = DynamicPPL.init!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)); julia> vi[@varname(x)] # (✓) -∞ < x < ∞ -0.21080155351918753 -julia> xs = [last(DynamicPPL.evaluate_and_sample!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10]; +julia> xs = [last(DynamicPPL.init!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10]; julia> any(xs .< 0) # (✓) Positive probability mass on negative numbers! true julia> # And with `OrderedDict` of course! - _, vi = DynamicPPL.evaluate_and_sample!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(OrderedDict{VarName,Any}()), true)); + _, vi = DynamicPPL.init!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(OrderedDict{VarName,Any}()), true)); julia> vi[@varname(x)] # (✓) -∞ < x < ∞ 0.6225185067787314 -julia> xs = [last(DynamicPPL.evaluate_and_sample!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10]; +julia> xs = [last(DynamicPPL.init!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10]; julia> any(xs .< 0) # (✓) Positive probability mass on negative numbers! true @@ -232,24 +232,27 @@ end # Constructor from `Model`. function SimpleVarInfo{T}( - rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() + rng::Random.AbstractRNG, + model::Model, + init_strategy::AbstractInitStrategy=InitFromPrior(), ) where {T<:Real} - new_model = contextualize(model, SamplingContext(rng, sampler, model.context)) - return last(evaluate!!(new_model, SimpleVarInfo{T}())) + return last(init!!(rng, model, SimpleVarInfo{T}(), init_strategy)) end function SimpleVarInfo{T}( - model::Model, sampler::AbstractSampler=SampleFromPrior() + model::Model, init_strategy::AbstractInitStrategy=InitFromPrior() ) where {T<:Real} - return SimpleVarInfo{T}(Random.default_rng(), model, sampler) + return SimpleVarInfo{T}(Random.default_rng(), model, init_strategy) end # Constructors without type param function SimpleVarInfo( - rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() + rng::Random.AbstractRNG, + model::Model, + init_strategy::AbstractInitStrategy=InitFromPrior(), ) - return SimpleVarInfo{LogProbType}(rng, model, sampler) + return SimpleVarInfo{LogProbType}(rng, model, init_strategy) end -function SimpleVarInfo(model::Model, sampler::AbstractSampler=SampleFromPrior()) - return SimpleVarInfo{LogProbType}(Random.default_rng(), model, sampler) +function SimpleVarInfo(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior()) + return SimpleVarInfo{LogProbType}(Random.default_rng(), model, init_strategy) end # Constructor from `VarInfo`. @@ -265,12 +268,12 @@ end function untyped_simple_varinfo(model::Model) varinfo = SimpleVarInfo(OrderedDict{VarName,Any}()) - return last(evaluate_and_sample!!(model, varinfo)) + return last(init!!(model, varinfo)) end function typed_simple_varinfo(model::Model) varinfo = SimpleVarInfo{Float64}() - return last(evaluate_and_sample!!(model, varinfo)) + return last(init!!(model, varinfo)) end function unflatten(svi::SimpleVarInfo, x::AbstractVector) @@ -482,7 +485,6 @@ function assume( return value, vi end -# NOTE: We don't implement `settrans!!(vi, trans, vn)`. function settrans!!(vi::SimpleVarInfo, trans) return settrans!!(vi, trans ? DynamicTransformation() : NoTransformation()) end @@ -492,6 +494,16 @@ end function settrans!!(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, trans) return Accessors.@set vi.varinfo = settrans!!(vi.varinfo, trans) end +function settrans!!(vi::SimpleOrThreadSafeSimple, trans::Bool, ::VarName) + # We keep this method around just to obey the AbstractVarInfo interface. + # However, note that this would only be a valid operation if it would be a + # no-op, which we check here. + if trans != istrans(vi) + error( + "Individual variables in SimpleVarInfo cannot have different `settrans` statuses.", + ) + end +end istrans(vi::SimpleVarInfo) = !(vi.transformation isa NoTransformation) istrans(vi::SimpleVarInfo, ::VarName) = istrans(vi) diff --git a/src/test_utils/contexts.jl b/src/test_utils/contexts.jl index 863db4262..d53ba6c5f 100644 --- a/src/test_utils/contexts.jl +++ b/src/test_utils/contexts.jl @@ -25,25 +25,49 @@ This method ensures that `context` - Correctly implements the tilde-pipeline. """ function test_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Model) - # `NodeTrait`. node_trait = DynamicPPL.NodeTrait(context) - # Throw error immediately if it it's missing a `NodeTrait` implementation. - node_trait isa Union{DynamicPPL.IsLeaf,DynamicPPL.IsParent} || - throw(ValueError("Invalid NodeTrait: $node_trait")) - - # To see change, let's make sure we're using a different leaf context than the current. - leafcontext_new = if DynamicPPL.leafcontext(context) isa DefaultContext - DynamicPPL.DynamicTransformationContext{false}() + if node_trait isa DynamicPPL.IsLeaf + test_leaf_context(context, model) + elseif node_trait isa DynamicPPL.IsParent + test_parent_context(context, model) else - DefaultContext() + error("Invalid NodeTrait: $node_trait") end - @test DynamicPPL.leafcontext(DynamicPPL.setleafcontext(context, leafcontext_new)) == - leafcontext_new +end + +function test_leaf_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Model) + @test DynamicPPL.NodeTrait(context) isa DynamicPPL.IsLeaf - # The interface methods. - if node_trait isa DynamicPPL.IsParent - # `childcontext` and `setchildcontext` - # With new child context + # Note that for a leaf context we can't assume that it will work with an + # empty VarInfo. (For example, DefaultContext will error with empty + # varinfos.) Thus we only test evaluation with VarInfos that are already + # filled with values. + @testset "evaluation" begin + # Generate a new filled untyped varinfo + _, untyped_vi = DynamicPPL.init!!(model, DynamicPPL.VarInfo()) + typed_vi = DynamicPPL.typed_varinfo(untyped_vi) + # Set the test context as the new leaf context + new_model = contextualize(model, DynamicPPL.setleafcontext(model.context, context)) + # Check that evaluation works + for vi in [untyped_vi, typed_vi] + _, vi = DynamicPPL.evaluate!!(new_model, vi) + @test vi isa DynamicPPL.VarInfo + end + end +end + +function test_parent_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Model) + @test DynamicPPL.NodeTrait(context) isa DynamicPPL.IsParent + + @testset "get/set leaf and child contexts" begin + # Ensure we're using a different leaf context than the current. + leafcontext_new = if DynamicPPL.leafcontext(context) isa DefaultContext + DynamicPPL.DynamicTransformationContext{false}() + else + DefaultContext() + end + @test DynamicPPL.leafcontext(DynamicPPL.setleafcontext(context, leafcontext_new)) == + leafcontext_new childcontext_new = TestParentContext() @test DynamicPPL.childcontext( DynamicPPL.setchildcontext(context, childcontext_new) @@ -56,19 +80,15 @@ function test_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Mod leafcontext_new end - # Make sure that the we can evaluate the model with the context (i.e. that none of the tilde-functions are incorrectly overloaded). - # The tilde-pipeline contains two different paths: with `SamplingContext` as a parent, and without it. - # NOTE(torfjelde): Need to sample with the untyped varinfo _using_ the context, since the - # context might alter which variables are present, their names, etc., e.g. `PrefixContext`. - # TODO(torfjelde): Make the `varinfo` used for testing a kwarg once it makes sense for other varinfos. - # Untyped varinfo. - varinfo_untyped = DynamicPPL.VarInfo() - model_with_spl = contextualize(model, SamplingContext(context)) - model_without_spl = contextualize(model, context) - @test DynamicPPL.evaluate!!(model_with_spl, varinfo_untyped) isa Any - @test DynamicPPL.evaluate!!(model_without_spl, varinfo_untyped) isa Any - # Typed varinfo. - varinfo_typed = DynamicPPL.typed_varinfo(varinfo_untyped) - @test DynamicPPL.evaluate!!(model_with_spl, varinfo_typed) isa Any - @test DynamicPPL.evaluate!!(model_without_spl, varinfo_typed) isa Any + @testset "initialisation and evaluation" begin + new_model = contextualize(model, context) + for vi in [DynamicPPL.VarInfo(), DynamicPPL.typed_varinfo(DynamicPPL.VarInfo())] + # Initialisation + _, vi = DynamicPPL.init!!(new_model, DynamicPPL.VarInfo()) + @test vi isa DynamicPPL.VarInfo + # Evaluation + _, vi = DynamicPPL.evaluate!!(new_model, vi) + @test vi isa DynamicPPL.VarInfo + end + end end diff --git a/src/test_utils/model_interface.jl b/src/test_utils/model_interface.jl index 93aed074c..cb949464e 100644 --- a/src/test_utils/model_interface.jl +++ b/src/test_utils/model_interface.jl @@ -92,9 +92,7 @@ Even though it is recommended to implement this by hand for a particular `Model` a default implementation using [`SimpleVarInfo{<:Dict}`](@ref) is provided. """ function varnames(model::Model) - return collect( - keys(last(DynamicPPL.evaluate_and_sample!!(model, SimpleVarInfo(Dict())))) - ) + return collect(keys(last(DynamicPPL.init!!(model, SimpleVarInfo(Dict()))))) end """ diff --git a/src/varinfo.jl b/src/varinfo.jl index dec4db3ec..081f65ea1 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -113,10 +113,14 @@ function VarInfo(meta=Metadata()) end """ - VarInfo([rng, ]model[, sampler]) + VarInfo( + [rng::Random.AbstractRNG], + model, + [init_strategy::AbstractInitStrategy] + ) -Generate a `VarInfo` object for the given `model`, by evaluating it once using -the given `rng`, `sampler`. +Generate a `VarInfo` object for the given `model`, by initialising it with the +given `rng` and `init_strategy`. !!! warning @@ -129,12 +133,14 @@ the given `rng`, `sampler`. instead. """ function VarInfo( - rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() + rng::Random.AbstractRNG, + model::Model, + init_strategy::AbstractInitStrategy=InitFromPrior(), ) - return typed_varinfo(rng, model, sampler) + return typed_varinfo(rng, model, init_strategy) end -function VarInfo(model::Model, sampler::AbstractSampler=SampleFromPrior()) - return VarInfo(Random.default_rng(), model, sampler) +function VarInfo(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior()) + return VarInfo(Random.default_rng(), model, init_strategy) end const UntypedVectorVarInfo = VarInfo{<:VarNamedVector} @@ -195,7 +201,7 @@ end ######################## """ - untyped_varinfo([rng, ]model[, sampler]) + untyped_varinfo([rng, ]model[, init_strategy]) Construct a VarInfo object for the given `model`, which has just a single `Metadata` as its metadata field. @@ -203,15 +209,17 @@ Construct a VarInfo object for the given `model`, which has just a single # Arguments - `rng::Random.AbstractRNG`: The random number generator to use during model evaluation - `model::Model`: The model for which to create the varinfo object -- `sampler::AbstractSampler`: The sampler to use for the model. Defaults to `SampleFromPrior()`. +- `init_strategy::AbstractInitStrategy`: How the values are to be initialised. Defaults to `InitFromPrior()`. """ function untyped_varinfo( - rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() + rng::Random.AbstractRNG, + model::Model, + init_strategy::AbstractInitStrategy=InitFromPrior(), ) - return last(evaluate_and_sample!!(rng, model, VarInfo(Metadata()), sampler)) + return last(init!!(rng, model, VarInfo(Metadata()), init_strategy)) end -function untyped_varinfo(model::Model, sampler::AbstractSampler=SampleFromPrior()) - return untyped_varinfo(Random.default_rng(), model, sampler) +function untyped_varinfo(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior()) + return untyped_varinfo(Random.default_rng(), model, init_strategy) end """ @@ -270,7 +278,7 @@ function typed_varinfo(vi::NTVarInfo) return vi end """ - typed_varinfo([rng, ]model[, sampler]) + typed_varinfo([rng, ]model[, init_strategy]) Return a VarInfo object for the given `model`, which has a NamedTuple of `Metadata` structs as its metadata field. @@ -278,19 +286,21 @@ Return a VarInfo object for the given `model`, which has a NamedTuple of # Arguments - `rng::Random.AbstractRNG`: The random number generator to use during model evaluation - `model::Model`: The model for which to create the varinfo object -- `sampler::AbstractSampler`: The sampler to use for the model. Defaults to `SampleFromPrior()`. +- `init_strategy::AbstractInitStrategy`: How the values are to be initialised. Defaults to `InitFromPrior()`. """ function typed_varinfo( - rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() + rng::Random.AbstractRNG, + model::Model, + init_strategy::AbstractInitStrategy=InitFromPrior(), ) - return typed_varinfo(untyped_varinfo(rng, model, sampler)) + return typed_varinfo(untyped_varinfo(rng, model, init_strategy)) end -function typed_varinfo(model::Model, sampler::AbstractSampler=SampleFromPrior()) - return typed_varinfo(Random.default_rng(), model, sampler) +function typed_varinfo(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior()) + return typed_varinfo(Random.default_rng(), model, init_strategy) end """ - untyped_vector_varinfo([rng, ]model[, sampler]) + untyped_vector_varinfo([rng, ]model[, init_strategy]) Return a VarInfo object for the given `model`, which has just a single `VarNamedVector` as its metadata field. @@ -298,23 +308,27 @@ Return a VarInfo object for the given `model`, which has just a single # Arguments - `rng::Random.AbstractRNG`: The random number generator to use during model evaluation - `model::Model`: The model for which to create the varinfo object -- `sampler::AbstractSampler`: The sampler to use for the model. Defaults to `SampleFromPrior()`. +- `init_strategy::AbstractInitStrategy`: How the values are to be initialised. Defaults to `InitFromPrior()`. """ function untyped_vector_varinfo(vi::UntypedVarInfo) md = metadata_to_varnamedvector(vi.metadata) return VarInfo(md, copy(vi.accs)) end function untyped_vector_varinfo( - rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() + rng::Random.AbstractRNG, + model::Model, + init_strategy::AbstractInitStrategy=InitFromPrior(), ) - return untyped_vector_varinfo(untyped_varinfo(rng, model, sampler)) + return untyped_vector_varinfo(untyped_varinfo(rng, model, init_strategy)) end -function untyped_vector_varinfo(model::Model, sampler::AbstractSampler=SampleFromPrior()) - return untyped_vector_varinfo(Random.default_rng(), model, sampler) +function untyped_vector_varinfo( + model::Model, init_strategy::AbstractInitStrategy=InitFromPrior() +) + return untyped_vector_varinfo(Random.default_rng(), model, init_strategy) end """ - typed_vector_varinfo([rng, ]model[, sampler]) + typed_vector_varinfo([rng, ]model[, init_strategy]) Return a VarInfo object for the given `model`, which has a NamedTuple of `VarNamedVector`s as its metadata field. @@ -322,7 +336,7 @@ Return a VarInfo object for the given `model`, which has a NamedTuple of # Arguments - `rng::Random.AbstractRNG`: The random number generator to use during model evaluation - `model::Model`: The model for which to create the varinfo object -- `sampler::AbstractSampler`: The sampler to use for the model. Defaults to `SampleFromPrior()`. +- `init_strategy::AbstractInitStrategy`: How the values are to be initialised. Defaults to `InitFromPrior()`. """ function typed_vector_varinfo(vi::NTVarInfo) md = map(metadata_to_varnamedvector, vi.metadata) @@ -334,12 +348,16 @@ function typed_vector_varinfo(vi::UntypedVectorVarInfo) return VarInfo(nt, copy(vi.accs)) end function typed_vector_varinfo( - rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() + rng::Random.AbstractRNG, + model::Model, + init_strategy::AbstractInitStrategy=InitFromPrior(), ) - return typed_vector_varinfo(untyped_vector_varinfo(rng, model, sampler)) + return typed_vector_varinfo(untyped_vector_varinfo(rng, model, init_strategy)) end -function typed_vector_varinfo(model::Model, sampler::AbstractSampler=SampleFromPrior()) - return typed_vector_varinfo(Random.default_rng(), model, sampler) +function typed_vector_varinfo( + model::Model, init_strategy::AbstractInitStrategy=InitFromPrior() +) + return typed_vector_varinfo(Random.default_rng(), model, init_strategy) end """ @@ -1508,42 +1526,6 @@ function islinked(vi::VarInfo) return any(istrans(vi, vn) for vn in keys(vi)) end -function nested_setindex_maybe!(vi::UntypedVarInfo, val, vn::VarName) - return _nested_setindex_maybe!(vi, getmetadata(vi, vn), val, vn) -end -function nested_setindex_maybe!( - vi::VarInfo{<:NamedTuple{names}}, val, vn::VarName{sym} -) where {names,sym} - return if sym in names - _nested_setindex_maybe!(vi, getmetadata(vi, vn), val, vn) - else - nothing - end -end -function _nested_setindex_maybe!( - vi::VarInfo, md::Union{Metadata,VarNamedVector}, val, vn::VarName -) - # If `vn` is in `vns`, then we can just use the standard `setindex!`. - vns = Base.keys(md) - if vn in vns - setindex!(vi, val, vn) - return vn - end - - # Otherwise, we need to check if either of the `vns` subsumes `vn`. - i = findfirst(Base.Fix2(subsumes, vn), vns) - i === nothing && return nothing - - vn_parent = vns[i] - val_parent = getindex(vi, vn_parent) # TODO: Ensure that we're working with a view here. - # Split the varname into its tail optic. - optic = remove_parent_optic(vn_parent, vn) - # Update the value for the parent. - val_parent_updated = set!!(val_parent, optic, val) - setindex!(vi, val_parent_updated, vn_parent) - return vn_parent -end - # The default getindex & setindex!() for get & set values # NOTE: vi[vn] will always transform the variable to its original space and Julia type function getindex(vi::VarInfo, vn::VarName) @@ -1966,113 +1948,6 @@ function _setval_kernel!(vi::VarInfoOrThreadSafeVarInfo, vn::VarName, values, ke return indices end -""" - setval_and_resample!(vi::VarInfo, x) - setval_and_resample!(vi::VarInfo, values, keys) - setval_and_resample!(vi::VarInfo, chains::AbstractChains, sample_idx, chain_idx) - -Set the values in `vi` to the provided values and those which are not present -in `x` or `chains` to *be* resampled. - -Note that this does *not* resample the values not provided! It will call -`setflag!(vi, vn, "del")` for variables `vn` for which no values are provided, which means -that the next time we call `model(vi)` these variables will be resampled. - -## Note -- This suffers from the same limitations as [`setval!`](@ref). See `setval!` for more info. - -## Example -```jldoctest -julia> using DynamicPPL, Distributions, StableRNGs - -julia> @model function demo(x) - m ~ Normal() - for i in eachindex(x) - x[i] ~ Normal(m, 1) - end - end; - -julia> rng = StableRNG(42); - -julia> m = demo([missing]); - -julia> var_info = DynamicPPL.VarInfo(rng, m); - # Checking the setting of "del" flags only makes sense for VarInfo{<:Metadata}. For VarInfo{<:VarNamedVector} the flag is effectively always set. - -julia> var_info[@varname(m)] --0.6702516921145671 - -julia> var_info[@varname(x[1])] --0.22312984965118443 - -julia> DynamicPPL.setval_and_resample!(var_info, (m = 100.0, )); # set `m` and ready `x[1]` for resampling - -julia> var_info[@varname(m)] # [✓] changed -100.0 - -julia> var_info[@varname(x[1])] # [✓] unchanged --0.22312984965118443 - -julia> m(rng, var_info); # sample `x[1]` conditioned on `m = 100.0` - -julia> var_info[@varname(m)] # [✓] unchanged -100.0 - -julia> var_info[@varname(x[1])] # [✓] changed -101.37363069798343 -``` - -## See also -- [`setval!`](@ref) -""" -function setval_and_resample!(vi::VarInfoOrThreadSafeVarInfo, x) - return setval_and_resample!(vi, values(x), keys(x)) -end -function setval_and_resample!(vi::VarInfoOrThreadSafeVarInfo, values, keys) - return _apply!(_setval_and_resample_kernel!, vi, values, keys) -end -function setval_and_resample!( - vi::VarInfoOrThreadSafeVarInfo, chains::AbstractChains, sample_idx::Int, chain_idx::Int -) - if supports_varname_indexing(chains) - # First we need to set every variable to be resampled. - for vn in keys(vi) - set_flag!(vi, vn, "del") - end - # Then we set the variables in `varinfo` from `chain`. - for vn in varnames(chains) - vn_updated = nested_setindex_maybe!( - vi, getindex_varname(chains, sample_idx, vn, chain_idx), vn - ) - - # Unset the `del` flag if we found something. - if vn_updated !== nothing - # NOTE: This will be triggered even if only a subset of a variable has been set! - unset_flag!(vi, vn_updated, "del") - end - end - else - setval_and_resample!(vi, chains.value[sample_idx, :, chain_idx], keys(chains)) - end -end - -function _setval_and_resample_kernel!( - vi::VarInfoOrThreadSafeVarInfo, vn::VarName, values, keys -) - indices = findall(Base.Fix1(subsumes_string, string(vn)), keys) - if !isempty(indices) - val = reduce(vcat, values[indices]) - setval!(vi, val, vn) - settrans!!(vi, false, vn) - else - # Ensures that we'll resample the variable corresponding to `vn` if we run - # the model on `vi` again. - set_flag!(vi, vn, "del") - end - - return indices -end - values_as(vi::VarInfo) = vi.metadata values_as(vi::VarInfo, ::Type{Vector}) = copy(getindex_internal(vi, Colon())) function values_as(vi::UntypedVarInfo, ::Type{NamedTuple}) diff --git a/test/ad.jl b/test/ad.jl index 371e79b06..0e5d8d7cf 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -111,9 +111,10 @@ using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest # Compiling the ReverseDiff tape used to fail here spl = Sampler(MyEmptyAlg()) + vi = DynamicPPL.link!!(VarInfo(model), model) sampling_model = contextualize(model, SamplingContext(model.context)) ldf = LogDensityFunction( - sampling_model, getlogjoint_internal; adtype=AutoReverseDiff(; compile=true) + sampling_model, getlogjoint_internal, vi; adtype=AutoReverseDiff(; compile=true) ) x = ldf.varinfo[:] @test LogDensityProblems.logdensity_and_gradient(ldf, x) isa Any diff --git a/test/compiler.jl b/test/compiler.jl index 97121715a..b1309254e 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -193,11 +193,11 @@ module Issue537 end varinfo = VarInfo(model) @test getlogjoint(varinfo) == lp @test varinfo_ isa AbstractVarInfo - # During the model evaluation, its context is wrapped in a - # SamplingContext, so `model_` is not going to be equal to `model`. - # We can still check equality of `f` though. + # During the model evaluation, its leaf context is changed to an InitContext, so + # `model_` is not going to be equal to `model`. We can still check equality of `f` + # though. @test model_.f === model.f - @test model_.context isa SamplingContext + @test model_.context isa DynamicPPL.InitContext @test model_.context.rng isa Random.AbstractRNG # disable warnings @@ -598,13 +598,13 @@ module Issue537 end # an attempt at a `NamedTuple` of the form `(x = 1, __varinfo__)`. @model empty_model() = return x = 1 empty_vi = VarInfo() - retval_and_vi = DynamicPPL.evaluate_and_sample!!(empty_model(), empty_vi) + retval_and_vi = DynamicPPL.init!!(empty_model(), empty_vi) @test retval_and_vi isa Tuple{Int,typeof(empty_vi)} # Even if the return-value is `AbstractVarInfo`, we should return # a `Tuple` with `AbstractVarInfo` in the second component too. @model demo() = return __varinfo__ - retval, svi = DynamicPPL.evaluate_and_sample!!(demo(), SimpleVarInfo()) + retval, svi = DynamicPPL.init!!(demo(), SimpleVarInfo()) @test svi == SimpleVarInfo() if Threads.nthreads() > 1 @test retval isa DynamicPPL.ThreadSafeVarInfo{<:SimpleVarInfo} @@ -620,11 +620,11 @@ module Issue537 end f(x) = return x^2 return f(1.0) end - retval, svi = DynamicPPL.evaluate_and_sample!!(demo(), SimpleVarInfo()) + retval, svi = DynamicPPL.init!!(demo(), SimpleVarInfo()) @test retval isa Float64 @model demo() = x ~ Normal() - retval, svi = DynamicPPL.evaluate_and_sample!!(demo(), SimpleVarInfo()) + retval, svi = DynamicPPL.init!!(demo(), SimpleVarInfo()) # Return-value when using `to_submodel` @model inner() = x ~ Normal() diff --git a/test/contexts.jl b/test/contexts.jl index 107607d99..1a6279bf4 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -50,7 +50,6 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() contexts = Dict( :default => DefaultContext(), :testparent => DynamicPPL.TestUtils.TestParentContext(DefaultContext()), - :sampling => SamplingContext(), :prefix => PrefixContext(@varname(x)), :condition1 => ConditionContext((x=1.0,)), :condition2 => ConditionContext( @@ -166,29 +165,30 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() @test new_vn == @varname(a.x[1]) @test new_ctx == DefaultContext() - ctx2 = SamplingContext(PrefixContext(@varname(a))) + ctx2 = FixedContext((b=4,), PrefixContext(@varname(a))) new_vn, new_ctx = DynamicPPL.prefix_and_strip_contexts(ctx2, vn) @test new_vn == @varname(a.x[1]) - @test new_ctx == SamplingContext() + @test new_ctx == FixedContext((b=4,)) ctx3 = PrefixContext(@varname(a), ConditionContext((a=1,))) new_vn, new_ctx = DynamicPPL.prefix_and_strip_contexts(ctx3, vn) @test new_vn == @varname(a.x[1]) @test new_ctx == ConditionContext((a=1,)) - ctx4 = SamplingContext(PrefixContext(@varname(a), ConditionContext((a=1,)))) + ctx4 = FixedContext( + (b=4,), PrefixContext(@varname(a), ConditionContext((a=1,))) + ) new_vn, new_ctx = DynamicPPL.prefix_and_strip_contexts(ctx4, vn) @test new_vn == @varname(a.x[1]) - @test new_ctx == SamplingContext(ConditionContext((a=1,))) + @test new_ctx == FixedContext((b=4,), ConditionContext((a=1,))) end @testset "evaluation: $(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS prefix_vn = @varname(my_prefix) - context = DynamicPPL.PrefixContext(prefix_vn, SamplingContext()) - sampling_model = contextualize(model, context) - # Sample with the context. - varinfo = DynamicPPL.VarInfo() - DynamicPPL.evaluate!!(sampling_model, varinfo) + context = DynamicPPL.PrefixContext(prefix_vn, DefaultContext()) + new_model = contextualize(model, context) + # Initialize a new varinfo with the prefixed model + _, varinfo = DynamicPPL.init!!(new_model, DynamicPPL.VarInfo()) # Extract the resulting varnames vns_actual = Set(keys(varinfo)) diff --git a/test/ext/DynamicPPLJETExt.jl b/test/ext/DynamicPPLJETExt.jl index 6737cf056..692f53911 100644 --- a/test/ext/DynamicPPLJETExt.jl +++ b/test/ext/DynamicPPLJETExt.jl @@ -30,7 +30,7 @@ DynamicPPL.UntypedVarInfo # Evaluation works (and it would even do so in practice), but sampling - # fill fail due to storing `Cauchy{Float64}` in `Vector{Normal{Float64}}`. + # will fail due to storing `Cauchy{Float64}` in `Vector{Normal{Float64}}`. @model function demo4() x ~ Bernoulli() if x @@ -62,33 +62,37 @@ @testset "demo models" begin @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS - sampling_model = contextualize(model, SamplingContext(model.context)) # Use debug logging below. varinfo = DynamicPPL.Experimental.determine_suitable_varinfo(model) - # Check that the inferred varinfo is indeed suitable for evaluation and sampling - f_eval, argtypes_eval = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( - model, varinfo - ) - JET.test_call(f_eval, argtypes_eval) - - f_sample, argtypes_sample = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( - sampling_model, varinfo - ) - JET.test_call(f_sample, argtypes_sample) # For our demo models, they should all result in typed. is_typed = varinfo isa DynamicPPL.NTVarInfo @test is_typed - # If the test failed, check why it didn't infer a typed varinfo + # If the test failed, check what the type stability problem was for + # the typed varinfo. This is mostly useful for debugging from test + # logs. if !is_typed + @info "Model `$(model.f)` is not type stable with typed varinfo." typed_vi = DynamicPPL.typed_varinfo(model) - f_eval, argtypes_eval = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( - model, typed_vi + + @info "Evaluating with DefaultContext:" + model = DynamicPPL.contextualize( + model, + DynamicPPL.setleafcontext(model.context, DynamicPPL.DefaultContext()), + ) + f, argtypes = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( + model, varinfo + ) + JET.test_call(f, argtypes) + + @info "Initialising with InitContext:" + model = DynamicPPL.contextualize( + model, + DynamicPPL.setleafcontext(model.context, DynamicPPL.InitContext()), ) - JET.test_call(f_eval, argtypes_eval) - f_sample, argtypes_sample = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( - sampling_model, typed_vi + f, argtypes = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( + model, varinfo ) - JET.test_call(f_sample, argtypes_sample) + JET.test_call(f, argtypes) end end end diff --git a/test/ext/DynamicPPLMCMCChainsExt.jl b/test/ext/DynamicPPLMCMCChainsExt.jl index 3ba5edfe1..79e13ad84 100644 --- a/test/ext/DynamicPPLMCMCChainsExt.jl +++ b/test/ext/DynamicPPLMCMCChainsExt.jl @@ -2,7 +2,12 @@ @model demo() = x ~ Normal() model = demo() - chain = MCMCChains.Chains(randn(1000, 2, 1), [:x, :y], Dict(:internals => [:y])) + chain = MCMCChains.Chains( + randn(1000, 2, 1), + [:x, :y], + Dict(:internals => [:y]); + info=(; varname_to_symbol=Dict(@varname(x) => :x)), + ) chain_generated = @test_nowarn returned(model, chain) @test size(chain_generated) == (1000, 1) @test mean(chain_generated) ≈ 0 atol = 0.1 diff --git a/test/model.jl b/test/model.jl index 2234dde8f..7374f73aa 100644 --- a/test/model.jl +++ b/test/model.jl @@ -155,24 +155,6 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() logjoint(model, chain) end - @testset "rng" begin - model = GDEMO_DEFAULT - - for sampler in (SampleFromPrior(), SampleFromUniform()) - for i in 1:10 - Random.seed!(100 + i) - vi = VarInfo() - DynamicPPL.evaluate_and_sample!!(Random.default_rng(), model, vi, sampler) - vals = vi[:] - - Random.seed!(100 + i) - vi = VarInfo() - DynamicPPL.evaluate_and_sample!!(Random.default_rng(), model, vi, sampler) - @test vi[:] == vals - end - end - end - @testset "defaults without VarInfo, Sampler, and Context" begin model = GDEMO_DEFAULT @@ -332,7 +314,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() @test logjoint(model, x) != DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian(model, x...) # Ensure `varnames` is implemented. - vi = last(DynamicPPL.evaluate_and_sample!!(model, SimpleVarInfo(OrderedDict()))) + vi = last(DynamicPPL.init!!(model, SimpleVarInfo(OrderedDict{VarName,Any}()))) @test all(collect(keys(vi)) .== DynamicPPL.TestUtils.varnames(model)) # Ensure `posterior_mean` is implemented. @test DynamicPPL.TestUtils.posterior_mean(model) isa typeof(x) @@ -513,7 +495,11 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() # Construct a chain with 'sampled values' of β ground_truth_β = 2 - β_chain = MCMCChains.Chains(rand(Normal(ground_truth_β, 0.002), 1000), [:β]) + β_chain = MCMCChains.Chains( + rand(Normal(ground_truth_β, 0.002), 1000), + [:β]; + info=(; varname_to_symbol=Dict(@varname(β) => :β)), + ) # Generate predictions from that chain xs_test = [10 + 0.1, 10 + 2 * 0.1] @@ -559,7 +545,9 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() @testset "prediction from multiple chains" begin # Normal linreg model multiple_β_chain = MCMCChains.Chains( - reshape(rand(Normal(ground_truth_β, 0.002), 1000, 2), 1000, 1, 2), [:β] + reshape(rand(Normal(ground_truth_β, 0.002), 1000, 2), 1000, 1, 2), + [:β]; + info=(; varname_to_symbol=Dict(@varname(β) => :β)), ) predictions = DynamicPPL.predict(m_lin_reg_test, multiple_β_chain) @test size(multiple_β_chain, 3) == size(predictions, 3) @@ -584,43 +572,6 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() end end end - - @testset "with AbstractVector{<:AbstractVarInfo}" begin - @model function linear_reg(x, y, σ=0.1) - β ~ Normal(1, 1) - for i in eachindex(y) - y[i] ~ Normal(β * x[i], σ) - end - end - - ground_truth_β = 2.0 - # the data will be ignored, as we are generating samples from the prior - xs_train = 1:0.1:10 - ys_train = ground_truth_β .* xs_train + rand(Normal(0, 0.1), length(xs_train)) - m_lin_reg = linear_reg(xs_train, ys_train) - chain = [ - last(DynamicPPL.evaluate_and_sample!!(m_lin_reg, VarInfo())) for - _ in 1:10000 - ] - - # chain is generated from the prior - @test mean([chain[i][@varname(β)] for i in eachindex(chain)]) ≈ 1.0 atol = 0.1 - - xs_test = [10 + 0.1, 10 + 2 * 0.1] - m_lin_reg_test = linear_reg(xs_test, fill(missing, length(xs_test))) - predicted_vis = DynamicPPL.predict(m_lin_reg_test, chain) - - @test size(predicted_vis) == size(chain) - @test Set(keys(predicted_vis[1])) == - Set([@varname(β), @varname(y[1]), @varname(y[2])]) - # because β samples are from the prior, the std will be larger - @test mean([ - predicted_vis[i][@varname(y[1])] for i in eachindex(predicted_vis) - ]) ≈ 1.0 * xs_test[1] rtol = 0.1 - @test mean([ - predicted_vis[i][@varname(y[2])] for i in eachindex(predicted_vis) - ]) ≈ 1.0 * xs_test[2] rtol = 0.1 - end end @testset "ProductNamedTupleDistribution sampling" begin diff --git a/test/sampler.jl b/test/sampler.jl index 5eb0da057..c812de938 100644 --- a/test/sampler.jl +++ b/test/sampler.jl @@ -1,4 +1,17 @@ @testset "sampler.jl" begin + @testset "varnames with same symbol but different type" begin + struct S <: AbstractMCMC.AbstractSampler end + DynamicPPL.initialstep(rng, model, ::DynamicPPL.Sampler{S}, vi; kwargs...) = vi + @model function g() + y = (; a=1, b=2) + y.a ~ Normal() + return y.b ~ Normal() + end + model = g() + spl = DynamicPPL.Sampler(S()) + @test AbstractMCMC.step(Xoshiro(468), g(), spl) isa Any + end + @testset "initial_state and resume_from kwargs" begin # Model is unused, but has to be a DynamicPPL.Model otherwise we won't hit our # overloaded method. @@ -126,8 +139,8 @@ @test length(chains) == N # `m` is Gaussian, i.e. no transformation is used, so it - # should have a mean equal to its prior, i.e. 2. - @test mean(vi[@varname(m)] for vi in chains) ≈ 2 atol = 0.1 + # will be drawn from U[-2, 2] and its mean should be 0. + @test mean(vi[@varname(m)] for vi in chains) ≈ 0.0 atol = 0.1 # Expected value of ``exp(X)`` where ``X ~ U[-2, 2]`` is ≈ 1.8. @test mean(vi[@varname(s)] for vi in chains) ≈ 1.8 atol = 0.1 @@ -170,8 +183,8 @@ end # initial samplers - DynamicPPL.initialsampler(::Sampler{OnlyInitAlgUniform}) = SampleFromUniform() - @test DynamicPPL.initialsampler(Sampler(OnlyInitAlgDefault())) == SampleFromPrior() + DynamicPPL.init_strategy(::Sampler{OnlyInitAlgUniform}) = InitFromUniform() + @test DynamicPPL.init_strategy(Sampler(OnlyInitAlgDefault())) == InitFromPrior() for alg in (OnlyInitAlgDefault(), OnlyInitAlgUniform()) # model with one variable: initialization p = 0.2 @@ -182,7 +195,7 @@ model = coinflip() sampler = Sampler(alg) lptrue = logpdf(Binomial(25, 0.2), 10) - let inits = (; p=0.2) + let inits = InitFromParams((; p=0.2)) chain = sample(model, sampler, 1; initial_params=inits, progress=false) @test chain[1].metadata.p.vals == [0.2] @test getlogjoint(chain[1]) == lptrue @@ -210,7 +223,7 @@ end model = twovars() lptrue = logpdf(InverseGamma(2, 3), 4) + logpdf(Normal(0, 2), -1) - for inits in ([4, -1], (; s=4, m=-1)) + let inits = InitFromParams((; s=4, m=-1)) chain = sample(model, sampler, 1; initial_params=inits, progress=false) @test chain[1].metadata.s.vals == [4] @test chain[1].metadata.m.vals == [-1] @@ -234,7 +247,7 @@ end # set only m = -1 - for inits in ([missing, -1], (; s=missing, m=-1), (; m=-1)) + for inits in (InitFromParams((; s=missing, m=-1)), InitFromParams((; m=-1))) chain = sample(model, sampler, 1; initial_params=inits, progress=false) @test !ismissing(chain[1].metadata.s.vals[1]) @test chain[1].metadata.m.vals == [-1] @@ -254,54 +267,6 @@ @test c[1].metadata.m.vals == [-1] end end - - # specify `initial_params=nothing` - Random.seed!(1234) - chain1 = sample(model, sampler, 1; progress=false) - Random.seed!(1234) - chain2 = sample(model, sampler, 1; initial_params=nothing, progress=false) - @test_throws DimensionMismatch sample( - model, sampler, 1; progress=false, initial_params=zeros(10) - ) - @test chain1[1].metadata.m.vals == chain2[1].metadata.m.vals - @test chain1[1].metadata.s.vals == chain2[1].metadata.s.vals - - # parallel sampling - Random.seed!(1234) - chains1 = sample(model, sampler, MCMCThreads(), 1, 10; progress=false) - Random.seed!(1234) - chains2 = sample( - model, sampler, MCMCThreads(), 1, 10; initial_params=nothing, progress=false - ) - for (c1, c2) in zip(chains1, chains2) - @test c1[1].metadata.m.vals == c2[1].metadata.m.vals - @test c1[1].metadata.s.vals == c2[1].metadata.s.vals - end - end - - @testset "error handling" begin - # https://github.com/TuringLang/Turing.jl/issues/2452 - @model function constrained_uniform(n) - Z ~ Uniform(10, 20) - X = Vector{Float64}(undef, n) - for i in 1:n - X[i] ~ Uniform(0, Z) - end - end - - n = 2 - initial_z = 15 - initial_x = [0.2, 0.5] - model = constrained_uniform(n) - vi = VarInfo(model) - - @test_throws ArgumentError DynamicPPL.initialize_parameters!!( - vi, [initial_z, initial_x], model - ) - - @test_throws ArgumentError DynamicPPL.initialize_parameters!!( - vi, (X=initial_x, Z=initial_z), model - ) end end end diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index 526fce92c..01cbfc593 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -160,7 +160,7 @@ ### Sampling ### # Sample a new varinfo! - _, svi_new = DynamicPPL.evaluate_and_sample!!(model, svi) + _, svi_new = DynamicPPL.init!!(model, svi) # Realization for `m` should be different wp. 1. for vn in DynamicPPL.TestUtils.varnames(model) @@ -228,9 +228,9 @@ # Initialize. svi_nt = DynamicPPL.settrans!!(SimpleVarInfo(), true) - svi_nt = last(DynamicPPL.evaluate_and_sample!!(model, svi_nt)) + svi_nt = last(DynamicPPL.init!!(model, svi_nt)) svi_vnv = DynamicPPL.settrans!!(SimpleVarInfo(DynamicPPL.VarNamedVector()), true) - svi_vnv = last(DynamicPPL.evaluate_and_sample!!(model, svi_vnv)) + svi_vnv = last(DynamicPPL.init!!(model, svi_vnv)) for svi in (svi_nt, svi_vnv) # Sample with large variations in unconstrained space. @@ -275,7 +275,7 @@ ) # Resulting varinfo should no longer be transformed. - vi_result = last(DynamicPPL.evaluate_and_sample!!(model, deepcopy(vi))) + vi_result = last(DynamicPPL.init!!(model, deepcopy(vi))) @test !DynamicPPL.istrans(vi_result) # Set the values to something that is out of domain if we're in constrained space. diff --git a/test/test_util.jl b/test/test_util.jl index c6762ed45..164751c7b 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -87,8 +87,10 @@ function make_chain_from_prior(rng::Random.AbstractRNG, model::Model, n_iters::I varnames = collect(varnames) # Construct matrix of values vals = [get(dict, vn, missing) for dict in dicts, vn in varnames] + # Construct dict of varnames -> symbol + vn_to_sym_dict = Dict(zip(varnames, map(Symbol, varnames))) # Construct and return the Chains object - return Chains(vals, varnames) + return Chains(vals, varnames; info=(; varname_to_symbol=vn_to_sym_dict)) end function make_chain_from_prior(model::Model, n_iters::Int) return make_chain_from_prior(Random.default_rng(), model, n_iters) diff --git a/test/varinfo.jl b/test/varinfo.jl index f36af44a1..75d8e062b 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -42,7 +42,7 @@ end end model = gdemo(1.0, 2.0) - vi = DynamicPPL.untyped_varinfo(model, SampleFromUniform()) + _, vi = DynamicPPL.init!!(model, VarInfo(), InitFromUniform()) tvi = DynamicPPL.typed_varinfo(vi) meta = vi.metadata @@ -325,7 +325,7 @@ end @test typed_vi[vn_y] == 2.0 end - @testset "setval! & setval_and_resample!" begin + @testset "setval!" begin @model function testmodel(x) n = length(x) s ~ truncated(Normal(); lower=0) @@ -376,8 +376,8 @@ end else DynamicPPL.setval!(vicopy, (m=zeros(5),)) end - # Setting `m` fails for univariate due to limitations of `setval!` - # and `setval_and_resample!`. See docstring of `setval!` for more info. + # Setting `m` fails for univariate due to limitations of `setval!`. + # See docstring of `setval!` for more info. if model == model_uv && vi in [vi_untyped, vi_typed] @test_broken vicopy[m_vns] == zeros(5) else @@ -402,57 +402,6 @@ end DynamicPPL.setval!(vicopy, (s=42,)) @test vicopy[m_vns] == 1:5 @test vicopy[s_vns] == 42 - - ### `setval_and_resample!` ### - if model == model_mv && vi == vi_untyped - # Trying to re-run model with `MvNormal` on `vi_untyped` will call - # `MvNormal(μ::Vector{Real}, Σ)` which causes `StackOverflowError` - # so we skip this particular case. - continue - end - - if vi in [vi_vnv, vi_vnv_typed] - # `setval_and_resample!` works differently for `VarNamedVector`: All - # values will be resampled when model(vicopy) is called. Hence the below - # tests are not applicable. - continue - end - - vicopy = deepcopy(vi) - DynamicPPL.setval_and_resample!(vicopy, (m=zeros(5),)) - model(vicopy) - # Setting `m` fails for univariate due to limitations of `subsumes(::String, ::String)` - if model == model_uv - @test_broken vicopy[m_vns] == zeros(5) - else - @test vicopy[m_vns] == zeros(5) - end - @test vicopy[s_vns] != vi[s_vns] - - # Ordering is NOT preserved. - DynamicPPL.setval_and_resample!( - vicopy, (; (Symbol("m[$i]") => i for i in (1, 3, 5, 4, 2))...) - ) - model(vicopy) - if model == model_uv - @test vicopy[m_vns] == 1:5 - else - @test vicopy[m_vns] == [1, 3, 5, 4, 2] - end - @test vicopy[s_vns] != vi[s_vns] - - # Correct ordering. - DynamicPPL.setval_and_resample!( - vicopy, (; (Symbol("m[$i]") => i for i in (1, 2, 3, 4, 5))...) - ) - model(vicopy) - @test vicopy[m_vns] == 1:5 - @test vicopy[s_vns] != vi[s_vns] - - DynamicPPL.setval_and_resample!(vicopy, (s=42,)) - model(vicopy) - @test vicopy[m_vns] != 1:5 - @test vicopy[s_vns] == 42 end end @@ -466,9 +415,6 @@ end ks = [@varname(x[1, 1]), @varname(x[2, 1]), @varname(x[1, 2]), @varname(x[2, 2])] DynamicPPL.setval!(vi, vi.metadata.x.vals, ks) @test vals_prev == vi.metadata.x.vals - - DynamicPPL.setval_and_resample!(vi, vi.metadata.x.vals, ks) - @test vals_prev == vi.metadata.x.vals end @testset "setval! on chain" begin @@ -533,17 +479,18 @@ end end model = gdemo([1.0, 1.5], [2.0, 2.5]) - # Check that instantiating the model using SampleFromUniform does not + # Check that instantiating the model using InitFromUniform does not # perform linking - # Note (penelopeysm): The purpose of using SampleFromUniform (SFU) - # specifically in this test is because SFU samples from the linked - # distribution i.e. in unconstrained space. However, it does this not - # by linking the varinfo but by transforming the distributions on the - # fly. That's why it's worth specifically checking that it can do this - # without having to change the VarInfo object. + # Note (penelopeysm): The purpose of using InitFromUniform specifically in + # this test is because it samples from the linked distribution i.e. in + # unconstrained space. However, it does this not by linking the varinfo + # but by transforming the distributions on the fly. That's why it's + # worth specifically checking that it can do this without having to + # change the VarInfo object. + # TODO(penelopeysm): Move this to InitFromUniform tests rather than here. vi = VarInfo() meta = vi.metadata - _, vi = DynamicPPL.evaluate_and_sample!!(model, vi, SampleFromUniform()) + _, vi = DynamicPPL.init!!(model, vi, InitFromUniform()) @test all(x -> !istrans(vi, x), meta.vns) # Check that linking and invlinking set the `trans` flag accordingly @@ -607,7 +554,7 @@ end function test_linked_varinfo(model, vi) # vn and dist are taken from the containing scope - vi = last(DynamicPPL.evaluate_and_sample!!(model, vi)) + vi = last(DynamicPPL.init!!(model, vi, InitFromPrior())) f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) x = f(DynamicPPL.getindex_internal(vi, vn)) @test istrans(vi, vn) @@ -618,6 +565,11 @@ end @test getlogprior(vi) ≈ Bijectors.logpdf_with_trans(dist, x, false) end + ### `VarInfo` + # Need to run once since we can't specify that we want to _sample_ + # in the unconstrained space for `VarInfo` without having `vn` + # present in the `varinfo`. + ## `untyped_varinfo` vi = DynamicPPL.untyped_varinfo(model) vi = DynamicPPL.settrans!!(vi, true, vn) @@ -628,11 +580,6 @@ end vi = DynamicPPL.settrans!!(vi, true, vn) test_linked_varinfo(model, vi) - ## `typed_varinfo` - vi = DynamicPPL.typed_varinfo(model) - vi = DynamicPPL.settrans!!(vi, true, vn) - test_linked_varinfo(model, vi) - ### `SimpleVarInfo` ## `SimpleVarInfo{<:NamedTuple}` vi = DynamicPPL.settrans!!(SimpleVarInfo(), true) @@ -1012,45 +959,6 @@ end @test merge(vi_double, vi_single)[vn] == 1.0 end - @testset "sampling from linked varinfo" begin - # `~` - @model function demo(n=1) - x = Vector(undef, n) - for i in eachindex(x) - x[i] ~ Exponential() - end - return x - end - model1 = demo(1) - varinfo1 = DynamicPPL.link!!(VarInfo(model1), model1) - # Sampling from `model2` should hit the `istrans(vi) == true` branches - # because all the existing variables are linked. - model2 = demo(2) - varinfo2 = last(DynamicPPL.evaluate_and_sample!!(model2, deepcopy(varinfo1))) - for vn in [@varname(x[1]), @varname(x[2])] - @test DynamicPPL.istrans(varinfo2, vn) - end - - # `.~` - @model function demo_dot(n=1) - x ~ Exponential() - if n > 1 - y = Vector(undef, n - 1) - y .~ Exponential() - end - return x - end - model1 = demo_dot(1) - varinfo1 = DynamicPPL.link!!(DynamicPPL.untyped_varinfo(model1), model1) - # Sampling from `model2` should hit the `istrans(vi) == true` branches - # because all the existing variables are linked. - model2 = demo_dot(2) - varinfo2 = last(DynamicPPL.evaluate_and_sample!!(model2, deepcopy(varinfo1))) - for vn in [@varname(x), @varname(y[1])] - @test DynamicPPL.istrans(varinfo2, vn) - end - end - # NOTE: It is not yet clear if this is something we want from all varinfo types. # Hence, we only test the `VarInfo` types here. @testset "vector_getranges for `VarInfo`" begin diff --git a/test/varnamedvector.jl b/test/varnamedvector.jl index 57a8175d4..af24be86f 100644 --- a/test/varnamedvector.jl +++ b/test/varnamedvector.jl @@ -610,9 +610,7 @@ end DynamicPPL.TestUtils.test_values(varinfo_eval, value_true, vns) # Is sampling correct? - varinfo_sample = last( - DynamicPPL.evaluate_and_sample!!(model, deepcopy(varinfo)) - ) + varinfo_sample = last(DynamicPPL.init!!(model, deepcopy(varinfo))) # Log density should be different. @test getlogjoint(varinfo_sample) != getlogjoint(varinfo) # Values should be different. From d3d32e49b0cd4c1ff46d87ef0527f0d21bcbedf2 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 24 Sep 2025 16:48:37 +0100 Subject: [PATCH 08/14] `InitContext`, part 5 - Remove `SamplingContext`, `SampleFrom{Prior,Uniform}`, `{tilde_,}assume` (#985) * Remove `SamplingContext` for good * Remove `tilde_assume` as well * Split up tilde_observe!! for Distribution / Submodel * Tidy up tilde-pipeline methods and docstrings * Fix tests * fix ambiguity * Add changelog * Update HISTORY.md Co-authored-by: Markus Hauru --------- Co-authored-by: Markus Hauru --- HISTORY.md | 62 +++++++++- docs/src/api.md | 27 ++-- ext/DynamicPPLEnzymeCoreExt.jl | 2 - src/DynamicPPL.jl | 8 +- src/context_implementations.jl | 217 +++++++++++++-------------------- src/contexts.jl | 71 +---------- src/contexts/init.jl | 10 +- src/debug_utils.jl | 2 +- src/sampler.jl | 45 ------- src/simple_varinfo.jl | 19 --- src/transforming.jl | 15 ++- src/utils.jl | 44 ------- test/Project.toml | 2 - test/ad.jl | 43 ------- test/contexts.jl | 22 +--- test/debug_utils.jl | 2 +- test/ext/DynamicPPLJETExt.jl | 6 + test/lkj.jl | 34 ++---- test/sampler.jl | 54 -------- test/threadsafe.jl | 10 +- 20 files changed, 205 insertions(+), 490 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index ddbe67842..d67afcbfe 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -2,8 +2,66 @@ ## 0.38.0 -The `varname_leaves` and `varname_and_value_leaves` functions have been moved to AbstractPPL.jl. -Their behaviour is otherwise identical. +**Breaking changes** + +### Introduction of `InitContext` + +DynamicPPL 0.38 introduces a new evaluation context, `InitContext`. +It is used to generate fresh values for random variables in a model. + +Evaluation contexts are stored inside a `DynamicPPL.Model` object, and control what happens with tilde-statements when a model is run. +The two major leaf (basic) contexts are `DefaultContext` and, now, `InitContext`. +`DefaultContext` is the default context, and it simply uses the values that are already stored in the `VarInfo` object passed to the model evaluation function. +On the other hand, `InitContext` ignores values in the VarInfo object and inserts new values obtained from a specified source. +(It follows also that the VarInfo being used may be empty, which means that `InitContext` is now also the way to obtain a fresh VarInfo for a model.) + +DynamicPPL 0.38 provides three flavours of _initialisation strategies_, which are specified as the second argument to `InitContext`: + + - `InitContext(rng, InitFromPrior())`: New values are sampled from the prior distribution (on the right-hand side of the tilde). + - `InitContext(rng, InitFromUniform(a, b))`: New values are sampled uniformly from the interval `[a, b]`, and then invlinked to the support of the distribution on the right-hand side of the tilde. + - `InitContext(rng, InitFromParams(p, fallback))`: New values are obtained by indexing into the `p` object, which can be a `NamedTuple` or `Dict{<:VarName}`. If a variable is not found in `p`, then the `fallback` strategy is used, which is simply another of these strategies. In particular, `InitFromParams` enables the case where different variables are to be initialised from different sources. + +(It is possible to define your own initialisation strategy; users who wish to do so are referred to the DynamicPPL API documentation and source code.) + +**The main impact on the upcoming Turing.jl release** is that, instead of providing initial values for sampling, the user will be expected to provide an initialisation strategy instead. +This is a more flexible approach, and not only solves a number of pre-existing issues with initialisation of Turing models, but also improves the clarity of user code. +In particular: + + - When providing a set of fixed parameters (i.e. `InitFromParams(p)`), `p` must now either be a NamedTuple or a Dict. Previously Vectors were allowed, which is error-prone because the ordering of variables in a VarInfo is not obvious. + - The parameters in `p` must now always be provided in unlinked space (i.e., in the space of the distribution on the right-hand side of the tilde). Previously, whether a parameter was expected to be in linked or unlinked space depended on whether the VarInfo was linked or not, which was confusing. + +### Removal of `SamplingContext` + +For developers working on DynamicPPL, `InitContext` now completely replaces what used to be `SamplingContext`, `SampleFromPrior`, and `SampleFromUniform`. +Evaluating a model with `SamplingContext(SampleFromPrior())` (e.g. with `DynamicPPL.evaluate_and_sample!!(model, VarInfo(), SampleFromPrior())` has a direct one-to-one replacement in `DynamicPPL.init!!(model, VarInfo(), InitFromPrior())`. +Please see the docstring of `init!!` for more details. +Likewise `SampleFromUniform()` can be replaced with `InitFromUniform()`. +`InitFromParams()` provides new functionality which was previously implemented in the roundabout way of manipulating the VarInfo (e.g. using `unflatten`, or even more hackily by directly modifying values in the VarInfo), and then evaluating using `DefaultContext`. + +The main change that this is likely to create is for those who are implementing samplers or inference algorithms. +The exact way in which this happens will be detailed in the Turing.jl changelog when a new release is made. +Broadly speaking, though, `SamplingContext(MySampler())` will be removed so if your sampler needs custom behaviour with the tilde-pipeline you will likely have to define your own context. + +### Simplification of the tilde-pipeline + +There are now only two functions in the tilde-pipeline that need to be overloaded to change the behaviour of tilde-statements, namely, `tilde_assume!!` and `tilde_observe!!`. +Other functions such as `tilde_assume` and `assume` (and their `observe` counterparts) have been removed. + +Note that this was effectively already the case in DynamicPPL 0.37 (where they were just wrappers around each other). +The separation of these functions was primarily implemented to avoid performing extra work where unneeded (e.g. to not calculate the log-likelihood when `PriorContext` was being used). This functionality has since been replaced with accumulators (see the 0.37 changelog for more details). + +**Other changes** + +### Reimplementation of functions using `InitContext` + +A number of functions have been reimplemented and unified with the help of `InitContext`. +In particular, this release brings substantial performance improvements for `returned` and `predict`. +Their APIs are the same. + +### Upstreaming of VarName functionality + +The implementation of the `varname_leaves` and `varname_and_value_leaves` functions have been moved to AbstractPPL.jl. +Their behaviour is otherwise identical, and they are still accessible from the DynamicPPL module (though still not exported). ## 0.37.3 diff --git a/docs/src/api.md b/docs/src/api.md index d1dddb560..e5c483bca 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -8,7 +8,7 @@ Part of the API of DynamicPPL is defined in the more lightweight interface packa A core component of DynamicPPL is the [`@model`](@ref) macro. It can be used to define probabilistic models in an intuitive way by specifying random variables and their distributions with `~` statements. -These statements are rewritten by `@model` as calls of [internal functions](@ref model_internal) for sampling the variables and computing their log densities. +These statements are rewritten by `@model` as calls of internal functions for sampling the variables and computing their log densities. ```@docs @model @@ -344,6 +344,13 @@ Base.empty! SimpleVarInfo ``` +### Tilde-pipeline + +```@docs +tilde_assume!! +tilde_observe!! +``` + ### Accumulators The subtypes of [`AbstractVarInfo`](@ref) store the cumulative log prior and log likelihood, and sometimes other variables that change during executing, in what are called accumulators. @@ -447,12 +454,12 @@ AbstractPPL.evaluate!! This method mutates the `varinfo` used for execution. By default, it does not perform any actual sampling: it only evaluates the model using the values of the variables that are already in the `varinfo`. +If you wish to sample new values, see the section on [VarInfo initialisation](#VarInfo-initialisation) just below this. The behaviour of a model execution can be changed with evaluation contexts, which are a field of the model. Contexts are subtypes of `AbstractPPL.AbstractContext`. ```@docs -SamplingContext DefaultContext PrefixContext ConditionContext @@ -486,15 +493,7 @@ DynamicPPL.init ### Samplers -In DynamicPPL two samplers are defined that are used to initialize unobserved random variables: -[`SampleFromPrior`](@ref) which samples from the prior distribution, and [`SampleFromUniform`](@ref) which samples from a uniform distribution. - -```@docs -SampleFromPrior -SampleFromUniform -``` - -Additionally, a generic sampler for inference is implemented. +In DynamicPPL a generic sampler for inference is implemented. ```@docs Sampler @@ -520,9 +519,3 @@ There is also the _experimental_ [`DynamicPPL.Experimental.determine_suitable_va DynamicPPL.Experimental.determine_suitable_varinfo DynamicPPL.Experimental.is_suitable_varinfo ``` - -### [Model-Internal Functions](@id model_internal) - -```@docs -tilde_assume -``` diff --git a/ext/DynamicPPLEnzymeCoreExt.jl b/ext/DynamicPPLEnzymeCoreExt.jl index d592e76b3..0088f8908 100644 --- a/ext/DynamicPPLEnzymeCoreExt.jl +++ b/ext/DynamicPPLEnzymeCoreExt.jl @@ -8,8 +8,6 @@ else using ..EnzymeCore end -@inline EnzymeCore.EnzymeRules.inactive_type(::Type{<:DynamicPPL.SamplingContext}) = true - # Mark istrans as having 0 derivative. The `nothing` return value is not significant, Enzyme # only checks whether such a method exists, and never runs it. @inline EnzymeCore.EnzymeRules.inactive(::typeof(DynamicPPL.istrans), args...) = nothing diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 6a01884a9..edf44439e 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -96,18 +96,16 @@ export AbstractVarInfo, values_as_in_model, # Samplers Sampler, - SampleFromPrior, - SampleFromUniform, # LogDensityFunction LogDensityFunction, # Contexts contextualize, - SamplingContext, DefaultContext, PrefixContext, ConditionContext, - assume, - tilde_assume, + # Tilde pipeline + tilde_assume!!, + tilde_observe!!, # Initialisation InitContext, AbstractInitStrategy, diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 786d7c913..a8f2d57e6 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -1,42 +1,37 @@ -# assume """ - tilde_assume(context::SamplingContext, right, vn, vi) - -Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), -accumulate the log probability, and return the sampled value with a context associated -with a sampler. - -Falls back to -```julia -tilde_assume(context.rng, context.context, context.sampler, right, vn, vi) -``` + DynamicPPL.tilde_assume!!( + context::AbstractContext, + right::Distribution, + vn::VarName, + vi::AbstractVarInfo + ) + +Handle assumed variables, i.e. anything which is not observed (see +[`tilde_observe!!`](@ref)). Accumulate the associated log probability, and return the +sampled value and updated `vi`. + +`vn` is the VarName on the left-hand side of the tilde statement. """ -function tilde_assume(context::SamplingContext, right, vn, vi) - return tilde_assume(context.rng, context.context, context.sampler, right, vn, vi) -end - -function tilde_assume(context::AbstractContext, args...) - return tilde_assume(childcontext(context), args...) -end -function tilde_assume(::DefaultContext, right, vn, vi) - return assume(right, vn, vi) -end - -function tilde_assume(rng::Random.AbstractRNG, context::AbstractContext, args...) - return tilde_assume(rng, childcontext(context), args...) -end -function tilde_assume(rng::Random.AbstractRNG, ::DefaultContext, sampler, right, vn, vi) - return assume(rng, sampler, right, vn, vi) +function tilde_assume!!( + context::AbstractContext, right::Distribution, vn::VarName, vi::AbstractVarInfo +) + return tilde_assume!!(childcontext(context), right, vn, vi) end -function tilde_assume(::DefaultContext, sampler, right, vn, vi) - # same as above but no rng - return assume(Random.default_rng(), sampler, right, vn, vi) +function tilde_assume!!( + ::DefaultContext, right::Distribution, vn::VarName, vi::AbstractVarInfo +) + y = getindex_internal(vi, vn) + f = from_maybe_linked_internal_transform(vi, vn, right) + x, inv_logjac = with_logabsdet_jacobian(f, y) + vi = accumulate_assume!!(vi, x, -inv_logjac, vn, right) + return x, vi end - -function tilde_assume(context::PrefixContext, right, vn, vi) +function tilde_assume!!( + context::PrefixContext, right::Distribution, vn::VarName, vi::AbstractVarInfo +) # Note that we can't use something like this here: # new_vn = prefix(context, vn) - # return tilde_assume(childcontext(context), right, new_vn, vi) + # return tilde_assume!!(childcontext(context), right, new_vn, vi) # This is because `prefix` applies _all_ prefixes in a given context to a # variable name. Thus, if we had two levels of nested prefixes e.g. # `PrefixContext{:a}(PrefixContext{:b}(DefaultContext()))`, then the @@ -44,50 +39,64 @@ function tilde_assume(context::PrefixContext, right, vn, vi) # would apply the prefix `b._`, resulting in `b.a.b._`. # This is why we need a special function, `prefix_and_strip_contexts`. new_vn, new_context = prefix_and_strip_contexts(context, vn) - return tilde_assume(new_context, right, new_vn, vi) + return tilde_assume!!(new_context, right, new_vn, vi) end -function tilde_assume( - rng::Random.AbstractRNG, context::PrefixContext, sampler, right, vn, vi +""" + DynamicPPL.tilde_assume!!( + context::AbstractContext, + right::DynamicPPL.Submodel, + vn::VarName, + vi::AbstractVarInfo + ) + +Evaluate the submodel with the given context. +""" +function tilde_assume!!( + context::AbstractContext, right::DynamicPPL.Submodel, vn::VarName, vi::AbstractVarInfo ) - new_vn, new_context = prefix_and_strip_contexts(context, vn) - return tilde_assume(rng, new_context, sampler, right, new_vn, vi) + return _evaluate!!(right, vi, context, vn) end """ - tilde_assume!!(context, right, vn, vi) + tilde_observe!!( + context::AbstractContext, + right::Distribution, + left, + vn::Union{VarName, Nothing}, + vi::AbstractVarInfo + ) -Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), -accumulate the log probability, and return the sampled value and updated `vi`. +This function handles observed variables, which may be: -By default, calls `tilde_assume(context, right, vn, vi)` and accumulates the log -probability of `vi` with the returned value. -""" -function tilde_assume!!(context, right, vn, vi) - return if right isa DynamicPPL.Submodel - _evaluate!!(right, vi, context, vn) - else - tilde_assume(context, right, vn, vi) - end -end +- literals on the left-hand side, e.g., `3.0 ~ Normal()` +- a model input, e.g. `x ~ Normal()` in a model `@model f(x) ... end` +- a conditioned or fixed variable, e.g. `x ~ Normal()` in a model `model | (; x = 3.0)`. -# observe -""" - tilde_observe!!(context::SamplingContext, right, left, vi) +The relevant log-probability associated with the observation is computed and accumulated in +the VarInfo object `vi` (except for fixed variables, which do not contribute to the +log-probability). -Handle observed constants with a `context` associated with a sampler. +`left` is the actual value that the left-hand side evaluates to. `vn` is the VarName on the +left-hand side, or `nothing` if the left-hand side is a literal value. -Falls back to `tilde_observe!!(context.context, right, left, vi)`. +Observations of submodels are not yet supported in DynamicPPL. """ -function tilde_observe!!(context::SamplingContext, right, left, vn, vi) - return tilde_observe!!(context.context, right, left, vn, vi) -end - -function tilde_observe!!(context::AbstractContext, right, left, vn, vi) +function tilde_observe!!( + context::AbstractContext, + right::Distribution, + left, + vn::Union{VarName,Nothing}, + vi::AbstractVarInfo, +) return tilde_observe!!(childcontext(context), right, left, vn, vi) end - -# `PrefixContext` -function tilde_observe!!(context::PrefixContext, right, left, vn, vi) +function tilde_observe!!( + context::PrefixContext, + right::Distribution, + left, + vn::Union{VarName,Nothing}, + vi::AbstractVarInfo, +) # In the observe case, unlike assume, `vn` may be `nothing` if the LHS is a literal # value. For the need for prefix_and_strip_contexts rather than just prefix, see the # comment in `tilde_assume!!`. @@ -98,74 +107,22 @@ function tilde_observe!!(context::PrefixContext, right, left, vn, vi) end return tilde_observe!!(new_context, right, left, new_vn, vi) end - -""" - tilde_observe!!(context, right, left, vn, vi) - -Handle observed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), -accumulate the log probability, and return the observed value and updated `vi`. - -Falls back to `tilde_observe!!(context, right, left, vi)` ignoring the information about variable name -and indices; if needed, these can be accessed through this function, though. -""" -function tilde_observe!!(::DefaultContext, right, left, vn, vi) - right isa DynamicPPL.Submodel && - throw(ArgumentError("`x ~ to_submodel(...)` is not supported when `x` is observed")) +function tilde_observe!!( + ::DefaultContext, + right::Distribution, + left, + vn::Union{VarName,Nothing}, + vi::AbstractVarInfo, +) vi = accumulate_observe!!(vi, right, left, vn) return left, vi end - -function assume(::Random.AbstractRNG, spl::Sampler, dist) - return error("DynamicPPL.assume: unmanaged inference algorithm: $(typeof(spl))") -end - -# fallback without sampler -function assume(dist::Distribution, vn::VarName, vi) - y = getindex_internal(vi, vn) - f = from_maybe_linked_internal_transform(vi, vn, dist) - x, inv_logjac = with_logabsdet_jacobian(f, y) - vi = accumulate_assume!!(vi, x, -inv_logjac, vn, dist) - return x, vi -end - -# TODO: Remove this thing. -# SampleFromPrior and SampleFromUniform -function assume( - rng::Random.AbstractRNG, - sampler::Union{SampleFromPrior,SampleFromUniform}, - dist::Distribution, - vn::VarName, - vi::VarInfoOrThreadSafeVarInfo, +function tilde_observe!!( + ::AbstractContext, + ::DynamicPPL.Submodel, + left, + vn::Union{VarName,Nothing}, + ::AbstractVarInfo, ) - if haskey(vi, vn) - # Always overwrite the parameters with new ones for `SampleFromUniform`. - if sampler isa SampleFromUniform || is_flagged(vi, vn, "del") - # TODO(mhauru) Is it important to unset the flag here? The `true` allows us - # to ignore the fact that for VarNamedVector this does nothing, but I'm unsure - # if that's okay. - unset_flag!(vi, vn, "del", true) - r = init(rng, dist, sampler) - f = to_maybe_linked_internal_transform(vi, vn, dist) - # TODO(mhauru) This should probably be call a function called setindex_internal! - vi = BangBang.setindex!!(vi, f(r), vn) - else - # Otherwise we just extract it. - r = vi[vn, dist] - end - else - r = init(rng, dist, sampler) - if istrans(vi) - f = to_linked_internal_transform(vi, vn, dist) - vi = push!!(vi, vn, f(r), dist) - # By default `push!!` sets the transformed flag to `false`. - vi = settrans!!(vi, true, vn) - else - vi = push!!(vi, vn, r, dist) - end - end - - # HACK: The above code might involve an `invlink` somewhere, etc. so we need to correct. - logjac = logabsdetjac(istrans(vi, vn) ? link_transform(dist) : identity, r) - vi = accumulate_assume!!(vi, r, logjac, vn, dist) - return r, vi + throw(ArgumentError("`x ~ to_submodel(...)` is not supported when `x` is observed")) end diff --git a/src/contexts.jl b/src/contexts.jl index cd9876768..439da47e5 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -47,7 +47,7 @@ effectively updating the child context. ```jldoctest julia> using DynamicPPL: DynamicTransformationContext -julia> ctx = SamplingContext(); +julia> ctx = ConditionContext((; a = 1)); julia> DynamicPPL.childcontext(ctx) DefaultContext() @@ -121,73 +121,6 @@ setleafcontext(::IsLeaf, ::IsParent, left, right) = right setleafcontext(::IsLeaf, ::IsLeaf, left, right) = right # Contexts -""" - SamplingContext( - [rng::Random.AbstractRNG=Random.default_rng()], - [sampler::AbstractSampler=SampleFromPrior()], - [context::AbstractContext=DefaultContext()], - ) - -Create a context that allows you to sample parameters with the `sampler` when running the model. -The `context` determines how the returned log density is computed when running the model. - -See also: [`DefaultContext`](@ref) -""" -struct SamplingContext{S<:AbstractSampler,C<:AbstractContext,R} <: AbstractContext - rng::R - sampler::S - context::C -end - -function SamplingContext( - rng::Random.AbstractRNG=Random.default_rng(), sampler::AbstractSampler=SampleFromPrior() -) - return SamplingContext(rng, sampler, DefaultContext()) -end - -function SamplingContext( - sampler::AbstractSampler, context::AbstractContext=DefaultContext() -) - return SamplingContext(Random.default_rng(), sampler, context) -end - -function SamplingContext(rng::Random.AbstractRNG, context::AbstractContext) - return SamplingContext(rng, SampleFromPrior(), context) -end - -function SamplingContext(context::AbstractContext) - return SamplingContext(Random.default_rng(), SampleFromPrior(), context) -end - -NodeTrait(context::SamplingContext) = IsParent() -childcontext(context::SamplingContext) = context.context -function setchildcontext(parent::SamplingContext, child) - return SamplingContext(parent.rng, parent.sampler, child) -end - -""" - hassampler(context) - -Return `true` if `context` has a sampler. -""" -hassampler(::SamplingContext) = true -hassampler(context::AbstractContext) = hassampler(NodeTrait(context), context) -hassampler(::IsLeaf, context::AbstractContext) = false -hassampler(::IsParent, context::AbstractContext) = hassampler(childcontext(context)) - -""" - getsampler(context) - -Return the sampler of the context `context`. - -This will traverse the context tree until it reaches the first [`SamplingContext`](@ref), -at which point it will return the sampler of that context. -""" -getsampler(context::SamplingContext) = context.sampler -getsampler(context::AbstractContext) = getsampler(NodeTrait(context), context) -getsampler(::IsParent, context::AbstractContext) = getsampler(childcontext(context)) -getsampler(::IsLeaf, ::AbstractContext) = error("No sampler found in context") - """ struct DefaultContext <: AbstractContext end @@ -252,7 +185,7 @@ PrefixContexts removed. NOTE: This does _not_ modify any variables in any `ConditionContext` and `FixedContext` that may be present in the context stack. This is because this -function is only used in `tilde_assume`, which is lower in the tilde-pipeline +function is only used in `tilde_assume!!`, which is lower in the tilde-pipeline than `contextual_isassumption` and `contextual_isfixed` (the functions which actually use the `ConditionContext` and `FixedContext` values). Thus, by this time, any `ConditionContext`s and `FixedContext`s present have already served diff --git a/src/contexts/init.jl b/src/contexts/init.jl index 636847117..4baca1b57 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -154,7 +154,7 @@ struct InitContext{R<:Random.AbstractRNG,S<:AbstractInitStrategy} <: AbstractCon end NodeTrait(::InitContext) = IsLeaf() -function tilde_assume( +function tilde_assume!!( ctx::InitContext, dist::Distribution, vn::VarName, vi::AbstractVarInfo ) in_varinfo = haskey(vi, vn) @@ -191,6 +191,12 @@ function tilde_assume( return x, vi end -function tilde_observe!!(::InitContext, right, left, vn, vi) +function tilde_observe!!( + ::InitContext, + right::Distribution, + left, + vn::Union{VarName,Nothing}, + vi::AbstractVarInfo, +) return tilde_observe!!(DefaultContext(), right, left, vn, vi) end diff --git a/src/debug_utils.jl b/src/debug_utils.jl index c2be4b46b..2ec8b15a2 100644 --- a/src/debug_utils.jl +++ b/src/debug_utils.jl @@ -485,7 +485,7 @@ and checking if the model is consistent across runs. function has_static_constraints( rng::Random.AbstractRNG, model::Model; num_evals::Int=5, error_on_failure::Bool=false ) - new_model = DynamicPPL.contextualize(model, SamplingContext(rng, SampleFromPrior())) + new_model = DynamicPPL.contextualize(model, InitContext(rng)) results = map(1:num_evals) do _ check_model_and_trace(new_model, VarInfo(); error_on_failure=error_on_failure) end diff --git a/src/sampler.jl b/src/sampler.jl index 98b50ba55..8b49f6c3b 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -1,34 +1,3 @@ -# TODO: Make `UniformSampling` and `Prior` algs + just use `Sampler` -# That would let us use all defaults for Sampler, combine it with other samplers etc. -""" - SampleFromUniform - -Sampling algorithm that samples unobserved random variables from a uniform distribution. - -# References - -[Stan reference manual](https://mc-stan.org/docs/2_28/reference-manual/initialization.html#random-initial-values) -""" -struct SampleFromUniform <: AbstractSampler end - -""" - SampleFromPrior - -Sampling algorithm that samples unobserved random variables from their prior distribution. -""" -struct SampleFromPrior <: AbstractSampler end - -# Initializations. -init(rng, dist, ::SampleFromPrior) = rand(rng, dist) -function init(rng, dist, ::SampleFromUniform) - return istransformable(dist) ? inittrans(rng, dist) : rand(rng, dist) -end - -init(rng, dist, ::SampleFromPrior, n::Int) = rand(rng, dist, n) -function init(rng, dist, ::SampleFromUniform, n::Int) - return istransformable(dist) ? inittrans(rng, dist, n) : rand(rng, dist, n) -end - # TODO(mhauru) Could we get rid of Sampler now that it's just a wrapper around `alg`? # (Selector has been removed). """ @@ -49,20 +18,6 @@ struct Sampler{T} <: AbstractSampler alg::T end -# AbstractMCMC interface for SampleFromUniform and SampleFromPrior -function AbstractMCMC.step( - rng::Random.AbstractRNG, - model::Model, - sampler::Union{SampleFromUniform,SampleFromPrior}, - state=nothing; - kwargs..., -) - vi = VarInfo() - strategy = sampler isa SampleFromPrior ? InitFromPrior() : InitFromUniform() - _, new_vi = DynamicPPL.init!!(rng, model, vi, strategy) - return new_vi, nothing -end - """ default_varinfo(rng, model, sampler) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 27365e4dc..f430755e7 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -466,25 +466,6 @@ function Base.merge(varinfo_left::SimpleVarInfo, varinfo_right::SimpleVarInfo) return SimpleVarInfo(values, accs, transformation) end -# Context implementations -# NOTE: Evaluations, i.e. those without `rng` are shared with other -# implementations of `AbstractVarInfo`. -function assume( - rng::Random.AbstractRNG, - sampler::Union{SampleFromPrior,SampleFromUniform}, - dist::Distribution, - vn::VarName, - vi::SimpleOrThreadSafeSimple, -) - value = init(rng, dist, sampler) - # Transform if we're working in unconstrained space. - f = to_maybe_linked_internal_transform(vi, vn, dist) - value_raw, logjac = with_logabsdet_jacobian(f, value) - vi = BangBang.push!!(vi, vn, value_raw, dist) - vi = accumulate_assume!!(vi, value, logjac, vn, dist) - return value, vi -end - function settrans!!(vi::SimpleVarInfo, trans) return settrans!!(vi, trans ? DynamicTransformation() : NoTransformation()) end diff --git a/src/transforming.jl b/src/transforming.jl index 56f861cff..589dca031 100644 --- a/src/transforming.jl +++ b/src/transforming.jl @@ -12,8 +12,11 @@ how to do the transformation, used by e.g. `SimpleVarInfo`. struct DynamicTransformationContext{isinverse} <: AbstractContext end NodeTrait(::DynamicTransformationContext) = IsLeaf() -function tilde_assume( - ::DynamicTransformationContext{isinverse}, right, vn, vi +function tilde_assume!!( + ::DynamicTransformationContext{isinverse}, + right::Distribution, + vn::VarName, + vi::AbstractVarInfo, ) where {isinverse} # vi[vn, right] always provides the value in unlinked space. x = vi[vn, right] @@ -31,7 +34,13 @@ function tilde_assume( return x, vi end -function tilde_observe!!(::DynamicTransformationContext, right, left, vn, vi) +function tilde_observe!!( + ::DynamicTransformationContext, + right::Distribution, + left, + vn::Union{VarName,Nothing}, + vi::AbstractVarInfo, +) return tilde_observe!!(DefaultContext(), right, left, vn, vi) end diff --git a/src/utils.jl b/src/utils.jl index c7d1e089f..a4c5f4a1b 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -456,50 +456,6 @@ function recombine(d::MultivariateDistribution, val::AbstractVector, n::Int) return copy(reshape(val, length(d), n)) end -# Uniform random numbers with range 4 for robust initializations -# Reference: https://mc-stan.org/docs/2_19/reference-manual/initialization.html -randrealuni(rng::Random.AbstractRNG) = 4 * rand(rng) - 2 -randrealuni(rng::Random.AbstractRNG, args...) = 4 .* rand(rng, args...) .- 2 - -istransformable(dist) = link_transform(dist) !== identity - -################################# -# Single-sample initialisations # -################################# - -inittrans(rng, dist::UnivariateDistribution) = Bijectors.invlink(dist, randrealuni(rng)) -function inittrans(rng, dist::MultivariateDistribution) - # Get the length of the unconstrained vector - b = link_transform(dist) - d = Bijectors.output_length(b, length(dist)) - return Bijectors.invlink(dist, randrealuni(rng, d)) -end -function inittrans(rng, dist::MatrixDistribution) - # Get the size of the unconstrained vector - b = link_transform(dist) - sz = Bijectors.output_size(b, size(dist)) - return Bijectors.invlink(dist, randrealuni(rng, sz...)) -end -function inittrans(rng, dist::Distribution{CholeskyVariate}) - # Get the size of the unconstrained vector - b = link_transform(dist) - sz = Bijectors.output_size(b, size(dist)) - return Bijectors.invlink(dist, randrealuni(rng, sz...)) -end -################################ -# Multi-sample initialisations # -################################ - -function inittrans(rng, dist::UnivariateDistribution, n::Int) - return Bijectors.invlink(dist, randrealuni(rng, n)) -end -function inittrans(rng, dist::MultivariateDistribution, n::Int) - return Bijectors.invlink(dist, randrealuni(rng, size(dist)[1], n)) -end -function inittrans(rng, dist::MatrixDistribution, n::Int) - return Bijectors.invlink(dist, [randrealuni(rng, size(dist)...) for _ in 1:n]) -end - ####################### # Convenience methods # ####################### diff --git a/test/Project.toml b/test/Project.toml index 537214464..3e7ffcaea 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -11,7 +11,6 @@ Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" -EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -39,7 +38,6 @@ DifferentiationInterface = "0.6.41, 0.7" Distributions = "0.25" DistributionsAD = "0.6.3" Documenter = "1" -EnzymeCore = "0.6 - 0.8" ForwardDiff = "0.10.12, 1" JET = "0.9 - 0.10.6" LogDensityProblems = "2" diff --git a/test/ad.jl b/test/ad.jl index 0e5d8d7cf..23e676ee7 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -77,49 +77,6 @@ using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest end end - @testset "Turing#2151: ReverseDiff compilation & eltype(vi, spl)" begin - # Failing model - t = 1:0.05:8 - σ = 0.3 - y = @. rand(sin(t) + Normal(0, σ)) - @model function state_space(y, TT, ::Type{T}=Float64) where {T} - # Priors - α ~ Normal(y[1], 0.001) - τ ~ Exponential(1) - η ~ filldist(Normal(0, 1), TT - 1) - σ ~ Exponential(1) - # create latent variable - x = Vector{T}(undef, TT) - x[1] = α - for t in 2:TT - x[t] = x[t - 1] + η[t - 1] * τ - end - # measurement model - y ~ MvNormal(x, σ^2 * I) - return x - end - model = state_space(y, length(t)) - - # Dummy sampling algorithm for testing. The test case can only be replicated - # with a custom sampler, it doesn't work with SampleFromPrior(). We need to - # overload assume so that model evaluation doesn't fail due to a lack - # of implementation - struct MyEmptyAlg end - DynamicPPL.assume( - ::Random.AbstractRNG, ::DynamicPPL.Sampler{MyEmptyAlg}, dist, vn, vi - ) = DynamicPPL.assume(dist, vn, vi) - - # Compiling the ReverseDiff tape used to fail here - spl = Sampler(MyEmptyAlg()) - vi = DynamicPPL.link!!(VarInfo(model), model) - sampling_model = contextualize(model, SamplingContext(model.context)) - ldf = LogDensityFunction( - sampling_model, getlogjoint_internal, vi; adtype=AutoReverseDiff(; compile=true) - ) - x = ldf.varinfo[:] - @test LogDensityProblems.logdensity_and_gradient(ldf, x) isa Any - end - # Test that various different ways of specifying array types as arguments work with all # ADTypes. @testset "Array argument types" begin diff --git a/test/contexts.jl b/test/contexts.jl index 1a6279bf4..2687c4336 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -24,8 +24,6 @@ using DynamicPPL: using LinearAlgebra: I using Random: Xoshiro -using EnzymeCore - # TODO: Should we maybe put this in DPPL itself? function Base.iterate(context::AbstractContext) if NodeTrait(context) isa IsLeaf @@ -150,11 +148,11 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() vn = @varname(x[1]) ctx1 = PrefixContext(@varname(a)) @test DynamicPPL.prefix(ctx1, vn) == @varname(a.x[1]) - ctx2 = SamplingContext(ctx1) + ctx2 = ConditionContext(Dict{VarName,Any}(), ctx1) @test DynamicPPL.prefix(ctx2, vn) == @varname(a.x[1]) ctx3 = PrefixContext(@varname(b), ctx2) @test DynamicPPL.prefix(ctx3, vn) == @varname(b.a.x[1]) - ctx4 = DynamicPPL.SamplingContext(ctx3) + ctx4 = FixedContext(Dict(), ctx3) @test DynamicPPL.prefix(ctx4, vn) == @varname(b.a.x[1]) end @@ -203,22 +201,6 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() end end - @testset "SamplingContext" begin - context = SamplingContext(Random.default_rng(), SampleFromPrior(), DefaultContext()) - @test context isa SamplingContext - - # convenience constructors - @test SamplingContext() == context - @test SamplingContext(Random.default_rng()) == context - @test SamplingContext(SampleFromPrior()) == context - @test SamplingContext(DefaultContext()) == context - @test SamplingContext(Random.default_rng(), SampleFromPrior()) == context - @test SamplingContext(Random.default_rng(), DefaultContext()) == context - @test SamplingContext(SampleFromPrior(), DefaultContext()) == context - @test SamplingContext(SampleFromPrior(), DefaultContext()) == context - @test EnzymeCore.EnzymeRules.inactive_type(typeof(context)) - end - @testset "ConditionContext" begin @testset "Nesting" begin @testset "NamedTuple" begin diff --git a/test/debug_utils.jl b/test/debug_utils.jl index 5bf741ff3..f950f6b45 100644 --- a/test/debug_utils.jl +++ b/test/debug_utils.jl @@ -149,7 +149,7 @@ model = demo_missing_in_multivariate([1.0, missing]) # Have to run this check_model call with an empty varinfo, because actually # instantiating the VarInfo would cause it to throw a MethodError. - model = contextualize(model, SamplingContext()) + model = contextualize(model, InitContext()) @test_throws ErrorException check_model(model, VarInfo(); error_on_failure=true) end diff --git a/test/ext/DynamicPPLJETExt.jl b/test/ext/DynamicPPLJETExt.jl index 692f53911..b34424a1c 100644 --- a/test/ext/DynamicPPLJETExt.jl +++ b/test/ext/DynamicPPLJETExt.jl @@ -64,6 +64,12 @@ @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS # Use debug logging below. varinfo = DynamicPPL.Experimental.determine_suitable_varinfo(model) + # Check that the inferred varinfo is indeed suitable for evaluation + f_eval, argtypes_eval = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( + model, varinfo + ) + JET.test_call(f_eval, argtypes_eval) + # For our demo models, they should all result in typed. is_typed = varinfo isa DynamicPPL.NTVarInfo @test is_typed diff --git a/test/lkj.jl b/test/lkj.jl index d581cd21b..5c5603aba 100644 --- a/test/lkj.jl +++ b/test/lkj.jl @@ -16,20 +16,15 @@ end # Same for both distributions target_mean = vec(Matrix{Float64}(I, 2, 2)) +n_samples = 1000 _lkj_atol = 0.05 @testset "Sample from x ~ LKJ(2, 1)" begin model = lkj_prior_demo() - # `SampleFromPrior` will sample in constrained space. - @testset "SampleFromPrior" begin - samples = sample(model, SampleFromPrior(), 1_000; progress=false) - @test mean(map(Base.Fix2(getindex, Colon()), samples)) ≈ target_mean atol = - _lkj_atol - end - - # `SampleFromUniform` will sample in unconstrained space. - @testset "SampleFromUniform" begin - samples = sample(model, SampleFromUniform(), 1_000; progress=false) + for init_strategy in [InitFromPrior(), InitFromUniform()] + samples = [ + last(DynamicPPL.init!!(model, VarInfo(), init_strategy)) for _ in 1:n_samples + ] @test mean(map(Base.Fix2(getindex, Colon()), samples)) ≈ target_mean atol = _lkj_atol end @@ -37,21 +32,10 @@ end @testset "Sample from x ~ LKJCholesky(2, 1, $(uplo))" for uplo in ['U', 'L'] model = lkj_chol_prior_demo(uplo) - # `SampleFromPrior` will sample in unconstrained space. - @testset "SampleFromPrior" begin - samples = sample(model, SampleFromPrior(), 1_000; progress=false) - # Build correlation matrix from factor - corr_matrices = map(samples) do s - M = reshape(s.metadata.vals, (2, 2)) - pd_from_triangular(M, uplo) - end - @test vec(mean(corr_matrices)) ≈ target_mean atol = _lkj_atol - end - - # `SampleFromUniform` will sample in unconstrained space. - @testset "SampleFromUniform" begin - samples = sample(model, SampleFromUniform(), 1_000; progress=false) - # Build correlation matrix from factor + for init_strategy in [InitFromPrior(), InitFromUniform()] + samples = [ + last(DynamicPPL.init!!(model, VarInfo(), init_strategy)) for _ in 1:n_samples + ] corr_matrices = map(samples) do s M = reshape(s.metadata.vals, (2, 2)) pd_from_triangular(M, uplo) diff --git a/test/sampler.jl b/test/sampler.jl index c812de938..5380ad17e 100644 --- a/test/sampler.jl +++ b/test/sampler.jl @@ -113,60 +113,6 @@ end end - @testset "SampleFromPrior and SampleUniform" begin - @model function gdemo(x, y) - s ~ InverseGamma(2, 3) - m ~ Normal(2.0, sqrt(s)) - x ~ Normal(m, sqrt(s)) - return y ~ Normal(m, sqrt(s)) - end - - model = gdemo(1.0, 2.0) - N = 1_000 - - chains = sample(model, SampleFromPrior(), N; progress=false) - @test chains isa Vector{<:VarInfo} - @test length(chains) == N - - # Expected value of ``X`` where ``X ~ N(2, ...)`` is 2. - @test mean(vi[@varname(m)] for vi in chains) ≈ 2 atol = 0.15 - - # Expected value of ``X`` where ``X ~ IG(2, 3)`` is 3. - @test mean(vi[@varname(s)] for vi in chains) ≈ 3 atol = 0.2 - - chains = sample(model, SampleFromUniform(), N; progress=false) - @test chains isa Vector{<:VarInfo} - @test length(chains) == N - - # `m` is Gaussian, i.e. no transformation is used, so it - # will be drawn from U[-2, 2] and its mean should be 0. - @test mean(vi[@varname(m)] for vi in chains) ≈ 0.0 atol = 0.1 - - # Expected value of ``exp(X)`` where ``X ~ U[-2, 2]`` is ≈ 1.8. - @test mean(vi[@varname(s)] for vi in chains) ≈ 1.8 atol = 0.1 - end - - @testset "init" begin - @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS - N = 1000 - chain_init = sample(model, SampleFromUniform(), N; progress=false) - - for vn in keys(first(chain_init)) - if AbstractPPL.subsumes(@varname(s), vn) - # `s ~ InverseGamma(2, 3)` and its unconstrained value will be sampled from Unif[-2,2]. - dist = InverseGamma(2, 3) - b = DynamicPPL.link_transform(dist) - @test mean(mean(b(vi[vn])) for vi in chain_init) ≈ 0 atol = 0.11 - elseif AbstractPPL.subsumes(@varname(m), vn) - # `m ~ Normal(0, sqrt(s))` and its constrained value is the same. - @test mean(mean(vi[vn]) for vi in chain_init) ≈ 0 atol = 0.11 - else - error("Unknown variable name: $vn") - end - end - end - end - @testset "Initial parameters" begin # dummy algorithm that just returns initial value and does not perform any sampling abstract type OnlyInitAlg end diff --git a/test/threadsafe.jl b/test/threadsafe.jl index 0421c89e2..522730566 100644 --- a/test/threadsafe.jl +++ b/test/threadsafe.jl @@ -68,8 +68,7 @@ @time model(vi) # Ensure that we use `ThreadSafeVarInfo` to handle multithreaded observe statements. - sampling_model = contextualize(model, SamplingContext(model.context)) - DynamicPPL.evaluate_threadsafe!!(sampling_model, vi) + DynamicPPL.evaluate_threadsafe!!(model, vi) @test getlogjoint(vi) ≈ lp_w_threads # check that it's wrapped during the model evaluation @test vi_ isa DynamicPPL.ThreadSafeVarInfo @@ -77,7 +76,7 @@ @test vi isa VarInfo println(" evaluate_threadsafe!!:") - @time DynamicPPL.evaluate_threadsafe!!(sampling_model, vi) + @time DynamicPPL.evaluate_threadsafe!!(model, vi) @model function wothreads(x) global vi_ = __varinfo__ @@ -104,13 +103,12 @@ @test lp_w_threads ≈ lp_wo_threads # Ensure that we use `VarInfo`. - sampling_model = contextualize(model, SamplingContext(model.context)) - DynamicPPL.evaluate_threadunsafe!!(sampling_model, vi) + DynamicPPL.evaluate_threadunsafe!!(model, vi) @test getlogjoint(vi) ≈ lp_w_threads @test vi_ isa VarInfo @test vi isa VarInfo println(" evaluate_threadunsafe!!:") - @time DynamicPPL.evaluate_threadunsafe!!(sampling_model, vi) + @time DynamicPPL.evaluate_threadunsafe!!(model, vi) end end From 5a98037a91f87cdb5a82ed3f1820dfb76270dd9b Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 24 Sep 2025 23:51:35 +0100 Subject: [PATCH 09/14] fix missing import --- src/test_utils/sampler.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/test_utils/sampler.jl b/src/test_utils/sampler.jl index 3ef965bad..2101388fb 100644 --- a/src/test_utils/sampler.jl +++ b/src/test_utils/sampler.jl @@ -3,6 +3,8 @@ # # Utilities to test samplers on models. +using AbstractPPL: AbstractPPL + """ marginal_mean_of_samples(chain, varname) From 7311465a48d9f8c86bd05160b6437904e7a3b68d Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 25 Sep 2025 13:12:07 +0100 Subject: [PATCH 10/14] Shuffle context code around and remove dead code (#1050) --- src/DynamicPPL.jl | 6 +- src/abstract_varinfo.jl | 50 +- src/context_implementations.jl | 128 ---- src/contexts.jl | 580 ++---------------- src/contexts/conditionfix.jl | 467 ++++++++++++++ src/contexts/default.jl | 60 ++ src/contexts/prefix.jl | 116 ++++ .../transformation.jl} | 28 - src/debug_utils.jl | 1 - src/submodel.jl | 34 + src/utils.jl | 4 - 11 files changed, 755 insertions(+), 719 deletions(-) delete mode 100644 src/context_implementations.jl create mode 100644 src/contexts/conditionfix.jl create mode 100644 src/contexts/default.jl create mode 100644 src/contexts/prefix.jl rename src/{transforming.jl => contexts/transformation.jl} (61%) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index edf44439e..b1b3bc3d9 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -174,7 +174,11 @@ abstract type AbstractVarInfo <: AbstractModelTrace end include("utils.jl") include("chains.jl") include("contexts.jl") +include("contexts/default.jl") include("contexts/init.jl") +include("contexts/transformation.jl") +include("contexts/prefix.jl") +include("contexts/conditionfix.jl") # Must come after contexts/prefix.jl include("model.jl") include("sampler.jl") include("varname.jl") @@ -187,10 +191,8 @@ include("abstract_varinfo.jl") include("threadsafe.jl") include("varinfo.jl") include("simple_varinfo.jl") -include("context_implementations.jl") include("compiler.jl") include("pointwise_logdensities.jl") -include("transforming.jl") include("logdensityfunction.jl") include("model_utils.jl") include("extract_priors.jl") diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index ac841baab..b3cf77121 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -827,6 +827,27 @@ end function link!!(vi::AbstractVarInfo, vns::VarNameTuple, model::Model) return link!!(default_transformation(model, vi), vi, vns, model) end +function link!!(t::DynamicTransformation, vi::AbstractVarInfo, model::Model) + # Note that in practice this method is only called for SimpleVarInfo, because VarInfo + # has a dedicated implementation + ctx = DynamicTransformationContext{false}() + model = contextualize(model, setleafcontext(model.context, ctx)) + vi = last(evaluate!!(model, vi)) + return settrans!!(vi, t) +end +function link!!( + t::StaticTransformation{<:Bijectors.Transform}, vi::AbstractVarInfo, ::Model +) + b = inverse(t.bijector) + x = vi[:] + y, logjac = with_logabsdet_jacobian(b, x) + # Set parameters and add the logjac term. + vi = unflatten(vi, y) + if hasacc(vi, Val(:LogJacobian)) + vi = acclogjac!!(vi, logjac) + end + return settrans!!(vi, t) +end """ link([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) @@ -846,6 +867,9 @@ end function link(vi::AbstractVarInfo, vns::VarNameTuple, model::Model) return link(default_transformation(model, vi), vi, vns, model) end +function link(t::DynamicTransformation, vi::AbstractVarInfo, model::Model) + return link!!(t, deepcopy(vi), model) +end """ invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) @@ -866,23 +890,14 @@ end function invlink!!(vi::AbstractVarInfo, vns::VarNameTuple, model::Model) return invlink!!(default_transformation(model, vi), vi, vns, model) end - -# Vector-based ones. -function link!!( - t::StaticTransformation{<:Bijectors.Transform}, vi::AbstractVarInfo, ::Model -) - b = inverse(t.bijector) - x = vi[:] - y, logjac = with_logabsdet_jacobian(b, x) - - # Set parameters and add the logjac term. - vi = unflatten(vi, y) - if hasacc(vi, Val(:LogJacobian)) - vi = acclogjac!!(vi, logjac) - end - return settrans!!(vi, t) +function invlink!!(::DynamicTransformation, vi::AbstractVarInfo, model::Model) + # Note that in practice this method is only called for SimpleVarInfo, because VarInfo + # has a dedicated implementation + ctx = DynamicTransformationContext{true}() + model = contextualize(model, setleafcontext(model.context, ctx)) + vi = last(evaluate!!(model, vi)) + return settrans!!(vi, NoTransformation()) end - function invlink!!( t::StaticTransformation{<:Bijectors.Transform}, vi::AbstractVarInfo, ::Model ) @@ -919,6 +934,9 @@ end function invlink(vi::AbstractVarInfo, vns::VarNameTuple, model::Model) return invlink(default_transformation(model, vi), vi, vns, model) end +function invlink(t::DynamicTransformation, vi::AbstractVarInfo, model::Model) + return invlink!!(t, deepcopy(vi), model) +end """ maybe_invlink_before_eval!!([t::Transformation,] vi, model) diff --git a/src/context_implementations.jl b/src/context_implementations.jl deleted file mode 100644 index a8f2d57e6..000000000 --- a/src/context_implementations.jl +++ /dev/null @@ -1,128 +0,0 @@ -""" - DynamicPPL.tilde_assume!!( - context::AbstractContext, - right::Distribution, - vn::VarName, - vi::AbstractVarInfo - ) - -Handle assumed variables, i.e. anything which is not observed (see -[`tilde_observe!!`](@ref)). Accumulate the associated log probability, and return the -sampled value and updated `vi`. - -`vn` is the VarName on the left-hand side of the tilde statement. -""" -function tilde_assume!!( - context::AbstractContext, right::Distribution, vn::VarName, vi::AbstractVarInfo -) - return tilde_assume!!(childcontext(context), right, vn, vi) -end -function tilde_assume!!( - ::DefaultContext, right::Distribution, vn::VarName, vi::AbstractVarInfo -) - y = getindex_internal(vi, vn) - f = from_maybe_linked_internal_transform(vi, vn, right) - x, inv_logjac = with_logabsdet_jacobian(f, y) - vi = accumulate_assume!!(vi, x, -inv_logjac, vn, right) - return x, vi -end -function tilde_assume!!( - context::PrefixContext, right::Distribution, vn::VarName, vi::AbstractVarInfo -) - # Note that we can't use something like this here: - # new_vn = prefix(context, vn) - # return tilde_assume!!(childcontext(context), right, new_vn, vi) - # This is because `prefix` applies _all_ prefixes in a given context to a - # variable name. Thus, if we had two levels of nested prefixes e.g. - # `PrefixContext{:a}(PrefixContext{:b}(DefaultContext()))`, then the - # first call would apply the prefix `a.b._`, and the recursive call - # would apply the prefix `b._`, resulting in `b.a.b._`. - # This is why we need a special function, `prefix_and_strip_contexts`. - new_vn, new_context = prefix_and_strip_contexts(context, vn) - return tilde_assume!!(new_context, right, new_vn, vi) -end -""" - DynamicPPL.tilde_assume!!( - context::AbstractContext, - right::DynamicPPL.Submodel, - vn::VarName, - vi::AbstractVarInfo - ) - -Evaluate the submodel with the given context. -""" -function tilde_assume!!( - context::AbstractContext, right::DynamicPPL.Submodel, vn::VarName, vi::AbstractVarInfo -) - return _evaluate!!(right, vi, context, vn) -end - -""" - tilde_observe!!( - context::AbstractContext, - right::Distribution, - left, - vn::Union{VarName, Nothing}, - vi::AbstractVarInfo - ) - -This function handles observed variables, which may be: - -- literals on the left-hand side, e.g., `3.0 ~ Normal()` -- a model input, e.g. `x ~ Normal()` in a model `@model f(x) ... end` -- a conditioned or fixed variable, e.g. `x ~ Normal()` in a model `model | (; x = 3.0)`. - -The relevant log-probability associated with the observation is computed and accumulated in -the VarInfo object `vi` (except for fixed variables, which do not contribute to the -log-probability). - -`left` is the actual value that the left-hand side evaluates to. `vn` is the VarName on the -left-hand side, or `nothing` if the left-hand side is a literal value. - -Observations of submodels are not yet supported in DynamicPPL. -""" -function tilde_observe!!( - context::AbstractContext, - right::Distribution, - left, - vn::Union{VarName,Nothing}, - vi::AbstractVarInfo, -) - return tilde_observe!!(childcontext(context), right, left, vn, vi) -end -function tilde_observe!!( - context::PrefixContext, - right::Distribution, - left, - vn::Union{VarName,Nothing}, - vi::AbstractVarInfo, -) - # In the observe case, unlike assume, `vn` may be `nothing` if the LHS is a literal - # value. For the need for prefix_and_strip_contexts rather than just prefix, see the - # comment in `tilde_assume!!`. - new_vn, new_context = if vn !== nothing - prefix_and_strip_contexts(context, vn) - else - vn, childcontext(context) - end - return tilde_observe!!(new_context, right, left, new_vn, vi) -end -function tilde_observe!!( - ::DefaultContext, - right::Distribution, - left, - vn::Union{VarName,Nothing}, - vi::AbstractVarInfo, -) - vi = accumulate_observe!!(vi, right, left, vn) - return left, vi -end -function tilde_observe!!( - ::AbstractContext, - ::DynamicPPL.Submodel, - left, - vn::Union{VarName,Nothing}, - ::AbstractVarInfo, -) - throw(ArgumentError("`x ~ to_submodel(...)` is not supported when `x` is observed")) -end diff --git a/src/contexts.jl b/src/contexts.jl index 439da47e5..70f99a73f 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -1,6 +1,3 @@ -# Fallback traits -# TODO: Should this instead be `NoChildren()`, `HasChild()`, etc. so we allow plural too, e.g. `HasChildren()`? - """ NodeTrait(context) NodeTrait(f, context) @@ -120,559 +117,62 @@ end setleafcontext(::IsLeaf, ::IsParent, left, right) = right setleafcontext(::IsLeaf, ::IsLeaf, left, right) = right -# Contexts -""" - struct DefaultContext <: AbstractContext end - -The `DefaultContext` is used by default to accumulate values like the log joint probability -when running the model. -""" -struct DefaultContext <: AbstractContext end -NodeTrait(::DefaultContext) = IsLeaf() - -""" - PrefixContext(vn::VarName[, context::AbstractContext]) - PrefixContext(vn::Val{sym}[, context::AbstractContext]) where {sym} - -Create a context that allows you to use the wrapped `context` when running the model and -prefixes all parameters with the VarName `vn`. - -`PrefixContext(Val(:a), context)` is equivalent to `PrefixContext(@varname(a), context)`. -If `context` is not provided, it defaults to `DefaultContext()`. - -This context is useful in nested models to ensure that the names of the parameters are -unique. - -See also: [`to_submodel`](@ref) -""" -struct PrefixContext{Tvn<:VarName,C<:AbstractContext} <: AbstractContext - vn_prefix::Tvn - context::C -end -PrefixContext(vn::VarName) = PrefixContext(vn, DefaultContext()) -function PrefixContext(::Val{sym}, context::AbstractContext) where {sym} - return PrefixContext(VarName{sym}(), context) -end -PrefixContext(::Val{sym}) where {sym} = PrefixContext(VarName{sym}()) - -NodeTrait(::PrefixContext) = IsParent() -childcontext(context::PrefixContext) = context.context -function setchildcontext(ctx::PrefixContext, child::AbstractContext) - return PrefixContext(ctx.vn_prefix, child) -end - -""" - prefix(ctx::AbstractContext, vn::VarName) - -Apply the prefixes in the context `ctx` to the variable name `vn`. -""" -function prefix(ctx::PrefixContext, vn::VarName) - return AbstractPPL.prefix(prefix(childcontext(ctx), vn), ctx.vn_prefix) -end -function prefix(ctx::AbstractContext, vn::VarName) - return prefix(NodeTrait(ctx), ctx, vn) -end -prefix(::IsLeaf, ::AbstractContext, vn::VarName) = vn -function prefix(::IsParent, ctx::AbstractContext, vn::VarName) - return prefix(childcontext(ctx), vn) -end - """ - prefix_and_strip_contexts(ctx::PrefixContext, vn::VarName) - -Same as `prefix`, but additionally returns a new context stack that has all the -PrefixContexts removed. - -NOTE: This does _not_ modify any variables in any `ConditionContext` and -`FixedContext` that may be present in the context stack. This is because this -function is only used in `tilde_assume!!`, which is lower in the tilde-pipeline -than `contextual_isassumption` and `contextual_isfixed` (the functions which -actually use the `ConditionContext` and `FixedContext` values). Thus, by this -time, any `ConditionContext`s and `FixedContext`s present have already served -their purpose. - -If you call this function, you must therefore be careful to ensure that you _do -not_ need to modify any inner `ConditionContext`s and `FixedContext`s. If you -_do_ need to modify them, then you may need to use -`prefix_cond_and_fixed_variables` instead. -""" -function prefix_and_strip_contexts(ctx::PrefixContext, vn::VarName) - child_context = childcontext(ctx) - # vn_prefixed contains the prefixes from all lower levels - vn_prefixed, child_context_without_prefixes = prefix_and_strip_contexts( - child_context, vn + DynamicPPL.tilde_assume!!( + context::AbstractContext, + right::Distribution, + vn::VarName, + vi::AbstractVarInfo ) - return AbstractPPL.prefix(vn_prefixed, ctx.vn_prefix), child_context_without_prefixes -end -function prefix_and_strip_contexts(ctx::AbstractContext, vn::VarName) - return prefix_and_strip_contexts(NodeTrait(ctx), ctx, vn) -end -prefix_and_strip_contexts(::IsLeaf, ctx::AbstractContext, vn::VarName) = (vn, ctx) -function prefix_and_strip_contexts(::IsParent, ctx::AbstractContext, vn::VarName) - vn, new_ctx = prefix_and_strip_contexts(childcontext(ctx), vn) - return vn, setchildcontext(ctx, new_ctx) -end - -""" - - ConditionContext{Values<:Union{NamedTuple,AbstractDict},Ctx<:AbstractContext} - -Model context that contains values that are to be conditioned on. The values -can either be a NamedTuple mapping symbols to values, such as `(a=1, b=2)`, or -an AbstractDict mapping varnames to values (e.g. `Dict(@varname(a) => 1, -@varname(b) => 2)`). The former is more performant, but the latter must be used -when there are varnames that cannot be represented as symbols, e.g. -`@varname(x[1])`. -""" -struct ConditionContext{ - Values<:Union{NamedTuple,AbstractDict{<:VarName}},Ctx<:AbstractContext -} <: AbstractContext - values::Values - context::Ctx -end - -const NamedConditionContext{Names} = ConditionContext{<:NamedTuple{Names}} -const DictConditionContext = ConditionContext{<:AbstractDict} - -# Use DefaultContext as the default base context -function ConditionContext(values::Union{NamedTuple,AbstractDict}) - return ConditionContext(values, DefaultContext()) -end -# Optimisation when there are no values to condition on -ConditionContext(::NamedTuple{()}, context::AbstractContext) = context -# Same as above, and avoids method ambiguity with below -ConditionContext(::NamedTuple{()}, context::NamedConditionContext) = context -# Collapse consecutive levels of `ConditionContext`. Note that this overrides -# values inside the child context, thus giving precedence to the outermost -# `ConditionContext`. -function ConditionContext(values::NamedTuple, context::NamedConditionContext) - return ConditionContext(merge(context.values, values), childcontext(context)) -end -function ConditionContext(values::AbstractDict{<:VarName}, context::DictConditionContext) - return ConditionContext(merge(context.values, values), childcontext(context)) -end - -function Base.show(io::IO, context::ConditionContext) - return print(io, "ConditionContext($(context.values), $(childcontext(context)))") -end - -NodeTrait(::ConditionContext) = IsParent() -childcontext(context::ConditionContext) = context.context -setchildcontext(parent::ConditionContext, child) = ConditionContext(parent.values, child) - -""" - hasconditioned(context::AbstractContext, vn::VarName) - -Return `true` if `vn` is found in `context`. -""" -hasconditioned(context::AbstractContext, vn::VarName) = false -hasconditioned(context::ConditionContext, vn::VarName) = hasvalue(context.values, vn) -function hasconditioned(context::ConditionContext, vns::AbstractArray{<:VarName}) - return all(Base.Fix1(hasvalue, context.values), vns) -end - -""" - getconditioned(context::AbstractContext, vn::VarName) - -Return value of `vn` in `context`. -""" -function getconditioned(context::AbstractContext, vn::VarName) - return error("context $(context) does not contain value for $vn") -end -function getconditioned(context::ConditionContext, vn::VarName) - return getvalue(context.values, vn) -end - -""" - hasconditioned_nested(context, vn) - -Return `true` if `vn` is found in `context` or any of its descendants. - -This is contrast to [`hasconditioned(::AbstractContext, ::VarName)`](@ref) which only checks -for `vn` in `context`, not recursively checking if `vn` is in any of its descendants. -""" -function hasconditioned_nested(context::AbstractContext, vn) - return hasconditioned_nested(NodeTrait(hasconditioned_nested, context), context, vn) -end -hasconditioned_nested(::IsLeaf, context, vn) = hasconditioned(context, vn) -function hasconditioned_nested(::IsParent, context, vn) - return hasconditioned(context, vn) || hasconditioned_nested(childcontext(context), vn) -end -function hasconditioned_nested(context::PrefixContext, vn) - return hasconditioned_nested(collapse_prefix_stack(context), vn) -end - -""" - getconditioned_nested(context, vn) - -Return the value of the parameter corresponding to `vn` from `context` or its descendants. - -This is contrast to [`getconditioned`](@ref) which only returns the value `vn` in `context`, -not recursively looking into its descendants. -""" -function getconditioned_nested(context::AbstractContext, vn) - return getconditioned_nested(NodeTrait(getconditioned_nested, context), context, vn) -end -function getconditioned_nested(::IsLeaf, context, vn) - return error("context $(context) does not contain value for $vn") -end -function getconditioned_nested(context::PrefixContext, vn) - return getconditioned_nested(collapse_prefix_stack(context), vn) -end -function getconditioned_nested(::IsParent, context, vn) - return if hasconditioned(context, vn) - getconditioned(context, vn) - else - getconditioned_nested(childcontext(context), vn) - end -end - -""" - decondition(context::AbstractContext, syms...) - -Return `context` but with `syms` no longer conditioned on. - -Note that this recursively traverses contexts, deconditioning all along the way. - -See also: [`condition`](@ref) -""" -decondition_context(::IsLeaf, context, args...) = context -function decondition_context(::IsParent, context, args...) - return setchildcontext(context, decondition_context(childcontext(context), args...)) -end -function decondition_context(context, args...) - return decondition_context(NodeTrait(context), context, args...) -end -function decondition_context(context::ConditionContext) - return decondition_context(childcontext(context)) -end -function decondition_context(context::ConditionContext, sym, syms...) - new_values = deepcopy(context.values) - for s in (sym, syms...) - new_values = BangBang.delete!!(new_values, s) - end - return if length(new_values) == 0 - # No more values left, can unwrap - decondition_context(childcontext(context), syms...) - else - ConditionContext( - new_values, decondition_context(childcontext(context), sym, syms...) - ) - end -end -function decondition_context(context::NamedConditionContext, vn::VarName{sym}) where {sym} - return ConditionContext( - BangBang.delete!!(context.values, sym), - decondition_context(childcontext(context), vn), - ) -end - -""" - conditioned(context::AbstractContext) - -Return `NamedTuple` of values that are conditioned on under context`. - -Note that this will recursively traverse the context stack and return -a merged version of the condition values. -""" -function conditioned(context::AbstractContext) - return conditioned(NodeTrait(conditioned, context), context) -end -conditioned(::IsLeaf, context) = NamedTuple() -conditioned(::IsParent, context) = conditioned(childcontext(context)) -function conditioned(context::ConditionContext) - # Note the order of arguments to `merge`. The behavior of the rest of DPPL - # is that the outermost `context` takes precendence, hence when resolving - # the `conditioned` variables we need to ensure that `context.values` takes - # precedence over decendants of `context`. - return _merge(context.values, conditioned(childcontext(context))) -end -function conditioned(context::PrefixContext) - return conditioned(collapse_prefix_stack(context)) -end -struct FixedContext{Values,Ctx<:AbstractContext} <: AbstractContext - values::Values - context::Ctx -end - -const NamedFixedContext{Names} = FixedContext{<:NamedTuple{Names}} -const DictFixedContext = FixedContext{<:AbstractDict} - -FixedContext(values) = FixedContext(values, DefaultContext()) - -# Try to avoid nested `FixedContext`. -function FixedContext(values::NamedTuple, context::NamedFixedContext) - # Note that this potentially overrides values from `context`, thus giving - # precedence to the outmost `FixedContext`. - return FixedContext(merge(context.values, values), childcontext(context)) -end - -function Base.show(io::IO, context::FixedContext) - return print(io, "FixedContext($(context.values), $(childcontext(context)))") -end - -NodeTrait(::FixedContext) = IsParent() -childcontext(context::FixedContext) = context.context -setchildcontext(parent::FixedContext, child) = FixedContext(parent.values, child) - -""" - hasfixed(context::AbstractContext, vn::VarName) +Handle assumed variables, i.e. anything which is not observed (see +[`tilde_observe!!`](@ref)). Accumulate the associated log probability, and return the +sampled value and updated `vi`. -Return `true` if a fixed value for `vn` is found in `context`. -""" -hasfixed(context::AbstractContext, vn::VarName) = false -hasfixed(context::FixedContext, vn::VarName) = hasvalue(context.values, vn) -function hasfixed(context::FixedContext, vns::AbstractArray{<:VarName}) - return all(Base.Fix1(hasvalue, context.values), vns) -end +`vn` is the VarName on the left-hand side of the tilde statement. +This function should return a tuple `(x, vi)`, where `x` is the sampled value (which +must be in unlinked space!) and `vi` is the updated VarInfo. """ - getfixed(context::AbstractContext, vn::VarName) - -Return the fixed value of `vn` in `context`. -""" -function getfixed(context::AbstractContext, vn::VarName) - return error("context $(context) does not contain value for $vn") -end -getfixed(context::FixedContext, vn::VarName) = getvalue(context.values, vn) - -""" - hasfixed_nested(context, vn) - -Return `true` if a fixed value for `vn` is found in `context` or any of its descendants. - -This is contrast to [`hasfixed(::AbstractContext, ::VarName)`](@ref) which only checks -for `vn` in `context`, not recursively checking if `vn` is in any of its descendants. -""" -function hasfixed_nested(context::AbstractContext, vn) - return hasfixed_nested(NodeTrait(hasfixed_nested, context), context, vn) -end -hasfixed_nested(::IsLeaf, context, vn) = hasfixed(context, vn) -function hasfixed_nested(::IsParent, context, vn) - return hasfixed(context, vn) || hasfixed_nested(childcontext(context), vn) -end -function hasfixed_nested(context::PrefixContext, vn) - return hasfixed_nested(collapse_prefix_stack(context), vn) -end - -""" - getfixed_nested(context, vn) - -Return the fixed value of the parameter corresponding to `vn` from `context` or its descendants. - -This is contrast to [`getfixed`](@ref) which only returns the value `vn` in `context`, -not recursively looking into its descendants. -""" -function getfixed_nested(context::AbstractContext, vn) - return getfixed_nested(NodeTrait(getfixed_nested, context), context, vn) -end -function getfixed_nested(::IsLeaf, context, vn) - return error("context $(context) does not contain value for $vn") -end -function getfixed_nested(context::PrefixContext, vn) - return getfixed_nested(collapse_prefix_stack(context), vn) -end -function getfixed_nested(::IsParent, context, vn) - return if hasfixed(context, vn) - getfixed(context, vn) - else - getfixed_nested(childcontext(context), vn) - end -end - -""" - fix([context::AbstractContext,] values::NamedTuple) - fix([context::AbstractContext]; values...) - -Return `FixedContext` with `values` and `context` if `values` is non-empty, -otherwise return `context` which is [`DefaultContext`](@ref) by default. - -See also: [`unfix`](@ref) -""" -fix(; values...) = fix(NamedTuple(values)) -fix(values::NamedTuple) = fix(DefaultContext(), values) -function fix(value::Pair{<:VarName}, values::Pair{<:VarName}...) - return fix((value, values...)) -end -function fix(values::NTuple{<:Any,<:Pair{<:VarName}}) - return fix(DefaultContext(), values) -end -fix(context::AbstractContext, values::NamedTuple{()}) = context -function fix(context::AbstractContext, values::Union{AbstractDict,NamedTuple}) - return FixedContext(values, context) -end -function fix(context::AbstractContext; values...) - return fix(context, NamedTuple(values)) -end -function fix(context::AbstractContext, value::Pair{<:VarName}, values::Pair{<:VarName}...) - return fix(context, (value, values...)) -end -function fix(context::AbstractContext, values::NTuple{<:Any,Pair{<:VarName}}) - return fix(context, Dict(values)) +function tilde_assume!!( + context::AbstractContext, right::Distribution, vn::VarName, vi::AbstractVarInfo +) + return tilde_assume!!(childcontext(context), right, vn, vi) end """ - unfix(context::AbstractContext, syms...) - -Return `context` but with `syms` no longer fixed. - -Note that this recursively traverses contexts, unfixing all along the way. - -See also: [`fix`](@ref) -""" -unfix(::IsLeaf, context, args...) = context -function unfix(::IsParent, context, args...) - return setchildcontext(context, unfix(childcontext(context), args...)) -end -function unfix(context, args...) - return unfix(NodeTrait(context), context, args...) -end -function unfix(context::FixedContext) - return unfix(childcontext(context)) -end -function unfix(context::FixedContext, sym) - return fix(unfix(childcontext(context), sym), BangBang.delete!!(context.values, sym)) -end -function unfix(context::FixedContext, sym, syms...) - return unfix( - fix(unfix(childcontext(context), syms...), BangBang.delete!!(context.values, sym)), - syms..., + DynamicPPL.tilde_observe!!( + context::AbstractContext, + right::Distribution, + left, + vn::Union{VarName, Nothing}, + vi::AbstractVarInfo ) -end - -function unfix(context::NamedFixedContext, vn::VarName{sym}) where {sym} - return fix(unfix(childcontext(context), vn), BangBang.delete!!(context.values, sym)) -end -function unfix(context::FixedContext, vn::VarName) - return fix(unfix(childcontext(context), vn), BangBang.delete!!(context.values, vn)) -end - -""" - fixed(context::AbstractContext) - -Return the values that are fixed under `context`. - -Note that this will recursively traverse the context stack and return -a merged version of the fix values. -""" -fixed(context::AbstractContext) = fixed(NodeTrait(fixed, context), context) -fixed(::IsLeaf, context) = NamedTuple() -fixed(::IsParent, context) = fixed(childcontext(context)) -function fixed(context::FixedContext) - # Note the order of arguments to `merge`. The behavior of the rest of DPPL - # is that the outermost `context` takes precendence, hence when resolving - # the `fixed` variables we need to ensure that `context.values` takes - # precedence over decendants of `context`. - return _merge(context.values, fixed(childcontext(context))) -end -function fixed(context::PrefixContext) - return fixed(collapse_prefix_stack(context)) -end - -""" - collapse_prefix_stack(context::AbstractContext) - -Apply `PrefixContext`s to any conditioned or fixed values inside them, and remove -the `PrefixContext`s from the context stack. - -!!! note - If you are reading this docstring, you might probably be interested in a more -thorough explanation of how PrefixContext and ConditionContext / FixedContext -interact with one another, especially in the context of submodels. - The DynamicPPL documentation contains [a separate page on this -topic](https://turinglang.org/DynamicPPL.jl/previews/PR892/internals/submodel_condition/) -which explains this in much more detail. - -```jldoctest -julia> using DynamicPPL: collapse_prefix_stack - -julia> c1 = PrefixContext(@varname(a), ConditionContext((x=1, ))); - -julia> collapse_prefix_stack(c1) -ConditionContext(Dict(a.x => 1), DefaultContext()) -julia> # Here, `x` gets prefixed only with `a`, whereas `y` is prefixed with both. - c2 = PrefixContext(@varname(a), ConditionContext((x=1, ), PrefixContext(@varname(b), ConditionContext((y=2,))))); +This function handles observed variables, which may be: -julia> collapsed = collapse_prefix_stack(c2); - -julia> # `collapsed` really looks something like this: - # ConditionContext(Dict{VarName{:a}, Int64}(a.b.y => 2, a.x => 1), DefaultContext()) - # To avoid fragility arising from the order of the keys in the doctest, we test - # this indirectly: - collapsed.values[@varname(a.x)], collapsed.values[@varname(a.b.y)] -(1, 2) -``` -""" -function collapse_prefix_stack(context::PrefixContext) - # Collapse the child context (thus applying any inner prefixes first) - collapsed = collapse_prefix_stack(childcontext(context)) - # Prefix any conditioned variables with the current prefix - # Note: prefix_conditioned_variables is O(N) in the depth of the context stack. - # So is this function. In the worst case scenario, this is O(N^2) in the - # depth of the context stack. - return prefix_cond_and_fixed_variables(collapsed, context.vn_prefix) -end -function collapse_prefix_stack(context::AbstractContext) - return collapse_prefix_stack(NodeTrait(collapse_prefix_stack, context), context) -end -collapse_prefix_stack(::IsLeaf, context) = context -function collapse_prefix_stack(::IsParent, context) - new_child_context = collapse_prefix_stack(childcontext(context)) - return setchildcontext(context, new_child_context) -end - -""" - prefix_cond_and_fixed_variables(context::AbstractContext, prefix::VarName) +- literals on the left-hand side, e.g., `3.0 ~ Normal()` +- a model input, e.g. `x ~ Normal()` in a model `@model f(x) ... end` +- a conditioned or fixed variable, e.g. `x ~ Normal()` in a model `model | (; x = 3.0)`. -Prefix all the conditioned and fixed variables in a given context with a single -`prefix`. +The relevant log-probability associated with the observation is computed and accumulated in +the VarInfo object `vi` (except for fixed variables, which do not contribute to the +log-probability). -```jldoctest -julia> using DynamicPPL: prefix_cond_and_fixed_variables, ConditionContext +`left` is the actual value that the left-hand side evaluates to. `vn` is the VarName on the +left-hand side, or `nothing` if the left-hand side is a literal value. -julia> c1 = ConditionContext((a=1, )) -ConditionContext((a = 1,), DefaultContext()) +Observations of submodels are not yet supported in DynamicPPL. -julia> prefix_cond_and_fixed_variables(c1, @varname(y)) -ConditionContext(Dict(y.a => 1), DefaultContext()) -``` +This function should return a tuple `(left, vi)`, where `left` is the same as the input, and +`vi` is the updated VarInfo. """ -function prefix_cond_and_fixed_variables(ctx::ConditionContext, prefix::VarName) - # Replace the prefix of the conditioned variables - vn_dict = to_varname_dict(ctx.values) - prefixed_vn_dict = Dict( - AbstractPPL.prefix(vn, prefix) => value for (vn, value) in vn_dict - ) - # Prefix the child context as well - prefixed_child_ctx = prefix_cond_and_fixed_variables(childcontext(ctx), prefix) - return ConditionContext(prefixed_vn_dict, prefixed_child_ctx) -end -function prefix_cond_and_fixed_variables(ctx::FixedContext, prefix::VarName) - # Replace the prefix of the conditioned variables - vn_dict = to_varname_dict(ctx.values) - prefixed_vn_dict = Dict( - AbstractPPL.prefix(vn, prefix) => value for (vn, value) in vn_dict - ) - # Prefix the child context as well - prefixed_child_ctx = prefix_cond_and_fixed_variables(childcontext(ctx), prefix) - return FixedContext(prefixed_vn_dict, prefixed_child_ctx) -end -function prefix_cond_and_fixed_variables(c::AbstractContext, prefix::VarName) - return prefix_cond_and_fixed_variables( - NodeTrait(prefix_cond_and_fixed_variables, c), c, prefix - ) -end -function prefix_cond_and_fixed_variables( - ::IsLeaf, context::AbstractContext, prefix::VarName +function tilde_observe!!( + context::AbstractContext, + right::Distribution, + left, + vn::Union{VarName,Nothing}, + vi::AbstractVarInfo, ) - return context -end -function prefix_cond_and_fixed_variables( - ::IsParent, context::AbstractContext, prefix::VarName -) - return setchildcontext( - context, prefix_cond_and_fixed_variables(childcontext(context), prefix) - ) + return tilde_observe!!(childcontext(context), right, left, vn, vi) end diff --git a/src/contexts/conditionfix.jl b/src/contexts/conditionfix.jl new file mode 100644 index 000000000..d3802de85 --- /dev/null +++ b/src/contexts/conditionfix.jl @@ -0,0 +1,467 @@ +""" + + ConditionContext{Values<:Union{NamedTuple,AbstractDict},Ctx<:AbstractContext} + +Model context that contains values that are to be conditioned on. The values +can either be a NamedTuple mapping symbols to values, such as `(a=1, b=2)`, or +an AbstractDict mapping varnames to values (e.g. `Dict(@varname(a) => 1, +@varname(b) => 2)`). The former is more performant, but the latter must be used +when there are varnames that cannot be represented as symbols, e.g. +`@varname(x[1])`. +""" +struct ConditionContext{ + Values<:Union{NamedTuple,AbstractDict{<:VarName}},Ctx<:AbstractContext +} <: AbstractContext + values::Values + context::Ctx +end + +const NamedConditionContext{Names} = ConditionContext{<:NamedTuple{Names}} +const DictConditionContext = ConditionContext{<:AbstractDict} + +# Use DefaultContext as the default base context +function ConditionContext(values::Union{NamedTuple,AbstractDict}) + return ConditionContext(values, DefaultContext()) +end +# Optimisation when there are no values to condition on +ConditionContext(::NamedTuple{()}, context::AbstractContext) = context +# Same as above, and avoids method ambiguity with below +ConditionContext(::NamedTuple{()}, context::NamedConditionContext) = context +# Collapse consecutive levels of `ConditionContext`. Note that this overrides +# values inside the child context, thus giving precedence to the outermost +# `ConditionContext`. +function ConditionContext(values::NamedTuple, context::NamedConditionContext) + return ConditionContext(merge(context.values, values), childcontext(context)) +end +function ConditionContext(values::AbstractDict{<:VarName}, context::DictConditionContext) + return ConditionContext(merge(context.values, values), childcontext(context)) +end + +function Base.show(io::IO, context::ConditionContext) + return print(io, "ConditionContext($(context.values), $(childcontext(context)))") +end + +NodeTrait(::ConditionContext) = IsParent() +childcontext(context::ConditionContext) = context.context +setchildcontext(parent::ConditionContext, child) = ConditionContext(parent.values, child) + +""" + hasconditioned(context::AbstractContext, vn::VarName) + +Return `true` if `vn` is found in `context`. +""" +hasconditioned(context::AbstractContext, vn::VarName) = false +hasconditioned(context::ConditionContext, vn::VarName) = hasvalue(context.values, vn) +function hasconditioned(context::ConditionContext, vns::AbstractArray{<:VarName}) + return all(Base.Fix1(hasvalue, context.values), vns) +end + +""" + getconditioned(context::AbstractContext, vn::VarName) + +Return value of `vn` in `context`. +""" +function getconditioned(context::AbstractContext, vn::VarName) + return error("context $(context) does not contain value for $vn") +end +function getconditioned(context::ConditionContext, vn::VarName) + return getvalue(context.values, vn) +end + +""" + hasconditioned_nested(context, vn) + +Return `true` if `vn` is found in `context` or any of its descendants. + +This is contrast to [`hasconditioned(::AbstractContext, ::VarName)`](@ref) which only checks +for `vn` in `context`, not recursively checking if `vn` is in any of its descendants. +""" +function hasconditioned_nested(context::AbstractContext, vn) + return hasconditioned_nested(NodeTrait(hasconditioned_nested, context), context, vn) +end +hasconditioned_nested(::IsLeaf, context, vn) = hasconditioned(context, vn) +function hasconditioned_nested(::IsParent, context, vn) + return hasconditioned(context, vn) || hasconditioned_nested(childcontext(context), vn) +end +function hasconditioned_nested(context::PrefixContext, vn) + return hasconditioned_nested(collapse_prefix_stack(context), vn) +end + +""" + getconditioned_nested(context, vn) + +Return the value of the parameter corresponding to `vn` from `context` or its descendants. + +This is contrast to [`getconditioned`](@ref) which only returns the value `vn` in `context`, +not recursively looking into its descendants. +""" +function getconditioned_nested(context::AbstractContext, vn) + return getconditioned_nested(NodeTrait(getconditioned_nested, context), context, vn) +end +function getconditioned_nested(::IsLeaf, context, vn) + return error("context $(context) does not contain value for $vn") +end +function getconditioned_nested(context::PrefixContext, vn) + return getconditioned_nested(collapse_prefix_stack(context), vn) +end +function getconditioned_nested(::IsParent, context, vn) + return if hasconditioned(context, vn) + getconditioned(context, vn) + else + getconditioned_nested(childcontext(context), vn) + end +end + +""" + decondition(context::AbstractContext, syms...) + +Return `context` but with `syms` no longer conditioned on. + +Note that this recursively traverses contexts, deconditioning all along the way. + +See also: [`condition`](@ref) +""" +decondition_context(::IsLeaf, context, args...) = context +function decondition_context(::IsParent, context, args...) + return setchildcontext(context, decondition_context(childcontext(context), args...)) +end +function decondition_context(context, args...) + return decondition_context(NodeTrait(context), context, args...) +end +function decondition_context(context::ConditionContext) + return decondition_context(childcontext(context)) +end +function decondition_context(context::ConditionContext, sym, syms...) + new_values = deepcopy(context.values) + for s in (sym, syms...) + new_values = BangBang.delete!!(new_values, s) + end + return if length(new_values) == 0 + # No more values left, can unwrap + decondition_context(childcontext(context), syms...) + else + ConditionContext( + new_values, decondition_context(childcontext(context), sym, syms...) + ) + end +end +function decondition_context(context::NamedConditionContext, vn::VarName{sym}) where {sym} + return ConditionContext( + BangBang.delete!!(context.values, sym), + decondition_context(childcontext(context), vn), + ) +end + +""" + conditioned(context::AbstractContext) + +Return `NamedTuple` of values that are conditioned on under context`. + +Note that this will recursively traverse the context stack and return +a merged version of the condition values. +""" +function conditioned(context::AbstractContext) + return conditioned(NodeTrait(conditioned, context), context) +end +conditioned(::IsLeaf, context) = NamedTuple() +conditioned(::IsParent, context) = conditioned(childcontext(context)) +function conditioned(context::ConditionContext) + # Note the order of arguments to `merge`. The behavior of the rest of DPPL + # is that the outermost `context` takes precendence, hence when resolving + # the `conditioned` variables we need to ensure that `context.values` takes + # precedence over decendants of `context`. + return _merge(context.values, conditioned(childcontext(context))) +end +function conditioned(context::PrefixContext) + return conditioned(collapse_prefix_stack(context)) +end + +struct FixedContext{Values,Ctx<:AbstractContext} <: AbstractContext + values::Values + context::Ctx +end + +const NamedFixedContext{Names} = FixedContext{<:NamedTuple{Names}} +const DictFixedContext = FixedContext{<:AbstractDict} + +FixedContext(values) = FixedContext(values, DefaultContext()) + +# Try to avoid nested `FixedContext`. +function FixedContext(values::NamedTuple, context::NamedFixedContext) + # Note that this potentially overrides values from `context`, thus giving + # precedence to the outmost `FixedContext`. + return FixedContext(merge(context.values, values), childcontext(context)) +end + +function Base.show(io::IO, context::FixedContext) + return print(io, "FixedContext($(context.values), $(childcontext(context)))") +end + +NodeTrait(::FixedContext) = IsParent() +childcontext(context::FixedContext) = context.context +setchildcontext(parent::FixedContext, child) = FixedContext(parent.values, child) + +""" + hasfixed(context::AbstractContext, vn::VarName) + +Return `true` if a fixed value for `vn` is found in `context`. +""" +hasfixed(context::AbstractContext, vn::VarName) = false +hasfixed(context::FixedContext, vn::VarName) = hasvalue(context.values, vn) +function hasfixed(context::FixedContext, vns::AbstractArray{<:VarName}) + return all(Base.Fix1(hasvalue, context.values), vns) +end + +""" + getfixed(context::AbstractContext, vn::VarName) + +Return the fixed value of `vn` in `context`. +""" +function getfixed(context::AbstractContext, vn::VarName) + return error("context $(context) does not contain value for $vn") +end +getfixed(context::FixedContext, vn::VarName) = getvalue(context.values, vn) + +""" + hasfixed_nested(context, vn) + +Return `true` if a fixed value for `vn` is found in `context` or any of its descendants. + +This is contrast to [`hasfixed(::AbstractContext, ::VarName)`](@ref) which only checks +for `vn` in `context`, not recursively checking if `vn` is in any of its descendants. +""" +function hasfixed_nested(context::AbstractContext, vn) + return hasfixed_nested(NodeTrait(hasfixed_nested, context), context, vn) +end +hasfixed_nested(::IsLeaf, context, vn) = hasfixed(context, vn) +function hasfixed_nested(::IsParent, context, vn) + return hasfixed(context, vn) || hasfixed_nested(childcontext(context), vn) +end +function hasfixed_nested(context::PrefixContext, vn) + return hasfixed_nested(collapse_prefix_stack(context), vn) +end + +""" + getfixed_nested(context, vn) + +Return the fixed value of the parameter corresponding to `vn` from `context` or its descendants. + +This is contrast to [`getfixed`](@ref) which only returns the value `vn` in `context`, +not recursively looking into its descendants. +""" +function getfixed_nested(context::AbstractContext, vn) + return getfixed_nested(NodeTrait(getfixed_nested, context), context, vn) +end +function getfixed_nested(::IsLeaf, context, vn) + return error("context $(context) does not contain value for $vn") +end +function getfixed_nested(context::PrefixContext, vn) + return getfixed_nested(collapse_prefix_stack(context), vn) +end +function getfixed_nested(::IsParent, context, vn) + return if hasfixed(context, vn) + getfixed(context, vn) + else + getfixed_nested(childcontext(context), vn) + end +end + +""" + fix([context::AbstractContext,] values::NamedTuple) + fix([context::AbstractContext]; values...) + +Return `FixedContext` with `values` and `context` if `values` is non-empty, +otherwise return `context` which is [`DefaultContext`](@ref) by default. + +See also: [`unfix`](@ref) +""" +fix(; values...) = fix(NamedTuple(values)) +fix(values::NamedTuple) = fix(DefaultContext(), values) +function fix(value::Pair{<:VarName}, values::Pair{<:VarName}...) + return fix((value, values...)) +end +function fix(values::NTuple{<:Any,<:Pair{<:VarName}}) + return fix(DefaultContext(), values) +end +fix(context::AbstractContext, values::NamedTuple{()}) = context +function fix(context::AbstractContext, values::Union{AbstractDict,NamedTuple}) + return FixedContext(values, context) +end +function fix(context::AbstractContext; values...) + return fix(context, NamedTuple(values)) +end +function fix(context::AbstractContext, value::Pair{<:VarName}, values::Pair{<:VarName}...) + return fix(context, (value, values...)) +end +function fix(context::AbstractContext, values::NTuple{<:Any,Pair{<:VarName}}) + return fix(context, Dict(values)) +end + +""" + unfix(context::AbstractContext, syms...) + +Return `context` but with `syms` no longer fixed. + +Note that this recursively traverses contexts, unfixing all along the way. + +See also: [`fix`](@ref) +""" +unfix(::IsLeaf, context, args...) = context +function unfix(::IsParent, context, args...) + return setchildcontext(context, unfix(childcontext(context), args...)) +end +function unfix(context, args...) + return unfix(NodeTrait(context), context, args...) +end +function unfix(context::FixedContext) + return unfix(childcontext(context)) +end +function unfix(context::FixedContext, sym) + return fix(unfix(childcontext(context), sym), BangBang.delete!!(context.values, sym)) +end +function unfix(context::FixedContext, sym, syms...) + return unfix( + fix(unfix(childcontext(context), syms...), BangBang.delete!!(context.values, sym)), + syms..., + ) +end + +function unfix(context::NamedFixedContext, vn::VarName{sym}) where {sym} + return fix(unfix(childcontext(context), vn), BangBang.delete!!(context.values, sym)) +end +function unfix(context::FixedContext, vn::VarName) + return fix(unfix(childcontext(context), vn), BangBang.delete!!(context.values, vn)) +end + +""" + fixed(context::AbstractContext) + +Return the values that are fixed under `context`. + +Note that this will recursively traverse the context stack and return +a merged version of the fix values. +""" +fixed(context::AbstractContext) = fixed(NodeTrait(fixed, context), context) +fixed(::IsLeaf, context) = NamedTuple() +fixed(::IsParent, context) = fixed(childcontext(context)) +function fixed(context::FixedContext) + # Note the order of arguments to `merge`. The behavior of the rest of DPPL + # is that the outermost `context` takes precendence, hence when resolving + # the `fixed` variables we need to ensure that `context.values` takes + # precedence over decendants of `context`. + return _merge(context.values, fixed(childcontext(context))) +end +function fixed(context::PrefixContext) + return fixed(collapse_prefix_stack(context)) +end + +########################################################################### +### Interaction of PrefixContext with ConditionContext and FixedContext ### +########################################################################### + +""" + collapse_prefix_stack(context::AbstractContext) + +Apply `PrefixContext`s to any conditioned or fixed values inside them, and remove +the `PrefixContext`s from the context stack. + +!!! note + If you are reading this docstring, you might probably be interested in a more +thorough explanation of how PrefixContext and ConditionContext / FixedContext +interact with one another, especially in the context of submodels. + The DynamicPPL documentation contains [a separate page on this +topic](https://turinglang.org/DynamicPPL.jl/previews/PR892/internals/submodel_condition/) +which explains this in much more detail. + +```jldoctest +julia> using DynamicPPL: collapse_prefix_stack + +julia> c1 = PrefixContext(@varname(a), ConditionContext((x=1, ))); + +julia> collapse_prefix_stack(c1) +ConditionContext(Dict(a.x => 1), DefaultContext()) + +julia> # Here, `x` gets prefixed only with `a`, whereas `y` is prefixed with both. + c2 = PrefixContext(@varname(a), ConditionContext((x=1, ), PrefixContext(@varname(b), ConditionContext((y=2,))))); + +julia> collapsed = collapse_prefix_stack(c2); + +julia> # `collapsed` really looks something like this: + # ConditionContext(Dict{VarName{:a}, Int64}(a.b.y => 2, a.x => 1), DefaultContext()) + # To avoid fragility arising from the order of the keys in the doctest, we test + # this indirectly: + collapsed.values[@varname(a.x)], collapsed.values[@varname(a.b.y)] +(1, 2) +``` +""" +function collapse_prefix_stack(context::PrefixContext) + # Collapse the child context (thus applying any inner prefixes first) + collapsed = collapse_prefix_stack(childcontext(context)) + # Prefix any conditioned variables with the current prefix + # Note: prefix_conditioned_variables is O(N) in the depth of the context stack. + # So is this function. In the worst case scenario, this is O(N^2) in the + # depth of the context stack. + return prefix_cond_and_fixed_variables(collapsed, context.vn_prefix) +end +function collapse_prefix_stack(context::AbstractContext) + return collapse_prefix_stack(NodeTrait(collapse_prefix_stack, context), context) +end +collapse_prefix_stack(::IsLeaf, context) = context +function collapse_prefix_stack(::IsParent, context) + new_child_context = collapse_prefix_stack(childcontext(context)) + return setchildcontext(context, new_child_context) +end + +""" + prefix_cond_and_fixed_variables(context::AbstractContext, prefix::VarName) + +Prefix all the conditioned and fixed variables in a given context with a single +`prefix`. + +```jldoctest +julia> using DynamicPPL: prefix_cond_and_fixed_variables, ConditionContext + +julia> c1 = ConditionContext((a=1, )) +ConditionContext((a = 1,), DefaultContext()) + +julia> prefix_cond_and_fixed_variables(c1, @varname(y)) +ConditionContext(Dict(y.a => 1), DefaultContext()) +``` +""" +function prefix_cond_and_fixed_variables(ctx::ConditionContext, prefix::VarName) + # Replace the prefix of the conditioned variables + vn_dict = to_varname_dict(ctx.values) + prefixed_vn_dict = Dict( + AbstractPPL.prefix(vn, prefix) => value for (vn, value) in vn_dict + ) + # Prefix the child context as well + prefixed_child_ctx = prefix_cond_and_fixed_variables(childcontext(ctx), prefix) + return ConditionContext(prefixed_vn_dict, prefixed_child_ctx) +end +function prefix_cond_and_fixed_variables(ctx::FixedContext, prefix::VarName) + # Replace the prefix of the conditioned variables + vn_dict = to_varname_dict(ctx.values) + prefixed_vn_dict = Dict( + AbstractPPL.prefix(vn, prefix) => value for (vn, value) in vn_dict + ) + # Prefix the child context as well + prefixed_child_ctx = prefix_cond_and_fixed_variables(childcontext(ctx), prefix) + return FixedContext(prefixed_vn_dict, prefixed_child_ctx) +end +function prefix_cond_and_fixed_variables(c::AbstractContext, prefix::VarName) + return prefix_cond_and_fixed_variables( + NodeTrait(prefix_cond_and_fixed_variables, c), c, prefix + ) +end +function prefix_cond_and_fixed_variables( + ::IsLeaf, context::AbstractContext, prefix::VarName +) + return context +end +function prefix_cond_and_fixed_variables( + ::IsParent, context::AbstractContext, prefix::VarName +) + return setchildcontext( + context, prefix_cond_and_fixed_variables(childcontext(context), prefix) + ) +end diff --git a/src/contexts/default.jl b/src/contexts/default.jl new file mode 100644 index 000000000..ec21e1a56 --- /dev/null +++ b/src/contexts/default.jl @@ -0,0 +1,60 @@ +""" + struct DefaultContext <: AbstractContext end + +`DefaultContext`, as the name suggests, is the default context used when instantiating a +model. + +```jldoctest +julia> @model f() = x ~ Normal(); + +julia> model = f(); model.context +DefaultContext() +``` + +As an evaluation context, the behaviour of `DefaultContext` is to require all variables to be +present in the `AbstractVarInfo` used for evaluation. Thus, semantically, evaluating a model +with `DefaultContext` means 'calculating the log-probability associated with the variables +in the `AbstractVarInfo`'. +""" +struct DefaultContext <: AbstractContext end +NodeTrait(::DefaultContext) = IsLeaf() + +""" + DynamicPPL.tilde_assume!!( + ::DefaultContext, right::Distribution, vn::VarName, vi::AbstractVarInfo + ) + +Handle assumed variables. For `DefaultContext`, this function extracts the value associated +with `vn` from `vi`, If `vi` does not contain an appropriate value then this will error. +""" +function tilde_assume!!( + ::DefaultContext, right::Distribution, vn::VarName, vi::AbstractVarInfo +) + y = getindex_internal(vi, vn) + f = from_maybe_linked_internal_transform(vi, vn, right) + x, inv_logjac = with_logabsdet_jacobian(f, y) + vi = accumulate_assume!!(vi, x, -inv_logjac, vn, right) + return x, vi +end + +""" + DynamicPPL.tilde_observe!!( + ::DefaultContext, + right::Distribution, + left, + vn::Union{VarName,Nothing}, + vi::AbstractVarInfo, + ) + +Handle observed variables. This just accumulates the log-likelihood for `left`. +""" +function tilde_observe!!( + ::DefaultContext, + right::Distribution, + left, + vn::Union{VarName,Nothing}, + vi::AbstractVarInfo, +) + vi = accumulate_observe!!(vi, right, left, vn) + return left, vi +end diff --git a/src/contexts/prefix.jl b/src/contexts/prefix.jl new file mode 100644 index 000000000..24615e683 --- /dev/null +++ b/src/contexts/prefix.jl @@ -0,0 +1,116 @@ +""" + PrefixContext(vn::VarName[, context::AbstractContext]) + PrefixContext(vn::Val{sym}[, context::AbstractContext]) where {sym} + +Create a context that allows you to use the wrapped `context` when running the model and +prefixes all parameters with the VarName `vn`. + +`PrefixContext(Val(:a), context)` is equivalent to `PrefixContext(@varname(a), context)`. +If `context` is not provided, it defaults to `DefaultContext()`. + +This context is useful in nested models to ensure that the names of the parameters are +unique. + +See also: [`to_submodel`](@ref) +""" +struct PrefixContext{Tvn<:VarName,C<:AbstractContext} <: AbstractContext + vn_prefix::Tvn + context::C +end +PrefixContext(vn::VarName) = PrefixContext(vn, DefaultContext()) +function PrefixContext(::Val{sym}, context::AbstractContext) where {sym} + return PrefixContext(VarName{sym}(), context) +end +PrefixContext(::Val{sym}) where {sym} = PrefixContext(VarName{sym}()) + +NodeTrait(::PrefixContext) = IsParent() +childcontext(context::PrefixContext) = context.context +function setchildcontext(ctx::PrefixContext, child::AbstractContext) + return PrefixContext(ctx.vn_prefix, child) +end + +""" + prefix(ctx::AbstractContext, vn::VarName) + +Apply the prefixes in the context `ctx` to the variable name `vn`. +""" +function prefix(ctx::PrefixContext, vn::VarName) + return AbstractPPL.prefix(prefix(childcontext(ctx), vn), ctx.vn_prefix) +end +function prefix(ctx::AbstractContext, vn::VarName) + return prefix(NodeTrait(ctx), ctx, vn) +end +prefix(::IsLeaf, ::AbstractContext, vn::VarName) = vn +function prefix(::IsParent, ctx::AbstractContext, vn::VarName) + return prefix(childcontext(ctx), vn) +end + +""" + prefix_and_strip_contexts(ctx::PrefixContext, vn::VarName) + +Same as `prefix`, but additionally returns a new context stack that has all the +PrefixContexts removed. + +NOTE: This does _not_ modify any variables in any `ConditionContext` and +`FixedContext` that may be present in the context stack. This is because this +function is only used in `tilde_assume!!`, which is lower in the tilde-pipeline +than `contextual_isassumption` and `contextual_isfixed` (the functions which +actually use the `ConditionContext` and `FixedContext` values). Thus, by this +time, any `ConditionContext`s and `FixedContext`s present have already served +their purpose. + +If you call this function, you must therefore be careful to ensure that you _do +not_ need to modify any inner `ConditionContext`s and `FixedContext`s. If you +_do_ need to modify them, then you may need to use +`prefix_cond_and_fixed_variables` instead. +""" +function prefix_and_strip_contexts(ctx::PrefixContext, vn::VarName) + child_context = childcontext(ctx) + # vn_prefixed contains the prefixes from all lower levels + vn_prefixed, child_context_without_prefixes = prefix_and_strip_contexts( + child_context, vn + ) + return AbstractPPL.prefix(vn_prefixed, ctx.vn_prefix), child_context_without_prefixes +end +function prefix_and_strip_contexts(ctx::AbstractContext, vn::VarName) + return prefix_and_strip_contexts(NodeTrait(ctx), ctx, vn) +end +prefix_and_strip_contexts(::IsLeaf, ctx::AbstractContext, vn::VarName) = (vn, ctx) +function prefix_and_strip_contexts(::IsParent, ctx::AbstractContext, vn::VarName) + vn, new_ctx = prefix_and_strip_contexts(childcontext(ctx), vn) + return vn, setchildcontext(ctx, new_ctx) +end + +function tilde_assume!!( + context::PrefixContext, right::Distribution, vn::VarName, vi::AbstractVarInfo +) + # Note that we can't use something like this here: + # new_vn = prefix(context, vn) + # return tilde_assume!!(childcontext(context), right, new_vn, vi) + # This is because `prefix` applies _all_ prefixes in a given context to a + # variable name. Thus, if we had two levels of nested prefixes e.g. + # `PrefixContext{:a}(PrefixContext{:b}(DefaultContext()))`, then the + # first call would apply the prefix `a.b._`, and the recursive call + # would apply the prefix `b._`, resulting in `b.a.b._`. + # This is why we need a special function, `prefix_and_strip_contexts`. + new_vn, new_context = prefix_and_strip_contexts(context, vn) + return tilde_assume!!(new_context, right, new_vn, vi) +end + +function tilde_observe!!( + context::PrefixContext, + right::Distribution, + left, + vn::Union{VarName,Nothing}, + vi::AbstractVarInfo, +) + # In the observe case, unlike assume, `vn` may be `nothing` if the LHS is a literal + # value. For the need for prefix_and_strip_contexts rather than just prefix, see the + # comment in `tilde_assume!!`. + new_vn, new_context = if vn !== nothing + prefix_and_strip_contexts(context, vn) + else + vn, childcontext(context) + end + return tilde_observe!!(new_context, right, left, new_vn, vi) +end diff --git a/src/transforming.jl b/src/contexts/transformation.jl similarity index 61% rename from src/transforming.jl rename to src/contexts/transformation.jl index 589dca031..720fa978f 100644 --- a/src/transforming.jl +++ b/src/contexts/transformation.jl @@ -43,31 +43,3 @@ function tilde_observe!!( ) return tilde_observe!!(DefaultContext(), right, left, vn, vi) end - -function link!!(t::DynamicTransformation, vi::AbstractVarInfo, model::Model) - return _transform!!(t, DynamicTransformationContext{false}(), vi, model) -end - -function invlink!!(::DynamicTransformation, vi::AbstractVarInfo, model::Model) - return _transform!!(NoTransformation(), DynamicTransformationContext{true}(), vi, model) -end - -function _transform!!( - t::AbstractTransformation, - ctx::DynamicTransformationContext, - vi::AbstractVarInfo, - model::Model, -) - # To transform using DynamicTransformationContext, we evaluate the model using that as the leaf context: - model = contextualize(model, setleafcontext(model.context, ctx)) - vi = settrans!!(last(evaluate!!(model, vi)), t) - return vi -end - -function link(t::DynamicTransformation, vi::AbstractVarInfo, model::Model) - return link!!(t, deepcopy(vi), model) -end - -function invlink(t::DynamicTransformation, vi::AbstractVarInfo, model::Model) - return invlink!!(t, deepcopy(vi), model) -end diff --git a/src/debug_utils.jl b/src/debug_utils.jl index 2ec8b15a2..13124e3a7 100644 --- a/src/debug_utils.jl +++ b/src/debug_utils.jl @@ -1,7 +1,6 @@ module DebugUtils using ..DynamicPPL -using ..DynamicPPL: broadcast_safe, AbstractContext, childcontext using Random: Random using Accessors: Accessors diff --git a/src/submodel.jl b/src/submodel.jl index dcb107bb4..145bd42c9 100644 --- a/src/submodel.jl +++ b/src/submodel.jl @@ -8,6 +8,10 @@ struct Submodel{M,AutoPrefix} model::M end +# ---------------------- +# Constructing submodels +# ---------------------- + """ to_submodel(model::Model[, auto_prefix::Bool]) @@ -152,6 +156,26 @@ ERROR: ArgumentError: `x ~ to_submodel(...)` is not supported when `x` is observ """ to_submodel(m::Model, auto_prefix::Bool=true) = Submodel{typeof(m),auto_prefix}(m) +# --------------------------- +# Submodels in tilde-pipeline +# --------------------------- + +""" + DynamicPPL.tilde_assume!!( + context::AbstractContext, + right::DynamicPPL.Submodel, + vn::VarName, + vi::AbstractVarInfo + ) + +Evaluate the submodel with the given context. +""" +function tilde_assume!!( + context::AbstractContext, right::DynamicPPL.Submodel, vn::VarName, vi::AbstractVarInfo +) + return _evaluate!!(right, vi, context, vn) +end + # When automatic prefixing is used, the submodel itself doesn't carry the # prefix, as the prefix is obtained from the LHS of `~` (whereas the submodel # is on the RHS). The prefix can only be obtained in `tilde_assume!!`, and then @@ -193,3 +217,13 @@ function _evaluate!!( # returns a tuple of submodel.model's return value and the new varinfo. return _evaluate!!(model, vi) end + +function tilde_observe!!( + ::AbstractContext, + ::DynamicPPL.Submodel, + left, + vn::Union{VarName,Nothing}, + ::AbstractVarInfo, +) + throw(ArgumentError("`x ~ to_submodel(...)` is not supported when `x` is observed")) +end diff --git a/src/utils.jl b/src/utils.jl index a4c5f4a1b..b09bfb9fa 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -793,10 +793,6 @@ end # Handle `AbstractDict` differently since `eltype` results in a `Pair`. infer_nested_eltype(::Type{<:AbstractDict{<:Any,ET}}) where {ET} = infer_nested_eltype(ET) -broadcast_safe(x) = x -broadcast_safe(x::Distribution) = (x,) -broadcast_safe(x::AbstractContext) = (x,) - # Convert (x=1,) to Dict(@varname(x) => 1) function to_varname_dict(nt::NamedTuple) return Dict{VarName,Any}(VarName{k}() => v for (k, v) in pairs(nt)) From c08cfa5e047946086fda67e1f4a36ae0342d261a Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 29 Sep 2025 16:24:44 +0200 Subject: [PATCH 11/14] Delete the `"del"` flag (#1058) * Delete del * Fix a typo * Add HISTORY entry about del --- HISTORY.md | 4 ++++ src/threadsafe.jl | 6 ++---- src/varinfo.jl | 42 +++++++++--------------------------------- test/varinfo.jl | 17 ++++++----------- 4 files changed, 21 insertions(+), 48 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index d67afcbfe..aaa5ac1eb 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -50,6 +50,10 @@ Other functions such as `tilde_assume` and `assume` (and their `observe` counter Note that this was effectively already the case in DynamicPPL 0.37 (where they were just wrappers around each other). The separation of these functions was primarily implemented to avoid performing extra work where unneeded (e.g. to not calculate the log-likelihood when `PriorContext` was being used). This functionality has since been replaced with accumulators (see the 0.37 changelog for more details). +### Removal of the `"del"` flag + +Previously `VarInfo` (or more correctly, the `Metadata` object within a `VarInfo`), had a flag called `"del"` for all variables. If it was set to `true` the variable was to be overwritten with a new value at the next evaluation. The new `InitContext` and related changes above make this flag unnecessary, and it has been removed. + **Other changes** ### Reimplementation of functions using `InitContext` diff --git a/src/threadsafe.jl b/src/threadsafe.jl index 6ca3b9852..f89a562e3 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -185,10 +185,8 @@ end values_as(vi::ThreadSafeVarInfo) = values_as(vi.varinfo) values_as(vi::ThreadSafeVarInfo, ::Type{T}) where {T} = values_as(vi.varinfo, T) -function unset_flag!( - vi::ThreadSafeVarInfo, vn::VarName, flag::String, ignoreable::Bool=false -) - return unset_flag!(vi.varinfo, vn, flag, ignoreable) +function unset_flag!(vi::ThreadSafeVarInfo, vn::VarName, flag::String) + return unset_flag!(vi.varinfo, vn, flag) end function is_flagged(vi::ThreadSafeVarInfo, vn::VarName, flag::String) return is_flagged(vi.varinfo, vn, flag) diff --git a/src/varinfo.jl b/src/varinfo.jl index 081f65ea1..062cc236b 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -423,7 +423,6 @@ Construct an empty type unstable instance of `Metadata`. function Metadata() vals = Vector{Real}() flags = Dict{String,BitVector}() - flags["del"] = BitVector() flags["trans"] = BitVector() return Metadata( @@ -887,12 +886,7 @@ function set_flag!(md::Metadata, vn::VarName, flag::String) end function set_flag!(vnv::VarNamedVector, ::VarName, flag::String) - if flag == "del" - # The "del" flag is effectively always set for a VarNamedVector, so this is a no-op. - else - throw(ErrorException("Flag $flag not valid for VarNamedVector")) - end - return vnv + throw(ErrorException("VarNamedVector does not support flags; Tried to set $(flag).")) end #### @@ -1710,7 +1704,7 @@ function BangBang.push!!(vi::VarInfo, vn::VarName, r, dist::Distribution) [1:length(val)], val, [dist], - Dict{String,BitVector}("trans" => [false], "del" => [false]), + Dict{String,BitVector}("trans" => [false]), ) vi = Accessors.@set vi.metadata[sym] = md else @@ -1744,7 +1738,6 @@ function Base.push!(meta::Metadata, vn, r, dist) push!(meta.ranges, (l + 1):(l + n)) append!(meta.vals, val) push!(meta.dists, dist) - push!(meta.flags["del"], false) push!(meta.flags["trans"], false) return meta end @@ -1770,42 +1763,25 @@ function is_flagged(metadata::Metadata, vn::VarName, flag::String) return metadata.flags[flag][getidx(metadata, vn)] end function is_flagged(::VarNamedVector, ::VarName, flag::String) - if flag == "del" - return true - else - throw(ErrorException("Flag $flag not valid for VarNamedVector")) - end + throw(ErrorException("VarNamedVector does not support flags; Tried to read $(flag).")) end -# TODO(mhauru) The "ignorable" argument is a temporary hack while developing VarNamedVector, -# but still having to support the interface based on Metadata too """ - unset_flag!(vi::VarInfo, vn::VarName, flag::String, ignorable::Bool=false + unset_flag!(vi::VarInfo, vn::VarName, flag::String Set `vn`'s value for `flag` to `false` in `vi`. - -Setting some flags for some `VarInfo` types is not possible, and by default attempting to do -so will error. If `ignorable` is set to `true` then this will silently be ignored instead. """ -function unset_flag!(vi::VarInfo, vn::VarName, flag::String, ignorable::Bool=false) - unset_flag!(getmetadata(vi, vn), vn, flag, ignorable) +function unset_flag!(vi::VarInfo, vn::VarName, flag::String) + unset_flag!(getmetadata(vi, vn), vn, flag) return vi end -function unset_flag!(metadata::Metadata, vn::VarName, flag::String, ignorable::Bool=false) +function unset_flag!(metadata::Metadata, vn::VarName, flag::String) metadata.flags[flag][getidx(metadata, vn)] = false return metadata end -function unset_flag!(vnv::VarNamedVector, ::VarName, flag::String, ignorable::Bool=false) - if ignorable - return vnv - end - if flag == "del" - throw(ErrorException("The \"del\" flag cannot be unset for VarNamedVector")) - else - throw(ErrorException("Flag $flag not valid for VarNamedVector")) - end - return vnv +function unset_flag!(vnv::VarNamedVector, ::VarName, flag::String) + throw(ErrorException("VarNamedVector does not support flags; Tried to unset $(flag).")) end # TODO: Maybe rename or something? diff --git a/test/varinfo.jl b/test/varinfo.jl index 75d8e062b..dc09ff8da 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -22,11 +22,6 @@ function randr(vi::DynamicPPL.VarInfo, vn::VarName, dist::Distribution) r = rand(dist) push!!(vi, vn, r, dist) r - elseif DynamicPPL.is_flagged(vi, vn, "del") - DynamicPPL.unset_flag!(vi, vn, "del") - r = rand(dist) - vi[vn] = DynamicPPL.tovec(r) - r else vi[vn] end @@ -300,14 +295,14 @@ end push!!(vi, vn_x, r, dist) - # del is set by default - @test !is_flagged(vi, vn_x, "del") + # trans is set by default + @test !is_flagged(vi, vn_x, "trans") - set_flag!(vi, vn_x, "del") - @test is_flagged(vi, vn_x, "del") + set_flag!(vi, vn_x, "trans") + @test is_flagged(vi, vn_x, "trans") - unset_flag!(vi, vn_x, "del") - @test !is_flagged(vi, vn_x, "del") + unset_flag!(vi, vn_x, "trans") + @test !is_flagged(vi, vn_x, "trans") end vi = VarInfo() test_varinfo!(vi) From 08212a21deb84d3308f31c66e65d92dd17bc7b19 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 30 Sep 2025 15:15:59 +0100 Subject: [PATCH 12/14] Fixes for Turing 0.41 (#1057) * setleafcontext(model, ctx) and various other fixes * fix a bug * Add warning for `initial_parameters=...` --- HISTORY.md | 4 ++++ ext/DynamicPPLJETExt.jl | 8 ++------ src/abstract_varinfo.jl | 6 ++---- src/contexts.jl | 23 +++++++++++++---------- src/model.jl | 13 +++++++++++-- src/sampler.jl | 25 +++++++++++++++++++++---- src/test_utils/contexts.jl | 2 +- src/threadsafe.jl | 8 ++------ test/ext/DynamicPPLJETExt.jl | 10 ++-------- 9 files changed, 58 insertions(+), 41 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index 3f20cfd2f..f69c4a6fd 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -56,6 +56,10 @@ Previously `VarInfo` (or more correctly, the `Metadata` object within a `VarInfo **Other changes** +### `setleafcontext(model, context)` + +This convenience method has been added to quickly modify the leaf context of a model. + ### Reimplementation of functions using `InitContext` A number of functions have been reimplemented and unified with the help of `InitContext`. diff --git a/ext/DynamicPPLJETExt.jl b/ext/DynamicPPLJETExt.jl index 55016d40c..e0163bb35 100644 --- a/ext/DynamicPPLJETExt.jl +++ b/ext/DynamicPPLJETExt.jl @@ -24,9 +24,7 @@ function DynamicPPL.Experimental._determine_varinfo_jet( varinfo = DynamicPPL.typed_varinfo(model) # Check type stability of evaluation (i.e. DefaultContext) - model = DynamicPPL.contextualize( - model, DynamicPPL.setleafcontext(model.context, DynamicPPL.DefaultContext()) - ) + model = DynamicPPL.setleafcontext(model, DynamicPPL.DefaultContext()) eval_issuccess, eval_result = DynamicPPL.Experimental.is_suitable_varinfo( model, varinfo; only_ddpl ) @@ -36,9 +34,7 @@ function DynamicPPL.Experimental._determine_varinfo_jet( end # Check type stability of initialisation (i.e. InitContext) - model = DynamicPPL.contextualize( - model, DynamicPPL.setleafcontext(model.context, DynamicPPL.InitContext()) - ) + model = DynamicPPL.setleafcontext(model, DynamicPPL.InitContext()) init_issuccess, init_result = DynamicPPL.Experimental.is_suitable_varinfo( model, varinfo; only_ddpl ) diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index b3cf77121..7cc800dbb 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -830,8 +830,7 @@ end function link!!(t::DynamicTransformation, vi::AbstractVarInfo, model::Model) # Note that in practice this method is only called for SimpleVarInfo, because VarInfo # has a dedicated implementation - ctx = DynamicTransformationContext{false}() - model = contextualize(model, setleafcontext(model.context, ctx)) + model = setleafcontext(model, DynamicTransformationContext{false}()) vi = last(evaluate!!(model, vi)) return settrans!!(vi, t) end @@ -893,8 +892,7 @@ end function invlink!!(::DynamicTransformation, vi::AbstractVarInfo, model::Model) # Note that in practice this method is only called for SimpleVarInfo, because VarInfo # has a dedicated implementation - ctx = DynamicTransformationContext{true}() - model = contextualize(model, setleafcontext(model.context, ctx)) + model = setleafcontext(model, DynamicTransformationContext{true}()) vi = last(evaluate!!(model, vi)) return settrans!!(vi, NoTransformation()) end diff --git a/src/contexts.jl b/src/contexts.jl index 70f99a73f..32a236e8e 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -58,16 +58,17 @@ DynamicTransformationContext{true}() setchildcontext """ - leafcontext(context) + leafcontext(context::AbstractContext) Return the leaf of `context`, i.e. the first descendant context that `IsLeaf`. """ -leafcontext(context) = leafcontext(NodeTrait(leafcontext, context), context) -leafcontext(::IsLeaf, context) = context -leafcontext(::IsParent, context) = leafcontext(childcontext(context)) +leafcontext(context::AbstractContext) = + leafcontext(NodeTrait(leafcontext, context), context) +leafcontext(::IsLeaf, context::AbstractContext) = context +leafcontext(::IsParent, context::AbstractContext) = leafcontext(childcontext(context)) """ - setleafcontext(left, right) + setleafcontext(left::AbstractContext, right::AbstractContext) Return `left` but now with its leaf context replaced by `right`. @@ -103,19 +104,21 @@ julia> # Append another parent context. ParentContext(ParentContext(ParentContext(DefaultContext()))) ``` """ -function setleafcontext(left, right) +function setleafcontext(left::AbstractContext, right::AbstractContext) return setleafcontext( NodeTrait(setleafcontext, left), NodeTrait(setleafcontext, right), left, right ) end -function setleafcontext(::IsParent, ::IsParent, left, right) +function setleafcontext( + ::IsParent, ::IsParent, left::AbstractContext, right::AbstractContext +) return setchildcontext(left, setleafcontext(childcontext(left), right)) end -function setleafcontext(::IsParent, ::IsLeaf, left, right) +function setleafcontext(::IsParent, ::IsLeaf, left::AbstractContext, right::AbstractContext) return setchildcontext(left, setleafcontext(childcontext(left), right)) end -setleafcontext(::IsLeaf, ::IsParent, left, right) = right -setleafcontext(::IsLeaf, ::IsLeaf, left, right) = right +setleafcontext(::IsLeaf, ::IsParent, left::AbstractContext, right::AbstractContext) = right +setleafcontext(::IsLeaf, ::IsLeaf, left::AbstractContext, right::AbstractContext) = right """ DynamicPPL.tilde_assume!!( diff --git a/src/model.jl b/src/model.jl index a6a3e0685..6c7e8de94 100644 --- a/src/model.jl +++ b/src/model.jl @@ -95,6 +95,16 @@ function contextualize(model::Model, context::AbstractContext) return Model(model.f, model.args, model.defaults, context) end +""" + setleafcontext(model::Model, context::AbstractContext) + +Return a new `Model` with its leaf context set to `context`. This is a convenience shortcut +for `contextualize(model, setleafcontext(model.context, context)`). +""" +function setleafcontext(model::Model, context::AbstractContext) + return contextualize(model, setleafcontext(model.context, context)) +end + """ model | (x = 1.0, ...) @@ -886,8 +896,7 @@ function init!!( varinfo::AbstractVarInfo, init_strategy::AbstractInitStrategy=InitFromPrior(), ) - new_context = setleafcontext(model.context, InitContext(rng, init_strategy)) - new_model = contextualize(model, new_context) + new_model = setleafcontext(model, InitContext(rng, init_strategy)) return evaluate!!(new_model, varinfo) end function init!!( diff --git a/src/sampler.jl b/src/sampler.jl index 8b49f6c3b..c598e13f5 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -46,12 +46,12 @@ function default_varinfo(rng::Random.AbstractRNG, model::Model, ::AbstractSample end """ - init_strategy(sampler) + init_strategy(sampler::AbstractSampler) Define the initialisation strategy used for generating initial values when sampling with `sampler`. Defaults to `InitFromPrior()`, but can be overridden. """ -init_strategy(::Sampler) = InitFromPrior() +init_strategy(::AbstractSampler) = InitFromPrior() function AbstractMCMC.sample( rng::Random.AbstractRNG, @@ -60,11 +60,15 @@ function AbstractMCMC.sample( N::Integer; chain_type=default_chain_type(sampler), resume_from=nothing, + initial_params=init_strategy(sampler), initial_state=loadstate(resume_from), kwargs..., ) + if hasproperty(kwargs, :initial_parameters) + @warn "The `initial_parameters` keyword argument is not recognised; please use `initial_params` instead." + end return AbstractMCMC.mcmcsample( - rng, model, sampler, N; chain_type, initial_state, kwargs... + rng, model, sampler, N; chain_type, initial_params, initial_state, kwargs... ) end @@ -76,12 +80,25 @@ function AbstractMCMC.sample( N::Integer, nchains::Integer; chain_type=default_chain_type(sampler), + initial_params=fill(init_strategy(sampler), nchains), resume_from=nothing, initial_state=loadstate(resume_from), kwargs..., ) + if hasproperty(kwargs, :initial_parameters) + @warn "The `initial_parameters` keyword argument is not recognised; please use `initial_params` instead." + end return AbstractMCMC.mcmcsample( - rng, model, sampler, parallel, N, nchains; chain_type, initial_state, kwargs... + rng, + model, + sampler, + parallel, + N, + nchains; + chain_type, + initial_params, + initial_state, + kwargs..., ) end diff --git a/src/test_utils/contexts.jl b/src/test_utils/contexts.jl index d53ba6c5f..aae2e4ec6 100644 --- a/src/test_utils/contexts.jl +++ b/src/test_utils/contexts.jl @@ -47,7 +47,7 @@ function test_leaf_context(context::DynamicPPL.AbstractContext, model::DynamicPP _, untyped_vi = DynamicPPL.init!!(model, DynamicPPL.VarInfo()) typed_vi = DynamicPPL.typed_varinfo(untyped_vi) # Set the test context as the new leaf context - new_model = contextualize(model, DynamicPPL.setleafcontext(model.context, context)) + new_model = DynamicPPL.setleafcontext(model, context) # Check that evaluation works for vi in [untyped_vi, typed_vi] _, vi = DynamicPPL.evaluate!!(new_model, vi) diff --git a/src/threadsafe.jl b/src/threadsafe.jl index f89a562e3..e86a4c4ae 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -103,16 +103,12 @@ end # consistency between `vi.accs_by_thread` field and `getacc(vi.varinfo)`, which accumulates # to define `getacc(vi)`. function link!!(t::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model) - model = contextualize( - model, setleafcontext(model.context, DynamicTransformationContext{false}()) - ) + model = setleafcontext(model, DynamicTransformationContext{false}()) return settrans!!(last(evaluate!!(model, vi)), t) end function invlink!!(::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model) - model = contextualize( - model, setleafcontext(model.context, DynamicTransformationContext{true}()) - ) + model = setleafcontext(model, DynamicTransformationContext{true}()) return settrans!!(last(evaluate!!(model, vi)), NoTransformation()) end diff --git a/test/ext/DynamicPPLJETExt.jl b/test/ext/DynamicPPLJETExt.jl index b34424a1c..8ed29e0c7 100644 --- a/test/ext/DynamicPPLJETExt.jl +++ b/test/ext/DynamicPPLJETExt.jl @@ -81,20 +81,14 @@ typed_vi = DynamicPPL.typed_varinfo(model) @info "Evaluating with DefaultContext:" - model = DynamicPPL.contextualize( - model, - DynamicPPL.setleafcontext(model.context, DynamicPPL.DefaultContext()), - ) + model = DynamicPPL.setleafcontext(model, DynamicPPL.DefaultContext()) f, argtypes = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( model, varinfo ) JET.test_call(f, argtypes) @info "Initialising with InitContext:" - model = DynamicPPL.contextualize( - model, - DynamicPPL.setleafcontext(model.context, DynamicPPL.InitContext()), - ) + model = DynamicPPL.setleafcontext(model, DynamicPPL.InitContext()) f, argtypes = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( model, varinfo ) From 7abd5fbf0c0a2ee59f244da1e2e0a8836eda45a1 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 2 Oct 2025 16:10:28 +0100 Subject: [PATCH 13/14] Remove `resume_from` and `default_chain_type` (#1061) * Remove resume_from * Format * Fix test --- HISTORY.md | 5 +++++ src/DynamicPPL.jl | 2 ++ src/sampler.jl | 37 ++++++++----------------------------- test/sampler.jl | 30 +++--------------------------- 4 files changed, 18 insertions(+), 56 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index f69c4a6fd..29bc56493 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -54,6 +54,11 @@ The separation of these functions was primarily implemented to avoid performing Previously `VarInfo` (or more correctly, the `Metadata` object within a `VarInfo`), had a flag called `"del"` for all variables. If it was set to `true` the variable was to be overwritten with a new value at the next evaluation. The new `InitContext` and related changes above make this flag unnecessary, and it has been removed. +### Removal of `resume_from` + +The `resume_from=chn` keyword argument to `sample` has been removed; please use `initial_state=DynamicPPL.loadstate(chn)` instead. +`loadstate` is exported from DynamicPPL. + **Other changes** ### `setleafcontext(model, context)` diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 31adadb55..43180b091 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -130,6 +130,8 @@ export AbstractVarInfo, prefix, returned, to_submodel, + # Chain save/resume + loadstate, # Convenience macros @addlogprob!, value_iterator_from_chain, diff --git a/src/sampler.jl b/src/sampler.jl index c598e13f5..01f056053 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -58,17 +58,15 @@ function AbstractMCMC.sample( model::Model, sampler::Sampler, N::Integer; - chain_type=default_chain_type(sampler), - resume_from=nothing, initial_params=init_strategy(sampler), - initial_state=loadstate(resume_from), + initial_state=nothing, kwargs..., ) if hasproperty(kwargs, :initial_parameters) @warn "The `initial_parameters` keyword argument is not recognised; please use `initial_params` instead." end return AbstractMCMC.mcmcsample( - rng, model, sampler, N; chain_type, initial_params, initial_state, kwargs... + rng, model, sampler, N; initial_params, initial_state, kwargs... ) end @@ -79,26 +77,15 @@ function AbstractMCMC.sample( parallel::AbstractMCMC.AbstractMCMCEnsemble, N::Integer, nchains::Integer; - chain_type=default_chain_type(sampler), initial_params=fill(init_strategy(sampler), nchains), - resume_from=nothing, - initial_state=loadstate(resume_from), + initial_state=nothing, kwargs..., ) if hasproperty(kwargs, :initial_parameters) @warn "The `initial_parameters` keyword argument is not recognised; please use `initial_params` instead." end return AbstractMCMC.mcmcsample( - rng, - model, - sampler, - parallel, - N, - nchains; - chain_type, - initial_params, - initial_state, - kwargs..., + rng, model, sampler, parallel, N, nchains; initial_params, initial_state, kwargs... ) end @@ -124,20 +111,12 @@ function AbstractMCMC.step( end """ - loadstate(data) + loadstate(chain::AbstractChains) -Load sampler state from `data`. - -By default, `data` is returned. -""" -loadstate(data) = data - -""" - default_chain_type(sampler) - -Default type of the chain of posterior samples from `sampler`. +Load sampler state from an `AbstractChains` object. This function should be overloaded by a +concrete Chains implementation. """ -default_chain_type(::Sampler) = Any +function loadstate end """ initialstep(rng, model, sampler, varinfo; kwargs...) diff --git a/test/sampler.jl b/test/sampler.jl index 5380ad17e..8be54901d 100644 --- a/test/sampler.jl +++ b/test/sampler.jl @@ -12,7 +12,7 @@ @test AbstractMCMC.step(Xoshiro(468), g(), spl) isa Any end - @testset "initial_state and resume_from kwargs" begin + @testset "initial_state" begin # Model is unused, but has to be a DynamicPPL.Model otherwise we won't hit our # overloaded method. @model f() = x ~ Normal() @@ -52,26 +52,15 @@ chn = sample(model, spl, N_iters; progress=false, chain_type=MCMCChains.Chains) initial_value = chn[:x][1] @test all(chn[:x] .== initial_value) # sanity check - # using `initial_state` chn2 = sample( model, spl, N_iters; progress=false, - initial_state=chn.info.samplerstate, + initial_state=DynamicPPL.loadstate(chn), chain_type=MCMCChains.Chains, ) @test all(chn2[:x] .== initial_value) - # using `resume_from` - chn3 = sample( - model, - spl, - N_iters; - progress=false, - resume_from=chn, - chain_type=MCMCChains.Chains, - ) - @test all(chn3[:x] .== initial_value) end @testset "multiple-chain sampling" begin @@ -86,7 +75,6 @@ ) initial_value = chn[:x][1, :] @test all(i -> chn[:x][i, :] == initial_value, 1:N_iters) # sanity check - # using `initial_state` chn2 = sample( model, spl, @@ -94,22 +82,10 @@ N_iters, N_chains; progress=false, - initial_state=chn.info.samplerstate, + initial_state=DynamicPPL.loadstate(chn), chain_type=MCMCChains.Chains, ) @test all(i -> chn2[:x][i, :] == initial_value, 1:N_iters) - # using `resume_from` - chn3 = sample( - model, - spl, - MCMCThreads(), - N_iters, - N_chains; - progress=false, - resume_from=chn, - chain_type=MCMCChains.Chains, - ) - @test all(i -> chn3[:x][i, :] == initial_value, 1:N_iters) end end From 908d4025005c7aaa95af31c074dd6c8fe9b73a1e Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 3 Oct 2025 07:45:50 +0100 Subject: [PATCH 14/14] remove initial_params warning --- src/sampler.jl | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/sampler.jl b/src/sampler.jl index 01f056053..ed1b86321 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -62,9 +62,6 @@ function AbstractMCMC.sample( initial_state=nothing, kwargs..., ) - if hasproperty(kwargs, :initial_parameters) - @warn "The `initial_parameters` keyword argument is not recognised; please use `initial_params` instead." - end return AbstractMCMC.mcmcsample( rng, model, sampler, N; initial_params, initial_state, kwargs... ) @@ -81,9 +78,6 @@ function AbstractMCMC.sample( initial_state=nothing, kwargs..., ) - if hasproperty(kwargs, :initial_parameters) - @warn "The `initial_parameters` keyword argument is not recognised; please use `initial_params` instead." - end return AbstractMCMC.mcmcsample( rng, model, sampler, parallel, N, nchains; initial_params, initial_state, kwargs... )