Skip to content
4 changes: 4 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# AdvancedHMC Changelog

## 0.9.0

- Stochastic gradient based methods `SGHMC` and `SGLD` are supported in AdvancedHMC.jl, please note there are similar methods with the same name in Turing.jl, so when using the two packages together, please specify the package exporting the method.

## 0.8.0

- To make an MCMC transtion from phasepoint `z` using trajectory `τ`(or HMCKernel `κ`) under Hamiltonian `h`, use `transition(h, τ, z)` or `transition(rng, h, τ, z)`(if using HMCKernel, use `transition(h, κ, z)` or `transition(rng, h, κ, z)`).
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "AdvancedHMC"
uuid = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
version = "0.8.0"
version = "0.9.0"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
4 changes: 2 additions & 2 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244"

[compat]
AdvancedHMC = "0.8"
AdvancedHMC = "0.9"
Documenter = "1"
DocumenterCitations = "1"
DocumenterCitations = "1"
2 changes: 1 addition & 1 deletion src/AdvancedHMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ include("sampler.jl")
export sample

include("constructors.jl")
export HMCSampler, HMC, NUTS, HMCDA
export HMCSampler, HMC, NUTS, HMCDA, SGHMC

include("abstractmcmc.jl")

Expand Down
122 changes: 122 additions & 0 deletions src/abstractmcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,120 @@
return Transition(t.z, tstat), newstate
end

struct SGHMCState{
TTrans<:Transition,
TMetric<:AbstractMetric,
TKernel<:AbstractMCMCKernel,
TAdapt<:Adaptation.AbstractAdaptor,
T<:AbstractVector{<:Real},
}
"Index of current iteration."
i::Int
"Current [`Transition`](@ref)."
transition::TTrans
"Current [`AbstractMetric`](@ref), possibly adapted."
metric::TMetric
"Current [`AbstractMCMCKernel`](@ref)."
κ::TKernel
"Current [`AbstractAdaptor`](@ref)."
adaptor::TAdapt
velocity::T
end
getadaptor(state::SGHMCState) = state.adaptor
getmetric(state::SGHMCState) = state.metric
getintegrator(state::SGHMCState) = state.κ.τ.integrator

Check warning on line 229 in src/abstractmcmc.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractmcmc.jl#L227-L229

Added lines #L227 - L229 were not covered by tests

function AbstractMCMC.step(
rng::Random.AbstractRNG,
model::AbstractMCMC.LogDensityModel,
spl::SGHMC;
initial_params=nothing,
kwargs...,
)
# Unpack model
logdensity = model.logdensity

# Define metric
metric = make_metric(spl, logdensity)

# Construct the hamiltonian using the initial metric
hamiltonian = Hamiltonian(metric, model)

# Compute initial sample and state.
initial_params = make_initial_params(rng, spl, logdensity, initial_params)
ϵ = make_step_size(rng, spl, hamiltonian, initial_params)
integrator = make_integrator(spl, ϵ)

# Make kernel
κ = make_kernel(spl, integrator)

# Make adaptor
adaptor = make_adaptor(spl, metric, integrator)

# Get an initial sample.
h, t = AdvancedHMC.sample_init(rng, hamiltonian, initial_params)

state = SGHMCState(0, t, metric, κ, adaptor, initial_params)

return AbstractMCMC.step(rng, model, spl, state; kwargs...)
end

function AbstractMCMC.step(
rng::AbstractRNG,
model::AbstractMCMC.LogDensityModel,
spl::SGHMC,
state::SGHMCState;
n_adapts::Int=0,
kwargs...,
)
if haskey(kwargs, :nadapts)
throw(

Check warning on line 275 in src/abstractmcmc.jl

View check run for this annotation

Codecov / codecov/patch

src/abstractmcmc.jl#L275

Added line #L275 was not covered by tests
ArgumentError(
"keyword argument `nadapts` is unsupported. Please use `n_adapts` to specify the number of adaptation steps.",
),
)
end

i = state.i + 1
t_old = state.transition
adaptor = state.adaptor
κ = state.κ
metric = state.metric

# Reconstruct hamiltonian.
h = Hamiltonian(metric, model)

# Compute gradient of log density.
logdensity_and_gradient = Base.Fix1(
LogDensityProblems.logdensity_and_gradient, model.logdensity
)
θ = copy(t_old.z.θ)
grad = last(logdensity_and_gradient(θ))

# Update latent variables and velocity according to
# equation (15) of Chen et al. (2014)
v = state.velocity
η = spl.learning_rate
α = spl.momentum_decay
newv = (1 - α) .* v .+ η .* grad .+ sqrt(2 * η * α) .* randn(rng, eltype(v), length(v))
θ .+= newv

