1
1
# assume
2
- """
3
- tilde_assume(context::SamplingContext, right, vn, vi)
4
-
5
- Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs),
6
- accumulate the log probability, and return the sampled value with a context associated
7
- with a sampler.
8
-
9
- Falls back to
10
- ```julia
11
- tilde_assume(context.rng, context.context, context.sampler, right, vn, vi)
12
- ```
13
- """
14
- function tilde_assume (context:: SamplingContext , right, vn, vi)
15
- return tilde_assume (context. rng, context. context, context. sampler, right, vn, vi)
16
- end
17
-
18
2
function tilde_assume (context:: AbstractContext , args... )
19
3
return tilde_assume (childcontext (context), args... )
20
4
end
21
5
function tilde_assume (:: DefaultContext , right, vn, vi)
22
- return assume (right, vn, vi)
23
- end
24
-
25
- function tilde_assume (rng:: Random.AbstractRNG , context:: AbstractContext , args... )
26
- return tilde_assume (rng, childcontext (context), args... )
27
- end
28
- function tilde_assume (rng:: Random.AbstractRNG , :: DefaultContext , sampler, right, vn, vi)
29
- return assume (rng, sampler, right, vn, vi)
30
- end
31
- function tilde_assume (:: Random.AbstractRNG , :: InitContext , sampler, right, vn, vi)
32
- return error (
33
- " Encountered SamplingContext->InitContext. This method will be removed in the next PR." ,
34
- )
35
- end
36
- function tilde_assume (:: DefaultContext , sampler, right, vn, vi)
37
- # same as above but no rng
38
- return assume (Random. default_rng (), sampler, right, vn, vi)
6
+ y = getindex_internal (vi, vn)
7
+ f = from_maybe_linked_internal_transform (vi, vn, right)
8
+ x, logjac = with_logabsdet_jacobian (f, y)
9
+ vi = accumulate_assume!! (vi, x, logjac, vn, right)
10
+ return x, vi
39
11
end
40
-
41
12
function tilde_assume (context:: PrefixContext , right, vn, vi)
42
13
# Note that we can't use something like this here:
43
14
# new_vn = prefix(context, vn)
@@ -51,12 +22,6 @@ function tilde_assume(context::PrefixContext, right, vn, vi)
51
22
new_vn, new_context = prefix_and_strip_contexts (context, vn)
52
23
return tilde_assume (new_context, right, new_vn, vi)
53
24
end
54
- function tilde_assume (
55
- rng:: Random.AbstractRNG , context:: PrefixContext , sampler, right, vn, vi
56
- )
57
- new_vn, new_context = prefix_and_strip_contexts (context, vn)
58
- return tilde_assume (rng, new_context, sampler, right, new_vn, vi)
59
- end
60
25
61
26
"""
62
27
tilde_assume!!(context, right, vn, vi)
@@ -76,17 +41,6 @@ function tilde_assume!!(context, right, vn, vi)
76
41
end
77
42
78
43
# observe
79
- """
80
- tilde_observe!!(context::SamplingContext, right, left, vi)
81
-
82
- Handle observed constants with a `context` associated with a sampler.
83
-
84
- Falls back to `tilde_observe!!(context.context, right, left, vi)`.
85
- """
86
- function tilde_observe!! (context:: SamplingContext , right, left, vn, vi)
87
- return tilde_observe!! (context. context, right, left, vn, vi)
88
- end
89
-
90
44
function tilde_observe!! (context:: AbstractContext , right, left, vn, vi)
91
45
return tilde_observe!! (childcontext (context), right, left, vn, vi)
92
46
end
@@ -119,59 +73,3 @@ function tilde_observe!!(::DefaultContext, right, left, vn, vi)
119
73
vi = accumulate_observe!! (vi, right, left, vn)
120
74
return left, vi
121
75
end
122
-
123
- function assume (:: Random.AbstractRNG , spl:: Sampler , dist)
124
- return error (" DynamicPPL.assume: unmanaged inference algorithm: $(typeof (spl)) " )
125
- end
126
-
127
- # fallback without sampler
128
- function assume (dist:: Distribution , vn:: VarName , vi)
129
- y = getindex_internal (vi, vn)
130
- f = from_maybe_linked_internal_transform (vi, vn, dist)
131
- x, logjac = with_logabsdet_jacobian (f, y)
132
- vi = accumulate_assume!! (vi, x, logjac, vn, dist)
133
- return x, vi
134
- end
135
-
136
- # TODO : Remove this thing.
137
- # SampleFromPrior and SampleFromUniform
138
- function assume (
139
- rng:: Random.AbstractRNG ,
140
- sampler:: Union{SampleFromPrior,SampleFromUniform} ,
141
- dist:: Distribution ,
142
- vn:: VarName ,
143
- vi:: VarInfoOrThreadSafeVarInfo ,
144
- )
145
- if haskey (vi, vn)
146
- # Always overwrite the parameters with new ones for `SampleFromUniform`.
147
- if sampler isa SampleFromUniform || is_flagged (vi, vn, " del" )
148
- # TODO (mhauru) Is it important to unset the flag here? The `true` allows us
149
- # to ignore the fact that for VarNamedVector this does nothing, but I'm unsure
150
- # if that's okay.
151
- unset_flag! (vi, vn, " del" , true )
152
- r = init (rng, dist, sampler)
153
- f = to_maybe_linked_internal_transform (vi, vn, dist)
154
- # TODO (mhauru) This should probably be call a function called setindex_internal!
155
- vi = BangBang. setindex!! (vi, f (r), vn)
156
- setorder! (vi, vn, get_num_produce (vi))
157
- else
158
- # Otherwise we just extract it.
159
- r = vi[vn, dist]
160
- end
161
- else
162
- r = init (rng, dist, sampler)
163
- if istrans (vi)
164
- f = to_linked_internal_transform (vi, vn, dist)
165
- vi = push!! (vi, vn, f (r), dist)
166
- # By default `push!!` sets the transformed flag to `false`.
167
- vi = settrans!! (vi, true , vn)
168
- else
169
- vi = push!! (vi, vn, r, dist)
170
- end
171
- end
172
-
173
- # HACK: The above code might involve an `invlink` somewhere, etc. so we need to correct.
174
- logjac = logabsdetjac (istrans (vi, vn) ? link_transform (dist) : identity, r)
175
- vi = accumulate_assume!! (vi, r, - logjac, vn, dist)
176
- return r, vi
177
- end
0 commit comments