diff --git a/src/AbstractGPs.jl b/src/AbstractGPs.jl index b0cd6885..87ffaa34 100644 --- a/src/AbstractGPs.jl +++ b/src/AbstractGPs.jl @@ -59,6 +59,9 @@ include("latent_gp.jl") # Plotting utilities. include("util/plotting.jl") +# Autodiff utilities +include("util/ad.jl") + # Testing utilities. include("util/TestUtils.jl") diff --git a/src/finite_gp_projection.jl b/src/finite_gp_projection.jl index d09cf139..52833ecc 100644 --- a/src/finite_gp_projection.jl +++ b/src/finite_gp_projection.jl @@ -4,7 +4,8 @@ The finite-dimensional projection of the AbstractGP `f` at `x`. Assumed to be observed under Gaussian noise with zero mean and covariance matrix `Σy` """ -struct FiniteGP{Tf<:AbstractGP,Tx<:AbstractVector,TΣ} <: AbstractMvNormal +struct FiniteGP{Tf<:AbstractGP,Tx<:AbstractVector,TΣ<:AbstractMatrix{<:Real}} <: + AbstractMvNormal f::Tf x::Tx Σy::TΣ @@ -17,7 +18,11 @@ end const default_σ² = 1e-18 function FiniteGP(f::AbstractGP, x::AbstractVector, σ²::Real=default_σ²) - return FiniteGP(f, x, Fill(σ², length(x))) + return FiniteGP(f, x, ScalMat(length(x), σ²)) +end + +function FiniteGP(f::AbstractGP, x::AbstractVector, σ²::UniformScaling) + return FiniteGP(f, x, σ²[1, 1]) end ## conversions diff --git a/src/util/ad.jl b/src/util/ad.jl new file mode 100644 index 00000000..388a2389 --- /dev/null +++ b/src/util/ad.jl @@ -0,0 +1,6 @@ +import ChainRulesCore: ProjectTo, Tangent +using PDMats: ScalMat + +ProjectTo(x::T) where {T<:ScalMat} = ProjectTo{T}(; dim=x.dim, value=ProjectTo(x.value)) +(pr::ProjectTo{<:ScalMat})(dx::ScalMat) = ScalMat(pr.dim, pr.value(dx.value)) +(pr::ProjectTo{<:ScalMat})(dx::Tangent{<:ScalMat}) = ScalMat(pr.dim, pr.value(dx.value)) diff --git a/test/finite_gp_projection.jl b/test/finite_gp_projection.jl index d75c1596..8c34a24c 100644 --- a/test/finite_gp_projection.jl +++ b/test/finite_gp_projection.jl @@ -205,6 +205,15 @@ end first(FiniteDifferences.grad(central_fdm(3, 1), Base.Fix1(logpdf, fx), y)) @test Distributions.sqmahal!(r, fx, Y) ≈ Distributions.sqmahal(fx, Y) end + + @testset "FiniteGP with UniformScaling" begin + f = GP(SqExponentialKernel()) + fx = f(rand(10), 2.0 * I) + # for now, just check that it runs + _ = mean(fx) + _ = mean_and_cov(fx) + _ = rand(fx) + end end @testset "Docs" begin