From cf2edd82270a3eaf7a4ba7541f6035f9b0b66747 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Thu, 14 Aug 2025 12:20:07 -0400 Subject: [PATCH 1/3] make sure that Partials length is known --- ext/LinearSolveForwardDiffExt.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ext/LinearSolveForwardDiffExt.jl b/ext/LinearSolveForwardDiffExt.jl index fb0ec730f..77f8a8659 100644 --- a/ext/LinearSolveForwardDiffExt.jl +++ b/ext/LinearSolveForwardDiffExt.jl @@ -34,7 +34,7 @@ const DualBLinearProblem = LinearProblem{ const DualAbstractLinearProblem = Union{ DualLinearProblem, DualALinearProblem, DualBLinearProblem} -LinearSolve.@concrete mutable struct DualLinearCache{DT <: Dual} +LinearSolve.@concrete mutable struct DualLinearCache{DT} linear_cache partials_A @@ -113,10 +113,10 @@ function linearsolve_dual_solution( end function linearsolve_dual_solution(u::AbstractArray, partials, - cache::DualLinearCache{DT}) where {DT} + cache::DualLinearCache{DT}) where {T, V, N, DT <: Dual{T,V,N}} # Handle single-level duals for arrays partials_list = RecursiveArrayTools.VectorOfArray(partials) - return map(((uᵢ, pᵢ),) -> DT(uᵢ, Partials(Tuple(pᵢ))), + return map(((uᵢ, pᵢ),) -> DT(uᵢ, Partials{N,V}(NTuple{N,V}(pᵢ))), zip(u, partials_list[i, :] for i in 1:length(partials_list.u[1]))) end From dd049b08dd0929c009755ef31ec0b72521f13d15 Mon Sep 17 00:00:00 2001 From: jClugstor Date: Fri, 15 Aug 2025 09:31:52 -0400 Subject: [PATCH 2/3] add Dual problem JET tests --- test/nopre/jet.jl | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/test/nopre/jet.jl b/test/nopre/jet.jl index 853916eba..8f428cc4a 100644 --- a/test/nopre/jet.jl +++ b/test/nopre/jet.jl @@ -1,4 +1,4 @@ -using LinearSolve, RecursiveFactorization, LinearAlgebra, SparseArrays, Test +using LinearSolve, ForwardDiff, RecursiveFactorization, LinearAlgebra, SparseArrays, Test using JET # Dense problem setup @@ -22,6 +22,18 @@ prob_sparse = LinearProblem(A_sparse, b) A_sparse_spd = sparse(A_spd) prob_sparse_spd = LinearProblem(A_sparse_spd, b) +# Dual problem set up +function h(p) + (A = [p[1] p[2]+1 p[2]^3; + 3*p[1] p[1]+5 p[2] * p[1]-4; + p[2]^2 9*p[1] p[2]], + b = [p[1] + 1, p[2] * 2, p[1]^2]) +end + +A, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)]) + +dual_prob = LinearProblem(A, b) + @testset "JET Tests for Dense Factorizations" begin # Working tests - these pass JET optimization checks JET.@test_opt init(prob, nothing) @@ -108,3 +120,11 @@ end JET.@test_opt solve(prob) broken=true JET.@test_opt solve(prob_sparse) broken=true end + +@testset "JET Tests for creating Dual solutions" begin + # Make sure there's no runtime dispatch when making solutions of Dual problems + dual_cache = init(prob) + ext = Base.get_extension(LinearSolve, :LinearSolveForwardDiffExt) + JET.@test_opt ext.linearsolve_dual_solution( + [1.0, 1.0, 1.0], [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], cache) +end \ No newline at end of file From 010269e53ff9f70b3eee0d5bca010404fce3d980 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Fri, 15 Aug 2025 14:34:18 +0000 Subject: [PATCH 3/3] Update test/nopre/jet.jl --- test/nopre/jet.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/nopre/jet.jl b/test/nopre/jet.jl index 8f428cc4a..d774efc34 100644 --- a/test/nopre/jet.jl +++ b/test/nopre/jet.jl @@ -126,5 +126,5 @@ end dual_cache = init(prob) ext = Base.get_extension(LinearSolve, :LinearSolveForwardDiffExt) JET.@test_opt ext.linearsolve_dual_solution( - [1.0, 1.0, 1.0], [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], cache) + [1.0, 1.0, 1.0], [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], dual_cache ) end \ No newline at end of file