Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions ext/LinearSolveCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,15 +120,21 @@ end
function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CUDAOffload32MixedLUFactorization;
kwargs...)
if cache.isfresh
fact, A_gpu_f32, b_gpu_f32, u_gpu_f32, T32, Torig = LinearSolve.@get_cacheval(cache, :CUDAOffload32MixedLUFactorization)
# Convert to Float32 for factorization using cached type
fact, A_gpu_f32, b_gpu_f32, u_gpu_f32 = LinearSolve.@get_cacheval(cache, :CUDAOffload32MixedLUFactorization)
# Compute 32-bit type on demand and convert
T32 = eltype(cache.A) <: Complex ? ComplexF32 : Float32
A_f32 = T32.(cache.A)
copyto!(A_gpu_f32, A_f32)
fact = lu(A_gpu_f32)
cache.cacheval = (fact, A_gpu_f32, b_gpu_f32, u_gpu_f32, T32, Torig)
cache.cacheval = (fact, A_gpu_f32, b_gpu_f32, u_gpu_f32)
cache.isfresh = false
end
fact, A_gpu_f32, b_gpu_f32, u_gpu_f32, T32, Torig = LinearSolve.@get_cacheval(cache, :CUDAOffload32MixedLUFactorization)
fact, A_gpu_f32, b_gpu_f32, u_gpu_f32 = LinearSolve.@get_cacheval(cache, :CUDAOffload32MixedLUFactorization)

# Compute types on demand for conversions
T32 = eltype(cache.A) <: Complex ? ComplexF32 : Float32
Torig = eltype(cache.u)

