Skip to content
3 changes: 3 additions & 0 deletions src/AbstractGPs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ include("latent_gp.jl")
# Plotting utilities.
include("util/plotting.jl")

# Autodiff utilities
include("util/ad.jl")

# Testing utilities.
include("util/TestUtils.jl")

Expand Down
9 changes: 7 additions & 2 deletions src/finite_gp_projection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
The finite-dimensional projection of the AbstractGP `f` at `x`. Assumed to be observed under
Gaussian noise with zero mean and covariance matrix `Σy`
"""
struct FiniteGP{Tf<:AbstractGP,Tx<:AbstractVector,TΣ} <: AbstractMvNormal
struct FiniteGP{Tf<:AbstractGP,Tx<:AbstractVector,TΣ<:AbstractMatrix{<:Real}} <:
AbstractMvNormal
f::Tf
x::Tx
Σy::TΣ
Expand All @@ -17,7 +18,11 @@
const default_σ² = 1e-18

function FiniteGP(f::AbstractGP, x::AbstractVector, σ²::Real=default_σ²)
return FiniteGP(f, x, Fill(σ², length(x)))
return FiniteGP(f, x, ScalMat(length(x), σ²))

Check warning on line 21 in src/finite_gp_projection.jl

View check run for this annotation

Codecov / codecov/patch

src/finite_gp_projection.jl#L21

Added line #L21 was not covered by tests
end

function FiniteGP(f::AbstractGP, x::AbstractVector, σ²::UniformScaling)
return FiniteGP(f, x, σ²[1, 1])

Check warning on line 25 in src/finite_gp_projection.jl

View check run for this annotation

Codecov / codecov/patch

src/finite_gp_projection.jl#L24-L25

Added lines #L24 - L25 were not covered by tests
end

## conversions
Expand Down
6 changes: 6 additions & 0 deletions src/util/ad.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import ChainRulesCore: ProjectTo, Tangent
using PDMats: ScalMat

ProjectTo(x::T) where {T<:ScalMat} = ProjectTo{T}(; dim=x.dim, value=ProjectTo(x.value))
(pr::ProjectTo{<:ScalMat})(dx::ScalMat) = ScalMat(pr.dim, pr.value(dx.value))
(pr::ProjectTo{<:ScalMat})(dx::Tangent{<:ScalMat}) = ScalMat(pr.dim, pr.value(dx.value))

Check warning on line 6 in src/util/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/util/ad.jl#L4-L6

Added lines #L4 - L6 were not covered by tests
9 changes: 9 additions & 0 deletions test/finite_gp_projection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,15 @@ end
first(FiniteDifferences.grad(central_fdm(3, 1), Base.Fix1(logpdf, fx), y))
@test Distributions.sqmahal!(r, fx, Y) ≈ Distributions.sqmahal(fx, Y)
end

@testset "FiniteGP with UniformScaling" begin
f = GP(SqExponentialKernel())
fx = f(rand(10), 2.0 * I)
# for now, just check that it runs
_ = mean(fx)
_ = mean_and_cov(fx)
_ = rand(fx)
end
end

@testset "Docs" begin
Expand Down
Loading