Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions ext/LinearSolveForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ const DualBLinearProblem = LinearProblem{
const DualAbstractLinearProblem = Union{
DualLinearProblem, DualALinearProblem, DualBLinearProblem}

LinearSolve.@concrete mutable struct DualLinearCache{DT <: Dual}
LinearSolve.@concrete mutable struct DualLinearCache{DT}
linear_cache

partials_A
Expand Down Expand Up @@ -113,10 +113,10 @@ function linearsolve_dual_solution(
end

function linearsolve_dual_solution(u::AbstractArray, partials,
cache::DualLinearCache{DT}) where {DT}
cache::DualLinearCache{DT}) where {T, V, N, DT <: Dual{T,V,N}}
# Handle single-level duals for arrays
partials_list = RecursiveArrayTools.VectorOfArray(partials)
return map(((uᵢ, pᵢ),) -> DT(uᵢ, Partials(Tuple(pᵢ))),
return map(((uᵢ, pᵢ),) -> DT(uᵢ, Partials{N,V}(NTuple{N,V}(pᵢ))),
zip(u, partials_list[i, :] for i in 1:length(partials_list.u[1])))
end

Expand Down
22 changes: 21 additions & 1 deletion test/nopre/jet.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using LinearSolve, RecursiveFactorization, LinearAlgebra, SparseArrays, Test
using LinearSolve, ForwardDiff, RecursiveFactorization, LinearAlgebra, SparseArrays, Test
using JET

# Dense problem setup
Expand All @@ -22,6 +22,18 @@ prob_sparse = LinearProblem(A_sparse, b)
A_sparse_spd = sparse(A_spd)
prob_sparse_spd = LinearProblem(A_sparse_spd, b)

# Dual problem set up
function h(p)
(A = [p[1] p[2]+1 p[2]^3;
3*p[1] p[1]+5 p[2] * p[1]-4;
p[2]^2 9*p[1] p[2]],
b = [p[1] + 1, p[2] * 2, p[1]^2])
end

A, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])

dual_prob = LinearProblem(A, b)

@testset "JET Tests for Dense Factorizations" begin
# Working tests - these pass JET optimization checks
JET.@test_opt init(prob, nothing)
Expand Down Expand Up @@ -108,3 +120,11 @@ end
JET.@test_opt solve(prob) broken=true
JET.@test_opt solve(prob_sparse) broken=true
end

@testset "JET Tests for creating Dual solutions" begin
# Make sure there's no runtime dispatch when making solutions of Dual problems
dual_cache = init(prob)
ext = Base.get_extension(LinearSolve, :LinearSolveForwardDiffExt)
JET.@test_opt ext.linearsolve_dual_solution(
[1.0, 1.0, 1.0], [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], dual_cache )
end
Loading