Skip to content

InitContext, part 3 - Introduce InitContext #981

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: breaking
Choose a base branch
from

Conversation

penelopeysm
Copy link
Member

@penelopeysm penelopeysm commented Jul 9, 2025

Part 1: Adding hasvalue and getvalue to AbstractPPL
Part 2: Removing hasvalue and getvalue from DynamicPPL

This is Part 3/N of splitting up #967.


This PR solely adds the code for a new leaf context, InitContext. Its behaviour is to always override existing values inside a VarInfo.

A long-held goal of mine is to split up contexts.jl and context_implementations.jl such that each context is in its own file. To this end, I've put all the InitContext-related code inside src/contexts/init.jl. This should, hopefully, also make reviewing easier.

InitContext comes in three forms:

  1. PriorInit: Samples values from the prior;
  2. UniformInit(min, max): Samples values from between min and max before invlinking them back into the support of the original distribution;
  3. ParamsInit(params): Takes values from params, which can be a NamedTuple or a Dict. Also takes a fallback strategy, which is used in case the varname is not found in params.

PriorInit and UniformInit have almost one-to-one correspondence with SampleFromPrior and SampleFromUniform. I haven't removed them yet because a LOT of stuff depends on SampleFromPrior. We will eventually remove all of it but doing it in a single PR would make it way too big.

However, ParamsInit is new. Its purpose is to set fixed values in a VarInfo in a principled manner. Right now, there is a fair amount of code that does low-level VarInfo manipulation by reaching inside and modifying the underlying arrays, such as initialise_values!!, and setval_and_resample!. This is prone to bugs, and also leads to (for example) the VarInfo's logp being out of sync. (All of that code will be removed in subsequent PRs.)

ParamsInit accomplishes the same in a much cleaner way, although in some cases (specifically, when setting initial values for sampling) it may necessitate one extra model evaluation. Because initialisation is not meant to happen within a tight loop (by definition it happens once), I don't consider performance to be an important issue.

One might ask what's the difference between InitContext{ParamsInit}, ConditionContext, and FixedContext. Consider the model @model f() = x ~ Normal(). Evaluating this with each of the contexts gives the following behaviour:

InitContext{ParamsInit} ConditionContext FixedContext
logpdf of x ~ Normal() Prior Likelihood Ignored
x in VarInfo Yes No No

In other words, InitContext behaves as if you had run the model and sampled a new value for x, but that value just so happened to be the one that you provided. In contrast, Condition and FixedContext fundamentally change the model that is being run in that x is no longer a parameter.

Closes #375.

Copy link
Contributor

github-actions bot commented Jul 9, 2025

Benchmark Report for Commit bcfdd93

Computer Information

Julia Version 1.11.6
Commit 9615af0f269 (2025-07-09 12:58 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 4 × AMD EPYC 7763 64-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver3)
Threads: 1 default, 0 interactive, 1 GC (on 4 virtual cores)

Benchmark Results

|                 Model | Dimension |  AD Backend |      VarInfo Type | Linked | Eval Time / Ref Time | AD Time / Eval Time |
|-----------------------|-----------|-------------|-------------------|--------|----------------------|---------------------|
| Simple assume observe |         1 | forwarddiff |             typed |  false |                  8.7 |                 1.6 |
|           Smorgasbord |       201 | forwarddiff |             typed |  false |                654.1 |                44.2 |
|           Smorgasbord |       201 | forwarddiff | simple_namedtuple |   true |                423.3 |                52.5 |
|           Smorgasbord |       201 | forwarddiff |           untyped |   true |               1184.4 |                29.9 |
|           Smorgasbord |       201 | forwarddiff |       simple_dict |   true |               6597.3 |                29.2 |
|           Smorgasbord |       201 | reversediff |             typed |   true |               1463.8 |                28.7 |
|           Smorgasbord |       201 |    mooncake |             typed |   true |               1002.2 |                 4.2 |
|    Loop univariate 1k |      1000 |    mooncake |             typed |   true |               5778.1 |                 4.1 |
|       Multivariate 1k |      1000 |    mooncake |             typed |   true |                981.6 |                 9.3 |
|   Loop univariate 10k |     10000 |    mooncake |             typed |   true |              65268.8 |                 3.7 |
|      Multivariate 10k |     10000 |    mooncake |             typed |   true |               8796.1 |                 9.8 |
|               Dynamic |        10 |    mooncake |             typed |   true |                133.3 |                12.3 |
|              Submodel |         1 |    mooncake |             typed |   true |                 13.2 |                 4.7 |
|                   LDA |        12 | reversediff |             typed |   true |               1179.1 |                 3.9 |

Copy link
Contributor

DynamicPPL.jl documentation for PR #981 is available at:
https://TuringLang.github.io/DynamicPPL.jl/previews/PR981/