# Make new transition.
z = phasepoint(h, θ, v)
t = transition(rng, h, κ, z)

# Adapt h and spl.
tstat = stat(t)
h, κ, isadapted = adapt!(h, κ, adaptor, i, n_adapts, θ, tstat.acceptance_rate)
tstat = merge(tstat, (is_adapt=isadapted,))

# Compute next sample and state.
sample = Transition(t.z, tstat)
newstate = SGHMCState(i, t, h.metric, κ, adaptor, newv)

return sample, newstate
end

################
### Callback ###
################
Expand Down Expand Up @@ -392,6 +506,10 @@
return NoAdaptation()
end

function make_adaptor(spl::SGHMC, metric::AbstractMetric, integrator::AbstractIntegrator)
return NoAdaptation()
end

function make_adaptor(
spl::HMCSampler, metric::AbstractMetric, integrator::AbstractIntegrator
)
Expand All @@ -417,3 +535,7 @@
function make_kernel(spl::HMCSampler, integrator::AbstractIntegrator)
return spl.κ
end

function make_kernel(spl::SGHMC, integrator::AbstractIntegrator)
return HMCKernel(Trajectory{EndPointTS}(integrator, FixedNSteps(spl.n_leapfrog)))
end
45 changes: 45 additions & 0 deletions src/constructors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,48 @@ function HMCDA(δ, λ; integrator=:leapfrog, metric=:diagonal)
end

sampler_eltype(::HMCDA{T}) where {T} = T

########### Static Hamiltonian Monte Carlo ###########

#############
### SGHMC ###
#############
"""
SGHMC(learning_rate::Real, momentun_decay::Real, integrator = :leapfrog, metric = :diagonal)

Stochastic Gradient Hamiltonian Monte Carlo sampler

# Fields

$(FIELDS)

# Notes

For more information, please view the following paper ([arXiv link](https://arxiv.org/abs/1402.4102)):

- Chen, Tianqi, Emily Fox, and Carlos Guestrin. "Stochastic gradient hamiltonian monte carlo." International conference on machine learning. PMLR, 2014.
"""
struct SGHMC{T<:Real,I<:Union{Symbol,AbstractIntegrator},M<:Union{Symbol,AbstractMetric}} <:
AbstractHMCSampler
"Learning rate for the gradient descent."
learning_rate::T
"Momentum decay rate."
momentum_decay::T
"Number of leapfrog steps."
n_leapfrog::Int
"Choice of integrator, specified either using a `Symbol` or [`AbstractIntegrator`](@ref)"
integrator::I
"Choice of initial metric; `Symbol` means it is automatically initialised. The metric type will be preserved during automatic initialisation and adaption."
metric::M
end

function SGHMC(
learning_rate, momentum_decay, n_leapfrog; integrator=:leapfrog, metric=:diagonal
)
T = determine_sampler_eltype(
learning_rate, momentum_decay, n_leapfrog, integrator, metric
)
return SGHMC(T(learning_rate), T(momentum_decay), n_leapfrog, integrator, metric)
end

sampler_eltype(::SGHMC{T}) where {T} = T
24 changes: 24 additions & 0 deletions test/abstractmcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ using Statistics: mean
nuts = NUTS(0.8)
hmc = HMC(100; integrator=Leapfrog(0.05))
hmcda = HMCDA(0.8, 0.1)
sghmc = SGHMC(0.01, 0.1, 100)

integrator = Leapfrog(1e-3)
κ = AdvancedHMC.make_kernel(nuts, integrator)
Expand Down Expand Up @@ -111,6 +112,29 @@ using Statistics: mean

@test m_est_hmc ≈ [49 / 24, 7 / 6] atol = RNDATOL

samples_sghmc = AbstractMCMC.sample(
rng,
model,
sghmc,
n_adapts + n_samples;
n_adapts=n_adapts,
initial_params=θ_init,
progress=false,
verbose=false,
)

# Transform back to original space.
# NOTE: We're not correcting for the `logabsdetjac` here since, but
# we're only interested in the mean it doesn't matter.
for t in samples_sghmc
t.z.θ .= invlink_gdemo(t.z.θ)
end
m_est_sghmc = mean(samples_sghmc) do t
t.z.θ
end

@test m_est_sghmc ≈ [49 / 24, 7 / 6] atol = RNDATOL

samples_custom = AbstractMCMC.sample(
rng,
model,
Expand Down
Loading