From 0e43a0806ad42e8fdd090c288d41e2ade2bff68c Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 2 Oct 2025 13:55:40 +0100 Subject: [PATCH 1/4] pass initial_state through for NUTS sampling --- src/mcmc/hmc.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/mcmc/hmc.jl b/src/mcmc/hmc.jl index 6ff975d4d..d776a68a8 100644 --- a/src/mcmc/hmc.jl +++ b/src/mcmc/hmc.jl @@ -120,6 +120,7 @@ function AbstractMCMC.sample( sampler, N; chain_type=chain_type, + initial_state=initial_state, progress=progress, nadapts=_nadapts, discard_initial=_discard_initial, From 08abb079b49dbcaf0af59f9f50d943fe42c3805b Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 2 Oct 2025 14:11:12 +0100 Subject: [PATCH 2/4] Add a test --- test/mcmc/hmc.jl | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/test/mcmc/hmc.jl b/test/mcmc/hmc.jl index 428c193ca..893bab113 100644 --- a/test/mcmc/hmc.jl +++ b/test/mcmc/hmc.jl @@ -197,6 +197,16 @@ using Turing @test_throws ErrorException sample(demo_impossible(), NUTS(), 5) end + @testset "NUTS initial parameters" begin + @model function f() + x ~ Normal() + return 10 ~ Normal(x) + end + chn1 = sample(StableRNG(468), f(), NUTS(), 100; save_state=true) + chn2 = sample(StableRNG(468), f(), NUTS(), 10; initial_state=chn1.info.samplerstate) + @test isapprox(chn1[:x], 5.0; atol=2.0) + end + @testset "(partially) issue: #2095" begin @model function vector_of_dirichlet((::Type{TV})=Vector{Float64}) where {TV} xs = Vector{TV}(undef, 2) From 58fdc729ffa0a2f8b46ec89e3236494fe8731677 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 2 Oct 2025 14:14:03 +0100 Subject: [PATCH 3/4] add test --- test/mcmc/hmc.jl | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/test/mcmc/hmc.jl b/test/mcmc/hmc.jl index 893bab113..5f811b31d 100644 --- a/test/mcmc/hmc.jl +++ b/test/mcmc/hmc.jl @@ -203,8 +203,19 @@ using Turing return 10 ~ Normal(x) end chn1 = sample(StableRNG(468), f(), NUTS(), 100; save_state=true) - chn2 = sample(StableRNG(468), f(), NUTS(), 10; initial_state=chn1.info.samplerstate) - @test isapprox(chn1[:x], 5.0; atol=2.0) + # chn1 should end up around x = 5. + chn2 = sample( + StableRNG(468), + f(), + NUTS(), + 10; + nadapts=0, + discard_adapt=false, + initial_state=chn1.info.samplerstate, + ) + # if chn2 uses initial_state, its first sample should be somewhere around 5. if + # initial_state isn't used, it will be sampled from [-2, 2] so this test should fail + @test isapprox(chn2[:x][1], 5.0; atol=2.0) end @testset "(partially) issue: #2095" begin From 8935940e5bd4b5de4aeacc84862ce9c6f91e34ff Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 2 Oct 2025 14:15:10 +0100 Subject: [PATCH 4/4] bump patch --- HISTORY.md | 4 ++++ Project.toml | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/HISTORY.md b/HISTORY.md index 2cdd2a644..8be2f43a6 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,3 +1,7 @@ +# 0.40.4 + +Fixes a bug where `initial_state` was not respected for NUTS if `resume_from` was not also specified. + # 0.40.3 This patch makes the `resume_from` keyword argument work correctly when sampling multiple chains. diff --git a/Project.toml b/Project.toml index e047098f9..3939acc92 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Turing" uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" -version = "0.40.3" +version = "0.40.4" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"