diff --git a/HISTORY.md b/HISTORY.md index 2cdd2a644..8be2f43a6 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,3 +1,7 @@ +# 0.40.4 + +Fixes a bug where `initial_state` was not respected for NUTS if `resume_from` was not also specified. + # 0.40.3 This patch makes the `resume_from` keyword argument work correctly when sampling multiple chains. diff --git a/Project.toml b/Project.toml index e047098f9..3939acc92 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Turing" uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" -version = "0.40.3" +version = "0.40.4" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/mcmc/hmc.jl b/src/mcmc/hmc.jl index 6ff975d4d..d776a68a8 100644 --- a/src/mcmc/hmc.jl +++ b/src/mcmc/hmc.jl @@ -120,6 +120,7 @@ function AbstractMCMC.sample( sampler, N; chain_type=chain_type, + initial_state=initial_state, progress=progress, nadapts=_nadapts, discard_initial=_discard_initial, diff --git a/test/mcmc/hmc.jl b/test/mcmc/hmc.jl index 428c193ca..5f811b31d 100644 --- a/test/mcmc/hmc.jl +++ b/test/mcmc/hmc.jl @@ -197,6 +197,27 @@ using Turing @test_throws ErrorException sample(demo_impossible(), NUTS(), 5) end + @testset "NUTS initial parameters" begin + @model function f() + x ~ Normal() + return 10 ~ Normal(x) + end + chn1 = sample(StableRNG(468), f(), NUTS(), 100; save_state=true) + # chn1 should end up around x = 5. + chn2 = sample( + StableRNG(468), + f(), + NUTS(), + 10; + nadapts=0, + discard_adapt=false, + initial_state=chn1.info.samplerstate, + ) + # if chn2 uses initial_state, its first sample should be somewhere around 5. if + # initial_state isn't used, it will be sampled from [-2, 2] so this test should fail + @test isapprox(chn2[:x][1], 5.0; atol=2.0) + end + @testset "(partially) issue: #2095" begin @model function vector_of_dirichlet((::Type{TV})=Vector{Float64}) where {TV} xs = Vector{TV}(undef, 2)