# Convert b to Float32, solve, then convert back to original precision
b_f32 = T32.(cache.b)
copyto!(b_gpu_f32, b_f32)
Expand All @@ -142,10 +148,9 @@ end
function LinearSolve.init_cacheval(alg::CUDAOffload32MixedLUFactorization, A, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
# Pre-allocate with Float32 arrays and cache types
# Pre-allocate with Float32 arrays
m, n = size(A)
T32 = eltype(A) <: Complex ? ComplexF32 : Float32
Torig = eltype(u)
noUnitT = typeof(zero(T32))
luT = LinearAlgebra.lutype(noUnitT)
ipiv = CuVector{Int32}(undef, min(m, n))
Expand All @@ -154,7 +159,7 @@ function LinearSolve.init_cacheval(alg::CUDAOffload32MixedLUFactorization, A, b,
A_gpu_f32 = CuMatrix{T32}(undef, m, n)
b_gpu_f32 = CuVector{T32}(undef, size(b, 1))
u_gpu_f32 = CuVector{T32}(undef, size(u, 1))
return (fact, A_gpu_f32, b_gpu_f32, u_gpu_f32, T32, Torig)
return (fact, A_gpu_f32, b_gpu_f32, u_gpu_f32)
end

end
23 changes: 14 additions & 9 deletions ext/LinearSolveMetalExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,9 @@ default_alias_b(::MetalOffload32MixedLUFactorization, ::Any, ::Any) = false
function LinearSolve.init_cacheval(alg::MetalOffload32MixedLUFactorization, A, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
# Pre-allocate with Float32 arrays and cache types
# Pre-allocate with Float32 arrays
m, n = size(A)
T32 = eltype(A) <: Complex ? ComplexF32 : Float32
Torig = eltype(u)
A_f32 = similar(A, T32)
b_f32 = similar(b, T32)
u_f32 = similar(u, T32)
Expand All @@ -48,34 +47,40 @@ function LinearSolve.init_cacheval(alg::MetalOffload32MixedLUFactorization, A, b
A_mtl = MtlArray{T32}(undef, m, n)
b_mtl = MtlVector{T32}(undef, size(b, 1))
u_mtl = MtlVector{T32}(undef, size(u, 1))
return (luinst, A_f32, b_f32, u_f32, A_mtl, b_mtl, u_mtl, T32, Torig)
return (luinst, A_f32, b_f32, u_f32, A_mtl, b_mtl, u_mtl)
end

function SciMLBase.solve!(cache::LinearCache, alg::MetalOffload32MixedLUFactorization;
kwargs...)
A = cache.A
A = convert(AbstractMatrix, A)
if cache.isfresh
luinst, A_f32, b_f32, u_f32, A_mtl, b_mtl, u_mtl, T32, Torig = @get_cacheval(cache, :MetalOffload32MixedLUFactorization)
# Convert to appropriate 32-bit type for factorization using cached type
luinst, A_f32, b_f32, u_f32, A_mtl, b_mtl, u_mtl = @get_cacheval(cache, :MetalOffload32MixedLUFactorization)
# Compute 32-bit type on demand and convert
T32 = eltype(A) <: Complex ? ComplexF32 : Float32
A_f32 .= T32.(A)
copyto!(A_mtl, A_f32)
res = lu(A_mtl)
# Store factorization and pre-allocated arrays
fact = LU(Array(res.factors), Array{Int}(res.ipiv), res.info)
cache.cacheval = (fact, A_f32, b_f32, u_f32, A_mtl, b_mtl, u_mtl, T32, Torig)
cache.cacheval = (fact, A_f32, b_f32, u_f32, A_mtl, b_mtl, u_mtl)
cache.isfresh = false
end

fact, A_f32, b_f32, u_f32, A_mtl, b_mtl, u_mtl, T32, Torig = @get_cacheval(cache, :MetalOffload32MixedLUFactorization)
# Convert b to 32-bit for solving using cached type
fact, A_f32, b_f32, u_f32, A_mtl, b_mtl, u_mtl = @get_cacheval(cache, :MetalOffload32MixedLUFactorization)

# Compute types on demand for conversions
T32 = eltype(cache.A) <: Complex ? ComplexF32 : Float32
Torig = eltype(cache.u)

# Convert b to 32-bit for solving
b_f32 .= T32.(cache.b)

# Create a temporary Float32 LU factorization for solving
fact_f32 = LU(T32.(fact.factors), fact.ipiv, fact.info)
ldiv!(u_f32, fact_f32, b_f32)

# Convert back to original precision using cached type
# Convert back to original precision
cache.u .= Torig.(u_f32)
SciMLBase.build_linear_solution(alg, cache.u, nothing, cache)
end
Expand Down
22 changes: 13 additions & 9 deletions ext/LinearSolveRecursiveFactorizationExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,13 @@ function LinearSolve.init_cacheval(alg::RF32MixedLUFactorization{P, T}, A, b, u,
# Pre-allocate appropriate 32-bit arrays based on input type
m, n = size(A)
T32 = eltype(A) <: Complex ? ComplexF32 : Float32
Torig = eltype(u)
A_32 = similar(A, T32)
b_32 = similar(b, T32)
u_32 = similar(u, T32)
luinst = ArrayInterface.lu_instance(rand(T32, 0, 0))
ipiv = Vector{LinearAlgebra.BlasInt}(undef, min(m, n))
# Return tuple with pre-allocated arrays and cached types
(luinst, ipiv, A_32, b_32, u_32, T32, Torig)
# Return tuple with pre-allocated arrays
(luinst, ipiv, A_32, b_32, u_32)
end

function SciMLBase.solve!(
Expand All @@ -65,8 +64,9 @@ function SciMLBase.solve!(

if cache.isfresh
# Get pre-allocated arrays from cacheval
luinst, ipiv, A_32, b_32, u_32, T32, Torig = LinearSolve.@get_cacheval(cache, :RF32MixedLUFactorization)
# Copy A to pre-allocated 32-bit array using cached type
luinst, ipiv, A_32, b_32, u_32 = LinearSolve.@get_cacheval(cache, :RF32MixedLUFactorization)
# Compute 32-bit type on demand and copy A
T32 = eltype(A) <: Complex ? ComplexF32 : Float32
A_32 .= T32.(A)

# Ensure ipiv is the right size
Expand All @@ -75,7 +75,7 @@ function SciMLBase.solve!(
end

fact = RecursiveFactorization.lu!(A_32, ipiv, Val(P), Val(T), check = false)
cache.cacheval = (fact, ipiv, A_32, b_32, u_32, T32, Torig)
cache.cacheval = (fact, ipiv, A_32, b_32, u_32)

if !LinearAlgebra.issuccess(fact)
return SciMLBase.build_linear_solution(
Expand All @@ -86,15 +86,19 @@ function SciMLBase.solve!(
end

# Get the factorization and pre-allocated arrays from the cache
fact_cached, ipiv, A_32, b_32, u_32, T32, Torig = LinearSolve.@get_cacheval(cache, :RF32MixedLUFactorization)
fact_cached, ipiv, A_32, b_32, u_32 = LinearSolve.@get_cacheval(cache, :RF32MixedLUFactorization)

# Copy b to pre-allocated 32-bit array using cached type
# Compute types on demand for conversions
T32 = eltype(cache.A) <: Complex ? ComplexF32 : Float32
Torig = eltype(cache.u)

# Copy b to pre-allocated 32-bit array
b_32 .= T32.(cache.b)

# Solve in 32-bit precision
ldiv!(u_32, fact_cached, b_32)

# Convert back to original precision using cached type
# Convert back to original precision
cache.u .= Torig.(u_32)

SciMLBase.build_linear_solution(
Expand Down
24 changes: 14 additions & 10 deletions src/appleaccelerate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -300,13 +300,12 @@ function LinearSolve.init_cacheval(alg::AppleAccelerate32MixedLUFactorization, A
# Pre-allocate appropriate 32-bit arrays based on input type
m, n = size(A)
T32 = eltype(A) <: Complex ? ComplexF32 : Float32
Torig = eltype(u)
A_32 = similar(A, T32)
b_32 = similar(b, T32)
u_32 = similar(u, T32)
luinst = ArrayInterface.lu_instance(rand(T32, 0, 0))
# Return tuple with pre-allocated arrays and cached types
(LU(luinst.factors, similar(A_32, Cint, 0), luinst.info), Ref{Cint}(), A_32, b_32, u_32, T32, Torig)
# Return tuple with pre-allocated arrays
(LU(luinst.factors, similar(A_32, Cint, 0), luinst.info), Ref{Cint}(), A_32, b_32, u_32)
end

function SciMLBase.solve!(cache::LinearCache, alg::AppleAccelerate32MixedLUFactorization;
Expand All @@ -318,11 +317,12 @@ function SciMLBase.solve!(cache::LinearCache, alg::AppleAccelerate32MixedLUFacto

if cache.isfresh
# Get pre-allocated arrays from cacheval
luinst, info, A_32, b_32, u_32, T32, Torig = @get_cacheval(cache, :AppleAccelerate32MixedLUFactorization)
# Copy A to pre-allocated 32-bit array using cached type
luinst, info, A_32, b_32, u_32 = @get_cacheval(cache, :AppleAccelerate32MixedLUFactorization)
# Compute 32-bit type on demand and copy A
T32 = eltype(A) <: Complex ? ComplexF32 : Float32
A_32 .= T32.(A)
res = aa_getrf!(A_32; ipiv = luinst.ipiv, info = info)
fact = (LU(res[1:3]...), res[4], A_32, b_32, u_32, T32, Torig)
fact = (LU(res[1:3]...), res[4], A_32, b_32, u_32)
cache.cacheval = fact

if !LinearAlgebra.issuccess(fact[1])
Expand All @@ -332,21 +332,25 @@ function SciMLBase.solve!(cache::LinearCache, alg::AppleAccelerate32MixedLUFacto
cache.isfresh = false
end

A_lu, info, A_32, b_32, u_32, T32, Torig = @get_cacheval(cache, :AppleAccelerate32MixedLUFactorization)
A_lu, info, A_32, b_32, u_32 = @get_cacheval(cache, :AppleAccelerate32MixedLUFactorization)
require_one_based_indexing(cache.u, cache.b)
m, n = size(A_lu, 1), size(A_lu, 2)

# Copy b to pre-allocated 32-bit array using cached type
# Compute types on demand for conversions
T32 = eltype(A) <: Complex ? ComplexF32 : Float32
Torig = eltype(cache.u)

# Copy b to pre-allocated 32-bit array
b_32 .= T32.(cache.b)

if m > n
aa_getrs!('N', A_lu.factors, A_lu.ipiv, b_32; info)
# Convert back to original precision using cached type
# Convert back to original precision
cache.u[1:n] .= Torig.(b_32[1:n])
else
copyto!(u_32, b_32)
aa_getrs!('N', A_lu.factors, A_lu.ipiv, u_32; info)
# Convert back to original precision using cached type
# Convert back to original precision
cache.u .= Torig.(u_32)
end

Expand Down
24 changes: 14 additions & 10 deletions src/mkl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -283,13 +283,12 @@ function LinearSolve.init_cacheval(alg::MKL32MixedLUFactorization, A, b, u, Pl,
# Pre-allocate appropriate 32-bit arrays based on input type
m, n = size(A)
T32 = eltype(A) <: Complex ? ComplexF32 : Float32
Torig = eltype(u)
A_32 = similar(A, T32)
b_32 = similar(b, T32)
u_32 = similar(u, T32)
luinst = ArrayInterface.lu_instance(rand(T32, 0, 0))
# Return tuple with pre-allocated arrays and cached types
(luinst, Ref{BlasInt}(), A_32, b_32, u_32, T32, Torig)
# Return tuple with pre-allocated arrays
(luinst, Ref{BlasInt}(), A_32, b_32, u_32)
end

function SciMLBase.solve!(cache::LinearCache, alg::MKL32MixedLUFactorization;
Expand All @@ -301,11 +300,12 @@ function SciMLBase.solve!(cache::LinearCache, alg::MKL32MixedLUFactorization;

if cache.isfresh
# Get pre-allocated arrays from cacheval
luinst, info, A_32, b_32, u_32, T32, Torig = @get_cacheval(cache, :MKL32MixedLUFactorization)
# Copy A to pre-allocated 32-bit array using cached type
luinst, info, A_32, b_32, u_32 = @get_cacheval(cache, :MKL32MixedLUFactorization)
# Compute 32-bit type on demand and copy A
T32 = eltype(A) <: Complex ? ComplexF32 : Float32
A_32 .= T32.(A)
res = getrf!(A_32; ipiv = luinst.ipiv, info = info)
fact = (LU(res[1:3]...), res[4], A_32, b_32, u_32, T32, Torig)
fact = (LU(res[1:3]...), res[4], A_32, b_32, u_32)
cache.cacheval = fact

if !LinearAlgebra.issuccess(fact[1])
Expand All @@ -315,21 +315,25 @@ function SciMLBase.solve!(cache::LinearCache, alg::MKL32MixedLUFactorization;
cache.isfresh = false
end

A_lu, info, A_32, b_32, u_32, T32, Torig = @get_cacheval(cache, :MKL32MixedLUFactorization)
A_lu, info, A_32, b_32, u_32 = @get_cacheval(cache, :MKL32MixedLUFactorization)
require_one_based_indexing(cache.u, cache.b)
m, n = size(A_lu, 1), size(A_lu, 2)

# Copy b to pre-allocated 32-bit array using cached type
# Compute types on demand for conversions
T32 = eltype(A) <: Complex ? ComplexF32 : Float32
Torig = eltype(cache.u)

# Copy b to pre-allocated 32-bit array
b_32 .= T32.(cache.b)

if m > n
getrs!('N', A_lu.factors, A_lu.ipiv, b_32; info)
# Convert back to original precision using cached type
# Convert back to original precision
cache.u[1:n] .= Torig.(b_32[1:n])
else
copyto!(u_32, b_32)
getrs!('N', A_lu.factors, A_lu.ipiv, u_32; info)
# Convert back to original precision using cached type
# Convert back to original precision
cache.u .= Torig.(u_32)
end

Expand Down
24 changes: 14 additions & 10 deletions src/openblas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -308,13 +308,12 @@ function LinearSolve.init_cacheval(alg::OpenBLAS32MixedLUFactorization, A, b, u,
# Pre-allocate appropriate 32-bit arrays based on input type
m, n = size(A)
T32 = eltype(A) <: Complex ? ComplexF32 : Float32
Torig = eltype(u)
A_32 = similar(A, T32)
b_32 = similar(b, T32)
u_32 = similar(u, T32)
luinst = ArrayInterface.lu_instance(rand(T32, 0, 0))
# Return tuple with pre-allocated arrays and cached types
(luinst, Ref{BlasInt}(), A_32, b_32, u_32, T32, Torig)
# Return tuple with pre-allocated arrays
(luinst, Ref{BlasInt}(), A_32, b_32, u_32)
end

function SciMLBase.solve!(cache::LinearCache, alg::OpenBLAS32MixedLUFactorization;
Expand All @@ -326,11 +325,12 @@ function SciMLBase.solve!(cache::LinearCache, alg::OpenBLAS32MixedLUFactorizatio

if cache.isfresh
# Get pre-allocated arrays from cacheval
luinst, info, A_32, b_32, u_32, T32, Torig = @get_cacheval(cache, :OpenBLAS32MixedLUFactorization)
# Copy A to pre-allocated 32-bit array using cached type
luinst, info, A_32, b_32, u_32 = @get_cacheval(cache, :OpenBLAS32MixedLUFactorization)
# Compute 32-bit type on demand and copy A
T32 = eltype(A) <: Complex ? ComplexF32 : Float32
A_32 .= T32.(A)
res = openblas_getrf!(A_32; ipiv = luinst.ipiv, info = info)
fact = (LU(res[1:3]...), res[4], A_32, b_32, u_32, T32, Torig)
fact = (LU(res[1:3]...), res[4], A_32, b_32, u_32)
cache.cacheval = fact

if !LinearAlgebra.issuccess(fact[1])
Expand All @@ -340,21 +340,25 @@ function SciMLBase.solve!(cache::LinearCache, alg::OpenBLAS32MixedLUFactorizatio
cache.isfresh = false
end

A_lu, info, A_32, b_32, u_32, T32, Torig = @get_cacheval(cache, :OpenBLAS32MixedLUFactorization)
A_lu, info, A_32, b_32, u_32 = @get_cacheval(cache, :OpenBLAS32MixedLUFactorization)
require_one_based_indexing(cache.u, cache.b)
m, n = size(A_lu, 1), size(A_lu, 2)

# Copy b to pre-allocated 32-bit array using cached type
# Compute types on demand for conversions
T32 = eltype(A) <: Complex ? ComplexF32 : Float32
Torig = eltype(cache.u)

# Copy b to pre-allocated 32-bit array
b_32 .= T32.(cache.b)

if m > n
openblas_getrs!('N', A_lu.factors, A_lu.ipiv, b_32; info)
# Convert back to original precision using cached type
# Convert back to original precision
cache.u[1:n] .= Torig.(b_32[1:n])
else
copyto!(u_32, b_32)
openblas_getrs!('N', A_lu.factors, A_lu.ipiv, u_32; info)
# Convert back to original precision using cached type
# Convert back to original precision
cache.u .= Torig.(u_32)
end

Expand Down
Loading