diff --git a/ext/LinearSolveCUDAExt.jl b/ext/LinearSolveCUDAExt.jl index 77796409b..c69b747cf 100644 --- a/ext/LinearSolveCUDAExt.jl +++ b/ext/LinearSolveCUDAExt.jl @@ -120,15 +120,21 @@ end function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CUDAOffload32MixedLUFactorization; kwargs...) if cache.isfresh - fact, A_gpu_f32, b_gpu_f32, u_gpu_f32, T32, Torig = LinearSolve.@get_cacheval(cache, :CUDAOffload32MixedLUFactorization) - # Convert to Float32 for factorization using cached type + fact, A_gpu_f32, b_gpu_f32, u_gpu_f32 = LinearSolve.@get_cacheval(cache, :CUDAOffload32MixedLUFactorization) + # Compute 32-bit type on demand and convert + T32 = eltype(cache.A) <: Complex ? ComplexF32 : Float32 A_f32 = T32.(cache.A) copyto!(A_gpu_f32, A_f32) fact = lu(A_gpu_f32) - cache.cacheval = (fact, A_gpu_f32, b_gpu_f32, u_gpu_f32, T32, Torig) + cache.cacheval = (fact, A_gpu_f32, b_gpu_f32, u_gpu_f32) cache.isfresh = false end - fact, A_gpu_f32, b_gpu_f32, u_gpu_f32, T32, Torig = LinearSolve.@get_cacheval(cache, :CUDAOffload32MixedLUFactorization) + fact, A_gpu_f32, b_gpu_f32, u_gpu_f32 = LinearSolve.@get_cacheval(cache, :CUDAOffload32MixedLUFactorization) + + # Compute types on demand for conversions + T32 = eltype(cache.A) <: Complex ? ComplexF32 : Float32 + Torig = eltype(cache.u) + # Convert b to Float32, solve, then convert back to original precision b_f32 = T32.(cache.b) copyto!(b_gpu_f32, b_f32) @@ -142,10 +148,9 @@ end function LinearSolve.init_cacheval(alg::CUDAOffload32MixedLUFactorization, A, b, u, Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) - # Pre-allocate with Float32 arrays and cache types + # Pre-allocate with Float32 arrays m, n = size(A) T32 = eltype(A) <: Complex ? ComplexF32 : Float32 - Torig = eltype(u) noUnitT = typeof(zero(T32)) luT = LinearAlgebra.lutype(noUnitT) ipiv = CuVector{Int32}(undef, min(m, n)) @@ -154,7 +159,7 @@ function LinearSolve.init_cacheval(alg::CUDAOffload32MixedLUFactorization, A, b, A_gpu_f32 = CuMatrix{T32}(undef, m, n) b_gpu_f32 = CuVector{T32}(undef, size(b, 1)) u_gpu_f32 = CuVector{T32}(undef, size(u, 1)) - return (fact, A_gpu_f32, b_gpu_f32, u_gpu_f32, T32, Torig) + return (fact, A_gpu_f32, b_gpu_f32, u_gpu_f32) end end diff --git a/ext/LinearSolveMetalExt.jl b/ext/LinearSolveMetalExt.jl index 81f497725..083bd82c3 100644 --- a/ext/LinearSolveMetalExt.jl +++ b/ext/LinearSolveMetalExt.jl @@ -36,10 +36,9 @@ default_alias_b(::MetalOffload32MixedLUFactorization, ::Any, ::Any) = false function LinearSolve.init_cacheval(alg::MetalOffload32MixedLUFactorization, A, b, u, Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) - # Pre-allocate with Float32 arrays and cache types + # Pre-allocate with Float32 arrays m, n = size(A) T32 = eltype(A) <: Complex ? ComplexF32 : Float32 - Torig = eltype(u) A_f32 = similar(A, T32) b_f32 = similar(b, T32) u_f32 = similar(u, T32) @@ -48,7 +47,7 @@ function LinearSolve.init_cacheval(alg::MetalOffload32MixedLUFactorization, A, b A_mtl = MtlArray{T32}(undef, m, n) b_mtl = MtlVector{T32}(undef, size(b, 1)) u_mtl = MtlVector{T32}(undef, size(u, 1)) - return (luinst, A_f32, b_f32, u_f32, A_mtl, b_mtl, u_mtl, T32, Torig) + return (luinst, A_f32, b_f32, u_f32, A_mtl, b_mtl, u_mtl) end function SciMLBase.solve!(cache::LinearCache, alg::MetalOffload32MixedLUFactorization; @@ -56,26 +55,32 @@ function SciMLBase.solve!(cache::LinearCache, alg::MetalOffload32MixedLUFactoriz A = cache.A A = convert(AbstractMatrix, A) if cache.isfresh - luinst, A_f32, b_f32, u_f32, A_mtl, b_mtl, u_mtl, T32, Torig = @get_cacheval(cache, :MetalOffload32MixedLUFactorization) - # Convert to appropriate 32-bit type for factorization using cached type + luinst, A_f32, b_f32, u_f32, A_mtl, b_mtl, u_mtl = @get_cacheval(cache, :MetalOffload32MixedLUFactorization) + # Compute 32-bit type on demand and convert + T32 = eltype(A) <: Complex ? ComplexF32 : Float32 A_f32 .= T32.(A) copyto!(A_mtl, A_f32) res = lu(A_mtl) # Store factorization and pre-allocated arrays fact = LU(Array(res.factors), Array{Int}(res.ipiv), res.info) - cache.cacheval = (fact, A_f32, b_f32, u_f32, A_mtl, b_mtl, u_mtl, T32, Torig) + cache.cacheval = (fact, A_f32, b_f32, u_f32, A_mtl, b_mtl, u_mtl) cache.isfresh = false end - fact, A_f32, b_f32, u_f32, A_mtl, b_mtl, u_mtl, T32, Torig = @get_cacheval(cache, :MetalOffload32MixedLUFactorization) - # Convert b to 32-bit for solving using cached type + fact, A_f32, b_f32, u_f32, A_mtl, b_mtl, u_mtl = @get_cacheval(cache, :MetalOffload32MixedLUFactorization) + + # Compute types on demand for conversions + T32 = eltype(cache.A) <: Complex ? ComplexF32 : Float32 + Torig = eltype(cache.u) + + # Convert b to 32-bit for solving b_f32 .= T32.(cache.b) # Create a temporary Float32 LU factorization for solving fact_f32 = LU(T32.(fact.factors), fact.ipiv, fact.info) ldiv!(u_f32, fact_f32, b_f32) - # Convert back to original precision using cached type + # Convert back to original precision cache.u .= Torig.(u_f32) SciMLBase.build_linear_solution(alg, cache.u, nothing, cache) end diff --git a/ext/LinearSolveRecursiveFactorizationExt.jl b/ext/LinearSolveRecursiveFactorizationExt.jl index 340b53838..947dd8020 100644 --- a/ext/LinearSolveRecursiveFactorizationExt.jl +++ b/ext/LinearSolveRecursiveFactorizationExt.jl @@ -47,14 +47,13 @@ function LinearSolve.init_cacheval(alg::RF32MixedLUFactorization{P, T}, A, b, u, # Pre-allocate appropriate 32-bit arrays based on input type m, n = size(A) T32 = eltype(A) <: Complex ? ComplexF32 : Float32 - Torig = eltype(u) A_32 = similar(A, T32) b_32 = similar(b, T32) u_32 = similar(u, T32) luinst = ArrayInterface.lu_instance(rand(T32, 0, 0)) ipiv = Vector{LinearAlgebra.BlasInt}(undef, min(m, n)) - # Return tuple with pre-allocated arrays and cached types - (luinst, ipiv, A_32, b_32, u_32, T32, Torig) + # Return tuple with pre-allocated arrays + (luinst, ipiv, A_32, b_32, u_32) end function SciMLBase.solve!( @@ -65,8 +64,9 @@ function SciMLBase.solve!( if cache.isfresh # Get pre-allocated arrays from cacheval - luinst, ipiv, A_32, b_32, u_32, T32, Torig = LinearSolve.@get_cacheval(cache, :RF32MixedLUFactorization) - # Copy A to pre-allocated 32-bit array using cached type + luinst, ipiv, A_32, b_32, u_32 = LinearSolve.@get_cacheval(cache, :RF32MixedLUFactorization) + # Compute 32-bit type on demand and copy A + T32 = eltype(A) <: Complex ? ComplexF32 : Float32 A_32 .= T32.(A) # Ensure ipiv is the right size @@ -75,7 +75,7 @@ function SciMLBase.solve!( end fact = RecursiveFactorization.lu!(A_32, ipiv, Val(P), Val(T), check = false) - cache.cacheval = (fact, ipiv, A_32, b_32, u_32, T32, Torig) + cache.cacheval = (fact, ipiv, A_32, b_32, u_32) if !LinearAlgebra.issuccess(fact) return SciMLBase.build_linear_solution( @@ -86,15 +86,19 @@ function SciMLBase.solve!( end # Get the factorization and pre-allocated arrays from the cache - fact_cached, ipiv, A_32, b_32, u_32, T32, Torig = LinearSolve.@get_cacheval(cache, :RF32MixedLUFactorization) + fact_cached, ipiv, A_32, b_32, u_32 = LinearSolve.@get_cacheval(cache, :RF32MixedLUFactorization) - # Copy b to pre-allocated 32-bit array using cached type + # Compute types on demand for conversions + T32 = eltype(cache.A) <: Complex ? ComplexF32 : Float32 + Torig = eltype(cache.u) + + # Copy b to pre-allocated 32-bit array b_32 .= T32.(cache.b) # Solve in 32-bit precision ldiv!(u_32, fact_cached, b_32) - # Convert back to original precision using cached type + # Convert back to original precision cache.u .= Torig.(u_32) SciMLBase.build_linear_solution( diff --git a/src/appleaccelerate.jl b/src/appleaccelerate.jl index 6de1567ee..0b031ddbc 100644 --- a/src/appleaccelerate.jl +++ b/src/appleaccelerate.jl @@ -300,13 +300,12 @@ function LinearSolve.init_cacheval(alg::AppleAccelerate32MixedLUFactorization, A # Pre-allocate appropriate 32-bit arrays based on input type m, n = size(A) T32 = eltype(A) <: Complex ? ComplexF32 : Float32 - Torig = eltype(u) A_32 = similar(A, T32) b_32 = similar(b, T32) u_32 = similar(u, T32) luinst = ArrayInterface.lu_instance(rand(T32, 0, 0)) - # Return tuple with pre-allocated arrays and cached types - (LU(luinst.factors, similar(A_32, Cint, 0), luinst.info), Ref{Cint}(), A_32, b_32, u_32, T32, Torig) + # Return tuple with pre-allocated arrays + (LU(luinst.factors, similar(A_32, Cint, 0), luinst.info), Ref{Cint}(), A_32, b_32, u_32) end function SciMLBase.solve!(cache::LinearCache, alg::AppleAccelerate32MixedLUFactorization; @@ -318,11 +317,12 @@ function SciMLBase.solve!(cache::LinearCache, alg::AppleAccelerate32MixedLUFacto if cache.isfresh # Get pre-allocated arrays from cacheval - luinst, info, A_32, b_32, u_32, T32, Torig = @get_cacheval(cache, :AppleAccelerate32MixedLUFactorization) - # Copy A to pre-allocated 32-bit array using cached type + luinst, info, A_32, b_32, u_32 = @get_cacheval(cache, :AppleAccelerate32MixedLUFactorization) + # Compute 32-bit type on demand and copy A + T32 = eltype(A) <: Complex ? ComplexF32 : Float32 A_32 .= T32.(A) res = aa_getrf!(A_32; ipiv = luinst.ipiv, info = info) - fact = (LU(res[1:3]...), res[4], A_32, b_32, u_32, T32, Torig) + fact = (LU(res[1:3]...), res[4], A_32, b_32, u_32) cache.cacheval = fact if !LinearAlgebra.issuccess(fact[1]) @@ -332,21 +332,25 @@ function SciMLBase.solve!(cache::LinearCache, alg::AppleAccelerate32MixedLUFacto cache.isfresh = false end - A_lu, info, A_32, b_32, u_32, T32, Torig = @get_cacheval(cache, :AppleAccelerate32MixedLUFactorization) + A_lu, info, A_32, b_32, u_32 = @get_cacheval(cache, :AppleAccelerate32MixedLUFactorization) require_one_based_indexing(cache.u, cache.b) m, n = size(A_lu, 1), size(A_lu, 2) - # Copy b to pre-allocated 32-bit array using cached type + # Compute types on demand for conversions + T32 = eltype(A) <: Complex ? ComplexF32 : Float32 + Torig = eltype(cache.u) + + # Copy b to pre-allocated 32-bit array b_32 .= T32.(cache.b) if m > n aa_getrs!('N', A_lu.factors, A_lu.ipiv, b_32; info) - # Convert back to original precision using cached type + # Convert back to original precision cache.u[1:n] .= Torig.(b_32[1:n]) else copyto!(u_32, b_32) aa_getrs!('N', A_lu.factors, A_lu.ipiv, u_32; info) - # Convert back to original precision using cached type + # Convert back to original precision cache.u .= Torig.(u_32) end diff --git a/src/mkl.jl b/src/mkl.jl index 0eab16140..0453f8f1a 100644 --- a/src/mkl.jl +++ b/src/mkl.jl @@ -283,13 +283,12 @@ function LinearSolve.init_cacheval(alg::MKL32MixedLUFactorization, A, b, u, Pl, # Pre-allocate appropriate 32-bit arrays based on input type m, n = size(A) T32 = eltype(A) <: Complex ? ComplexF32 : Float32 - Torig = eltype(u) A_32 = similar(A, T32) b_32 = similar(b, T32) u_32 = similar(u, T32) luinst = ArrayInterface.lu_instance(rand(T32, 0, 0)) - # Return tuple with pre-allocated arrays and cached types - (luinst, Ref{BlasInt}(), A_32, b_32, u_32, T32, Torig) + # Return tuple with pre-allocated arrays + (luinst, Ref{BlasInt}(), A_32, b_32, u_32) end function SciMLBase.solve!(cache::LinearCache, alg::MKL32MixedLUFactorization; @@ -301,11 +300,12 @@ function SciMLBase.solve!(cache::LinearCache, alg::MKL32MixedLUFactorization; if cache.isfresh # Get pre-allocated arrays from cacheval - luinst, info, A_32, b_32, u_32, T32, Torig = @get_cacheval(cache, :MKL32MixedLUFactorization) - # Copy A to pre-allocated 32-bit array using cached type + luinst, info, A_32, b_32, u_32 = @get_cacheval(cache, :MKL32MixedLUFactorization) + # Compute 32-bit type on demand and copy A + T32 = eltype(A) <: Complex ? ComplexF32 : Float32 A_32 .= T32.(A) res = getrf!(A_32; ipiv = luinst.ipiv, info = info) - fact = (LU(res[1:3]...), res[4], A_32, b_32, u_32, T32, Torig) + fact = (LU(res[1:3]...), res[4], A_32, b_32, u_32) cache.cacheval = fact if !LinearAlgebra.issuccess(fact[1]) @@ -315,21 +315,25 @@ function SciMLBase.solve!(cache::LinearCache, alg::MKL32MixedLUFactorization; cache.isfresh = false end - A_lu, info, A_32, b_32, u_32, T32, Torig = @get_cacheval(cache, :MKL32MixedLUFactorization) + A_lu, info, A_32, b_32, u_32 = @get_cacheval(cache, :MKL32MixedLUFactorization) require_one_based_indexing(cache.u, cache.b) m, n = size(A_lu, 1), size(A_lu, 2) - # Copy b to pre-allocated 32-bit array using cached type + # Compute types on demand for conversions + T32 = eltype(A) <: Complex ? ComplexF32 : Float32 + Torig = eltype(cache.u) + + # Copy b to pre-allocated 32-bit array b_32 .= T32.(cache.b) if m > n getrs!('N', A_lu.factors, A_lu.ipiv, b_32; info) - # Convert back to original precision using cached type + # Convert back to original precision cache.u[1:n] .= Torig.(b_32[1:n]) else copyto!(u_32, b_32) getrs!('N', A_lu.factors, A_lu.ipiv, u_32; info) - # Convert back to original precision using cached type + # Convert back to original precision cache.u .= Torig.(u_32) end diff --git a/src/openblas.jl b/src/openblas.jl index 3830c9a39..96abb6f14 100644 --- a/src/openblas.jl +++ b/src/openblas.jl @@ -308,13 +308,12 @@ function LinearSolve.init_cacheval(alg::OpenBLAS32MixedLUFactorization, A, b, u, # Pre-allocate appropriate 32-bit arrays based on input type m, n = size(A) T32 = eltype(A) <: Complex ? ComplexF32 : Float32 - Torig = eltype(u) A_32 = similar(A, T32) b_32 = similar(b, T32) u_32 = similar(u, T32) luinst = ArrayInterface.lu_instance(rand(T32, 0, 0)) - # Return tuple with pre-allocated arrays and cached types - (luinst, Ref{BlasInt}(), A_32, b_32, u_32, T32, Torig) + # Return tuple with pre-allocated arrays + (luinst, Ref{BlasInt}(), A_32, b_32, u_32) end function SciMLBase.solve!(cache::LinearCache, alg::OpenBLAS32MixedLUFactorization; @@ -326,11 +325,12 @@ function SciMLBase.solve!(cache::LinearCache, alg::OpenBLAS32MixedLUFactorizatio if cache.isfresh # Get pre-allocated arrays from cacheval - luinst, info, A_32, b_32, u_32, T32, Torig = @get_cacheval(cache, :OpenBLAS32MixedLUFactorization) - # Copy A to pre-allocated 32-bit array using cached type + luinst, info, A_32, b_32, u_32 = @get_cacheval(cache, :OpenBLAS32MixedLUFactorization) + # Compute 32-bit type on demand and copy A + T32 = eltype(A) <: Complex ? ComplexF32 : Float32 A_32 .= T32.(A) res = openblas_getrf!(A_32; ipiv = luinst.ipiv, info = info) - fact = (LU(res[1:3]...), res[4], A_32, b_32, u_32, T32, Torig) + fact = (LU(res[1:3]...), res[4], A_32, b_32, u_32) cache.cacheval = fact if !LinearAlgebra.issuccess(fact[1]) @@ -340,21 +340,25 @@ function SciMLBase.solve!(cache::LinearCache, alg::OpenBLAS32MixedLUFactorizatio cache.isfresh = false end - A_lu, info, A_32, b_32, u_32, T32, Torig = @get_cacheval(cache, :OpenBLAS32MixedLUFactorization) + A_lu, info, A_32, b_32, u_32 = @get_cacheval(cache, :OpenBLAS32MixedLUFactorization) require_one_based_indexing(cache.u, cache.b) m, n = size(A_lu, 1), size(A_lu, 2) - # Copy b to pre-allocated 32-bit array using cached type + # Compute types on demand for conversions + T32 = eltype(A) <: Complex ? ComplexF32 : Float32 + Torig = eltype(cache.u) + + # Copy b to pre-allocated 32-bit array b_32 .= T32.(cache.b) if m > n openblas_getrs!('N', A_lu.factors, A_lu.ipiv, b_32; info) - # Convert back to original precision using cached type + # Convert back to original precision cache.u[1:n] .= Torig.(b_32[1:n]) else copyto!(u_32, b_32) openblas_getrs!('N', A_lu.factors, A_lu.ipiv, u_32; info) - # Convert back to original precision using cached type + # Convert back to original precision cache.u .= Torig.(u_32) end