diff --git a/src/nodes/predefined.jl b/src/nodes/predefined.jl index 7bc1bf17b..5257b4de3 100644 --- a/src/nodes/predefined.jl +++ b/src/nodes/predefined.jl @@ -34,6 +34,7 @@ include("predefined/continuous_transition.jl") include("predefined/half_normal.jl") include("predefined/binomial_polya.jl") include("predefined/multinomial_polya.jl") +include("predefined/sigmoid.jl") include("predefined/flow/flow.jl") include("predefined/delta/delta.jl") diff --git a/src/nodes/predefined/sigmoid.jl b/src/nodes/predefined/sigmoid.jl new file mode 100644 index 000000000..7c4c80e33 --- /dev/null +++ b/src/nodes/predefined/sigmoid.jl @@ -0,0 +1,18 @@ +using StatsFuns: logistic, softplus +using Distributions: pdf + +export Sigmoid + +struct Sigmoid end + +@node Sigmoid Stochastic [out, in, ζ] + +@average_energy Sigmoid (q_out::Categorical, q_in::UnivariateNormalDistributionsFamily, q_ζ::PointMass) = begin + m_out = pdf(q_out, 1) + m_in, v_in = mean_var(q_in) + + ζ_hat = mean(q_ζ) + + U = -(m_in * m_out - softplus(-ζ_hat) - (0.5 * (m_in + ζ_hat)) - 0.5 * ((logistic(ζ_hat) - 0.5)/ζ_hat) * (m_in^2 + v_in - ζ_hat^2)) + return U +end diff --git a/src/rules/gamma_mixture/a.jl b/src/rules/gamma_mixture/a.jl index 9623a082e..9c00830e4 100644 --- a/src/rules/gamma_mixture/a.jl +++ b/src/rules/gamma_mixture/a.jl @@ -1,5 +1,5 @@ -@rule GammaMixture((:a, k), Marginalisation) (q_out::Any, q_switch::Any, q_b::GammaDistributionsFamily) = begin +@rule GammaMixture((:a, k), Marginalisation) (q_out::GammaDistributionsFamily, q_switch::Any, q_b::GammaDistributionsFamily) = begin p = probvec(q_switch)[k] β = mean(log, q_out) + mean(log, q_b) γ = p * β diff --git a/src/rules/predefined.jl b/src/rules/predefined.jl index a406bc2f2..883f4608a 100644 --- a/src/rules/predefined.jl +++ b/src/rules/predefined.jl @@ -198,3 +198,7 @@ include("multinomial_polya/x.jl") include("dirichlet_collection/out.jl") include("dirichlet_collection/marginals.jl") + +include("sigmoid/in.jl") +include("sigmoid/out.jl") +include("sigmoid/zeta.jl") diff --git a/src/rules/sigmoid/in.jl b/src/rules/sigmoid/in.jl new file mode 100644 index 000000000..1fe0e92db --- /dev/null +++ b/src/rules/sigmoid/in.jl @@ -0,0 +1,19 @@ +using Distributions: pdf +using StatsFuns: logistic +@rule Sigmoid(:in, Marginalisation) (q_out::Categorical, q_ζ::PointMass) = begin + m_out = pdf(q_out, 1) + ζ_hat = mean(q_ζ) + w = (logistic(ζ_hat) - 0.5)/ζ_hat + ξ = (m_out - 0.5) * w + T = promote_type(eltype(m_out), eltype(ζ_hat)) + return NormalWeightedMeanPrecision{T}(ξ, w) +end + +@rule Sigmoid(:in, Marginalisation) (q_out::PointMass, q_ζ::PointMass) = begin + m_out = mean(q_out) + ζ_hat = mean(q_ζ) + w = (logistic(ζ_hat) - 0.5)/ζ_hat + ξ = (m_out - 0.5) * w + T = promote_type(eltype(m_out), eltype(ζ_hat)) + return NormalWeightedMeanPrecision{T}(ξ, w) +end diff --git a/src/rules/sigmoid/out.jl b/src/rules/sigmoid/out.jl new file mode 100644 index 000000000..f7af8f1c4 --- /dev/null +++ b/src/rules/sigmoid/out.jl @@ -0,0 +1,11 @@ +using StatsFuns: logistic +@rule Sigmoid(:out, Marginalisation) (q_in::UnivariateNormalDistributionsFamily, q_ζ::PointMass) = begin + m_in = mean(q_in) + ζ_hat = mean(q_ζ) + p = logistic(m_in) + T = promote_type(eltype(m_in), eltype(ζ_hat)) + probs = clamp.([p, 1 - p], tiny, 1 - tiny) + probs ./= sum(probs) + probs_T = convert(Vector{T}, probs) + return Categorical(probs_T) +end diff --git a/src/rules/sigmoid/zeta.jl b/src/rules/sigmoid/zeta.jl new file mode 100644 index 000000000..8d4da1086 --- /dev/null +++ b/src/rules/sigmoid/zeta.jl @@ -0,0 +1,5 @@ +@rule Sigmoid(:ζ, Marginalisation) (q_out::Any, q_in::UnivariateNormalDistributionsFamily) = begin + m_in, v_in = mean_var(q_in) + T = promote_type(eltype(m_in), eltype(v_in)) + return PointMass{T}(sqrt(m_in^2 + v_in)) +end diff --git a/test/nodes/predefined/sigmoid_tests.jl b/test/nodes/predefined/sigmoid_tests.jl new file mode 100644 index 000000000..d5148633f --- /dev/null +++ b/test/nodes/predefined/sigmoid_tests.jl @@ -0,0 +1,18 @@ +@testitem "sigmoidNode" begin + using ReactiveMP, Random, BayesBase, ExponentialFamily + import ReactiveMP: Sigmoid + + @testset "Average energy" begin + q_in = NormalMeanVariance(0.0, 1.0) + for normal_fam in (NormalMeanVariance, NormalMeanPrecision, NormalWeightedMeanPrecision) + q_in_adj = convert(normal_fam, q_in) + @test score( + AverageEnergy(), + Sigmoid, + Val{(:out, :in, :ζ)}(), + (Marginal(Categorical(0.5, 0.5), false, false, nothing), Marginal(q_in_adj, false, false, nothing), Marginal(PointMass(1.0), false, false, nothing)), + nothing + ) ≈ 0.8132616875182228 + end + end +end diff --git a/test/rules/sigmoid/in_tests.jl b/test/rules/sigmoid/in_tests.jl new file mode 100644 index 000000000..ace1e835d --- /dev/null +++ b/test/rules/sigmoid/in_tests.jl @@ -0,0 +1,22 @@ +@testitem "rules:Sigmoid:in" begin + using ReactiveMP, BayesBase, Random, ExponentialFamily, Distributions + using StatsFuns: logistic + + import ReactiveMP: @test_rules + + @testset "Mean Field: (q_out::Categorical, q_ζ::PointMass) - Float64" begin + @test_rules [check_type_promotion = true, atol = [Float64 => 1e-5]] Sigmoid(:in, Marginalisation) [ + (input = (q_out = Categorical([0.5, 0.5]), q_ζ = PointMass(1.0)), output = NormalWeightedMeanPrecision(0.0, 0.2310585786300049)), + (input = (q_out = Categorical([1.0, 0.0]), q_ζ = PointMass(1.0)), output = NormalWeightedMeanPrecision(0.11552928931500245, 0.2310585786300049)), + (input = (q_out = Categorical([0.0, 1.0]), q_ζ = PointMass(1.0)), output = NormalWeightedMeanPrecision(-0.11552928931500245, 0.2310585786300049)) + ] + end + + @testset "Mean Field: (q_out::PointMass, q_ζ::PointMass) - Float64" begin + @test_rules [check_type_promotion = true, atol = [Float64 => 1e-5]] Sigmoid(:in, Marginalisation) [ + (input = (q_out = PointMass(0.5), q_ζ = PointMass(1.0)), output = NormalWeightedMeanPrecision(0.0, 0.2310585786300049)), + (input = (q_out = PointMass(1.0), q_ζ = PointMass(1.0)), output = NormalWeightedMeanPrecision(0.11552928931500245, 0.2310585786300049)), + (input = (q_out = PointMass(0.0), q_ζ = PointMass(1.0)), output = NormalWeightedMeanPrecision(-0.11552928931500245, 0.2310585786300049)) + ] + end +end diff --git a/test/rules/sigmoid/out_tests.jl b/test/rules/sigmoid/out_tests.jl new file mode 100644 index 000000000..92fde765f --- /dev/null +++ b/test/rules/sigmoid/out_tests.jl @@ -0,0 +1,19 @@ +@testitem "rules:Sigmoid:out" begin + using ReactiveMP, BayesBase, Random, ExponentialFamily, Distributions + using StatsFuns: logistic + + import ReactiveMP: @test_rules + + @testset "Mean Field: (q_in::UnivariateNormalDistributionsFamily, q_ζ::PointMass)" begin + q_in = [NormalMeanVariance(0.0, 1.0), NormalMeanVariance(-1.0, 1.0), NormalMeanVariance(10.0, 1.0)] + results = [[0.5, 0.5], [0.2689414213699951, 0.7310585786300049], [0.9999546021312976, 4.5397868702390376e-5]] + for (i, result) in enumerate(results) + for normal_fam in (NormalMeanVariance, NormalMeanPrecision, NormalWeightedMeanPrecision) + q_in_adj = convert(normal_fam, q_in[i]) + @test_rules [check_type_promotion = true, atol = [Float64 => 1e-5]] Sigmoid(:out, Marginalisation) [( + input = (q_in = q_in_adj, q_ζ = PointMass(2.0)), output = Categorical(result) + )] + end + end + end +end diff --git a/test/rules/sigmoid/zeta_tests.jl b/test/rules/sigmoid/zeta_tests.jl new file mode 100644 index 000000000..7e8668f3b --- /dev/null +++ b/test/rules/sigmoid/zeta_tests.jl @@ -0,0 +1,19 @@ +@testitem "rules:Sigmoid:zeta" begin + using ReactiveMP, BayesBase, Random, ExponentialFamily, Distributions + using StatsFuns: logistic + + import ReactiveMP: @test_rules + + @testset "Mean Field: (q_out::Any, q_in::UnivariateNormalDistributionsFamily)" begin + q_in = [NormalMeanVariance(0.0, 1.0), NormalMeanVariance(-1.0, 1.0), NormalMeanVariance(10.0, 1.0)] + results = [1.0, 1.4142135623730951, 10.04987562112089] + for (i, result) in enumerate(results) + for normal_fam in (NormalMeanVariance, NormalMeanPrecision, NormalWeightedMeanPrecision) + q_in_adj = convert(normal_fam, q_in[i]) + @test_rules [check_type_promotion = false, atol = [Float64 => 1e-5]] Sigmoid(:ζ, Marginalisation) [( + input = (q_out = 2.0, q_in = q_in_adj), output = PointMass(result) + )] + end + end + end +end