Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 7 additions & 34 deletions benchmark/continuous_transition_bench.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,52 +54,30 @@ function create_benchmark_data(dx, dy)
m_y = MvNormalMeanCovariance(μy, Σy)
m_x = MvNormalMeanCovariance(μx, Σx)

return (
meta = meta,
q_y = q_y, q_x = q_x, q_a = q_a, q_W = q_W,
q_y_x = q_y_x,
m_y = m_y, m_x = m_x
)
return (meta = meta, q_y = q_y, q_x = q_x, q_a = q_a, q_W = q_W, q_y_x = q_y_x, m_y = m_y, m_x = m_x)
end

# ============================================================================
# Benchmark Functions
# ============================================================================

function bench_a_structured(data)
@call_rule ContinuousTransition(:a, Marginalisation) (
q_y_x = data.q_y_x,
q_a = data.q_a,
q_W = data.q_W,
meta = data.meta
)
@call_rule ContinuousTransition(:a, Marginalisation) (q_y_x = data.q_y_x, q_a = data.q_a, q_W = data.q_W, meta = data.meta)
end

function bench_a_meanfield(data)
@call_rule ContinuousTransition(:a, Marginalisation) (
q_y = data.q_y,
q_x = data.q_x,
q_a = data.q_a,
q_W = data.q_W,
meta = data.meta
)
@call_rule ContinuousTransition(:a, Marginalisation) (q_y = data.q_y, q_x = data.q_x, q_a = data.q_a, q_W = data.q_W, meta = data.meta)
end

function bench_marginal_y_x(data)
@call_marginalrule ContinuousTransition(:y_x) (
m_y = data.m_y,
m_x = data.m_x,
q_a = data.q_a,
q_W = data.q_W,
meta = data.meta
)
@call_marginalrule ContinuousTransition(:y_x) (m_y = data.m_y, m_x = data.m_x, q_a = data.q_a, q_W = data.q_W, meta = data.meta)
end

# ============================================================================
# Benchmark Runner
# ============================================================================

function run_benchmarks(; quick_mode=false)
function run_benchmarks(; quick_mode = false)
println()
println("=" ^ 80)
println(" ContinuousTransition Rules Benchmark")
Expand All @@ -118,11 +96,7 @@ function run_benchmarks(; quick_mode=false)
println()

# Results storage
results = Dict{String, Vector{Tuple{Int, Int, Float64}}}(
"a_structured" => [],
"a_meanfield" => [],
"marginal_y_x" => []
)
results = Dict{String, Vector{Tuple{Int, Int, Float64}}}("a_structured" => [], "a_meanfield" => [], "marginal_y_x" => [])

for (dx, dy) in test_dims
println("-" ^ 60)
Expand Down Expand Up @@ -192,5 +166,4 @@ end
# ============================================================================

quick_mode = length(ARGS) > 0 && ARGS[1] == "quick"
run_benchmarks(quick_mode=quick_mode)

run_benchmarks(quick_mode = quick_mode)
31 changes: 4 additions & 27 deletions benchmark/rules/continuous_transition/continuous_transition.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,7 @@ function create_ct_benchmark_data(dx, dy)
m_y = MvNormalMeanCovariance(μy, Σy)
m_x = MvNormalMeanCovariance(μx, Σx)

return (
meta = meta,
q_y = q_y, q_x = q_x, q_a = q_a, q_W = q_W,
q_y_x = q_y_x,
m_y = m_y, m_x = m_x
)
return (meta = meta, q_y = q_y, q_x = q_x, q_a = q_a, q_W = q_W, q_y_x = q_y_x, m_y = m_y, m_x = m_x)
end

"""
Expand All @@ -72,23 +67,12 @@ function add_continuous_transition_a_benchmarks(SUITE)

# Structured VMP: q(y,x) joint
SUITE["a"]["Structured"]["dx=$(dx), dy=$(dy)"] = @benchmarkable begin
@call_rule ContinuousTransition(:a, Marginalisation) (
q_y_x = $data.q_y_x,
q_a = $data.q_a,
q_W = $data.q_W,
meta = $data.meta
)
@call_rule ContinuousTransition(:a, Marginalisation) (q_y_x = $data.q_y_x, q_a = $data.q_a, q_W = $data.q_W, meta = $data.meta)
end

# Mean-field VMP: q(y)q(x)q(a)q(W)
SUITE["a"]["Mean-field"]["dx=$(dx), dy=$(dy)"] = @benchmarkable begin
@call_rule ContinuousTransition(:a, Marginalisation) (
q_y = $data.q_y,
q_x = $data.q_x,
q_a = $data.q_a,
q_W = $data.q_W,
meta = $data.meta
)
@call_rule ContinuousTransition(:a, Marginalisation) (q_y = $data.q_y, q_x = $data.q_x, q_a = $data.q_a, q_W = $data.q_W, meta = $data.meta)
end
end
end
Expand All @@ -104,14 +88,7 @@ function add_continuous_transition_marginals_benchmarks(SUITE)

# y_x marginal rule
SUITE["marginals"]["y_x"]["dx=$(dx), dy=$(dy)"] = @benchmarkable begin
@call_marginalrule ContinuousTransition(:y_x) (
m_y = $data.m_y,
m_x = $data.m_x,
q_a = $data.q_a,
q_W = $data.q_W,
meta = $data.meta
)
@call_marginalrule ContinuousTransition(:y_x) (m_y = $data.m_y, m_x = $data.m_x, q_a = $data.q_a, q_W = $data.q_W, meta = $data.meta)
end
end
end