Copy link

codecov bot commented Jul 10, 2025

Codecov Report

❌ Patch coverage is 0% with 63 lines in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (breaking@5a9e9d2). Learn more about missing BASE report.

Files with missing lines Patch % Lines
src/contexts/init.jl 0.00% 49 Missing ⚠️
src/model.jl 0.00% 11 Missing ⚠️
src/varnamedvector.jl 0.00% 3 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##             breaking     #981   +/-   ##
===========================================
  Coverage            ?   33.70%           
===========================================
  Files               ?       39           
  Lines               ?     3979           
  Branches            ?        0           
===========================================
  Hits                ?     1341           
  Misses              ?     2638           
  Partials            ?        0           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Comment on lines -283 to -287
"""
prefix(model::Model, x::VarName)
prefix(model::Model, x::Val{sym})
prefix(model::Model, x::Any)

Copy link
Member Author

@penelopeysm penelopeysm Jul 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code was shifted verbatim to src/model.jl to avoid circular dependencies between files.

@penelopeysm penelopeysm changed the title Introduce InitContext InitContext, part 3 - Introduce InitContext Jul 10, 2025
@penelopeysm penelopeysm force-pushed the py/init-prior-uniform branch 2 times, most recently from 025aa8b to b55c1e1 Compare July 10, 2025 14:24
Base automatically changed from py/hasgetvalue to breaking July 25, 2025 14:22
@penelopeysm penelopeysm force-pushed the py/init-prior-uniform branch 2 times, most recently from 5f60c46 to 4408efb Compare July 26, 2025 17:45
@penelopeysm penelopeysm marked this pull request as ready for review July 26, 2025 17:46
@penelopeysm penelopeysm force-pushed the py/init-prior-uniform branch from 4408efb to 9c07727 Compare July 26, 2025 18:25
@penelopeysm penelopeysm force-pushed the py/init-prior-uniform branch from 9c07727 to ef038c6 Compare August 5, 2025 11:52
Base automatically changed from breaking to main August 7, 2025 09:45
@penelopeysm penelopeysm force-pushed the py/init-prior-uniform branch from ef038c6 to 5961ca9 Compare August 8, 2025 10:15
@penelopeysm penelopeysm mentioned this pull request Aug 8, 2025
8 tasks
@penelopeysm penelopeysm force-pushed the py/init-prior-uniform branch from 5961ca9 to 0656487 Compare August 8, 2025 10:20
@penelopeysm penelopeysm changed the base branch from main to breaking August 8, 2025 10:20
@penelopeysm penelopeysm force-pushed the py/init-prior-uniform branch 2 times, most recently from fd78d42 to bfcdbb9 Compare August 10, 2025 13:33
Note that, apart from being simpler code, Distributions.Uniform also
doesn't allow the lower and upper bounds to be exactly equal (but we
might like to keep that option open in DynamicPPL, e.g. if the user
wants to initialise all values to the same value in linked space).
This should have been changed in #940, but slipped through as the file
wasn't listed as one of the changed files.
@penelopeysm penelopeysm force-pushed the py/init-prior-uniform branch from bfcdbb9 to ab3e8da Compare August 10, 2025 13:44
@penelopeysm penelopeysm requested a review from mhauru August 12, 2025 15:05
!!! warning "Values must be unlinked"
The values returned by `init` are always in the untransformed space, i.e.,
they must be within the support of the original distribution. That means that,
for example, `init(rng, dist, u::UniformInit)` will in general return values that
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about calling it UniformLinkedInit to avoid a confusion here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmmm I don't have a super strong preference (the current docstring clarifies it, and the lack of a docstring was IMO the main problem with SampleFromUniform), but let's go with the more explicit and clearer name.

Copy link
Member Author

@penelopeysm penelopeysm Aug 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although I also don't want to give the impression that init(..., ::UniformLinkedInit) returns values in linked space. :/ On this basis I might prefer the original UniformInit a bit better.

AbstractInitStrategy

Abstract type representing the possible ways of initialising new values for
the random variables in a model (e.g., when creating a new VarInfo).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this have a list of functions subtypes must implement methods for?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call, done.

return if hasvalue(p.params, vn, dist)
x = getvalue(p.params, vn, dist)
if x === missing
init(rng, vn, dist, p.default)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could there be demand for p.default=nothing in which case this would error? I wonder if in some cases the current implementation could cause silent unexpected behaviour, e.g. if you misspell the varname and thus end up getting samples from the prior rather than the expected value.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup. I can't remember what prompted it now, but I did think about this a while back. I don't think that this would get used anywhere in DynamicPPL, but I'm not opposed to adding it in anyway.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, do you prefer p.default, or p.fallback?

@penelopeysm penelopeysm requested a review from mhauru August 12, 2025 17:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants