Skip to content

Commit e5a038e

Browse files
committed
Allow opting out of TSVI
1 parent 62a8746 commit e5a038e

File tree

5 files changed

+60
-18
lines changed

5 files changed

+60
-18
lines changed

docs/src/api.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,12 @@ Similarly, we can revert this with [`DynamicPPL.unfix`](@ref), i.e. return the v
110110
DynamicPPL.unfix
111111
```
112112

113+
## Controlling threadsafe evaluation
114+
115+
```@docs
116+
DynamicPPL.set_threadsafe_eval!
117+
```
118+
113119
## Predicting
114120

115121
DynamicPPL provides functionality for generating samples from the posterior predictive distribution through the `predict` function. This allows you to use posterior parameter samples to generate predictions for unobserved data points.

src/DynamicPPL.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ export AbstractVarInfo,
9292
getargnames,
9393
extract_priors,
9494
values_as_in_model,
95+
set_threadsafe_eval!,
9596
# LogDensityFunction
9697
LogDensityFunction,
9798
# Leaf contexts
@@ -212,8 +213,12 @@ include("test_utils.jl")
212213
include("experimental.jl")
213214
include("deprecated.jl")
214215

215-
if isdefined(Base.Experimental, :register_error_hint)
216-
function __init__()
216+
function __init__()
217+
# This has to be in the `__init__()` function, if it's placed at the top level it
218+
# always evaluates to false.
219+
DynamicPPL.set_threadsafe_eval!(Threads.nthreads() > 1)
220+
221+
if isdefined(Base.Experimental, :register_error_hint)
217222
# Better error message if users forget to load JET.jl
218223
Base.Experimental.register_error_hint(MethodError) do io, exc, argtypes, _
219224
requires_jet =

src/fasteval.jl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -219,10 +219,7 @@ function (f::FastLogDensityAt)(params::AbstractVector{<:Real})
219219
# which is unnecessary. So we shortcircuit this by simply calling `_evaluate!!`
220220
# directly. To preserve thread-safety we need to reproduce the ThreadSafeVarInfo logic
221221
# here.
222-
# TODO(penelopeysm): This should _not_ check Threads.nthreads(). I still don't know what
223-
# it _should_ do, but this is wrong regardless.
224-
# https://github.com/TuringLang/DynamicPPL.jl/issues/1086
225-
vi = if Threads.nthreads() > 1
222+
vi = if DynamicPPL.USE_THREADSAFE_EVAL[]
226223
accs = map(
227224
acc -> DynamicPPL.convert_eltype(float_type_with_fallback(eltype(params)), acc),
228225
accs,

src/model.jl

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,25 @@
1+
# This is overridden in the `__init__()` function (src/DynamicPPL.jl)
2+
USE_THREADSAFE_EVAL = Ref(true)
3+
4+
"""
5+
DynamicPPL.set_threadsafe_eval!(val::Bool)
6+
7+
Enable or disable threadsafe model evaluation globally. By default, threadsafe evaluation is
8+
used whenever Julia is run with multiple threads.
9+
10+
However, this is not necessary for the vast majority of DynamicPPL models. **In particular,
11+
use of threaded sampling with MCMCChains alone does NOT require threadsafe evaluation.**
12+
Threadsafe evaluation is only needed when manipulating `VarInfo` objects in parallel, e.g.
13+
when using `x ~ dist` statements inside `Threads.@threads` blocks.
14+
15+
If you do not need threadsafe evaluation, disabling it can lead to significant performance
16+
improvements.
17+
"""
18+
function set_threadsafe_eval!(val::Bool)
19+
USE_THREADSAFE_EVAL[] = val
20+
return nothing
21+
end
22+
123
"""
224
struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractContext}
325
f::F
@@ -863,16 +885,6 @@ function (model::Model)(rng::Random.AbstractRNG, varinfo::AbstractVarInfo=VarInf
863885
return first(init!!(rng, model, varinfo))
864886
end
865887

866-
"""
867-
use_threadsafe_eval(context::AbstractContext, varinfo::AbstractVarInfo)
868-
869-
Return `true` if evaluation of a model using `context` and `varinfo` should
870-
wrap `varinfo` in `ThreadSafeVarInfo`, i.e. threadsafe evaluation, and `false` otherwise.
871-
"""
872-
function use_threadsafe_eval(context::AbstractContext, varinfo::AbstractVarInfo)
873-
return Threads.nthreads() > 1
874-
end
875-
876888
"""
877889
init!!(
878890
[rng::Random.AbstractRNG,]
@@ -912,14 +924,14 @@ end
912924
913925
Evaluate the `model` with the given `varinfo`.
914926
915-
If multiple threads are available, the varinfo provided will be wrapped in a
927+
If threadsafe evaluation is enabled, the varinfo provided will be wrapped in a
916928
`ThreadSafeVarInfo` before evaluation.
917929
918930
Returns a tuple of the model's return value, plus the updated `varinfo`
919931
(unwrapped if necessary).
920932
"""
921933
function AbstractPPL.evaluate!!(model::Model, varinfo::AbstractVarInfo)
922-
return if use_threadsafe_eval(model.context, varinfo)
934+
return if DynamicPPL.USE_THREADSAFE_EVAL[]
923935
evaluate_threadsafe!!(model, varinfo)
924936
else
925937
evaluate_threadunsafe!!(model, varinfo)

test/threadsafe.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,26 @@
11
@testset "threadsafe.jl" begin
2+
@testset "set threadsafe eval" begin
3+
# A dummy model that lets us see what type of VarInfo is being used for evaluation.
4+
@model function find_out_varinfo_type()
5+
x ~ Normal()
6+
return typeof(__varinfo__)
7+
end
8+
model = find_out_varinfo_type()
9+
10+
# Check the default.
11+
@test DynamicPPL.USE_THREADSAFE_EVAL[] == (Threads.nthreads() > 1)
12+
# Disable it.
13+
DynamicPPL.set_threadsafe_eval!(false)
14+
@test DynamicPPL.USE_THREADSAFE_EVAL[] == false
15+
@test !(model() <: DynamicPPL.ThreadSafeVarInfo)
16+
# Enable it.
17+
DynamicPPL.set_threadsafe_eval!(true)
18+
@test DynamicPPL.USE_THREADSAFE_EVAL[] == true
19+
@test model() <: DynamicPPL.ThreadSafeVarInfo
20+
# Reset to default to avoid messing with other tests.
21+
DynamicPPL.set_threadsafe_eval!(Threads.nthreads() > 1)
22+
end
23+
224
@testset "constructor" begin
325
vi = VarInfo(gdemo_default)
426
threadsafe_vi = @inferred DynamicPPL.ThreadSafeVarInfo(vi)

0 commit comments

Comments
 (0)