Skip to content

Commit 7492b7f

Browse files
ChrisRackauckas-ClaudeChrisRackauckasclaude
authored
Simplify mixed precision implementations by computing types on demand (#759)
Remove cached T32 and Torig types from init_cacheval return tuples. Instead compute these types on demand in solve! functions to reduce complexity while maintaining zero allocations for subsequent solves. This change affects all mixed precision implementations: - MKL32MixedLUFactorization - OpenBLAS32MixedLUFactorization - AppleAccelerate32MixedLUFactorization - RF32MixedLUFactorization - CUDAOffload32MixedLUFactorization - MetalOffload32MixedLUFactorization 🤖 Generated with [Claude Code](https://claude.ai/code) Co-authored-by: ChrisRackauckas <accounts@chrisrackauckas.com> Co-authored-by: Claude <noreply@anthropic.com>
1 parent ae99918 commit 7492b7f

File tree

6 files changed

+81
-55
lines changed

6 files changed

+81
-55
lines changed

ext/LinearSolveCUDAExt.jl

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -120,15 +120,21 @@ end
120120
function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CUDAOffload32MixedLUFactorization;
121121
kwargs...)
122122
if cache.isfresh
123-
fact, A_gpu_f32, b_gpu_f32, u_gpu_f32, T32, Torig = LinearSolve.@get_cacheval(cache, :CUDAOffload32MixedLUFactorization)
124-
# Convert to Float32 for factorization using cached type
123+
fact, A_gpu_f32, b_gpu_f32, u_gpu_f32 = LinearSolve.@get_cacheval(cache, :CUDAOffload32MixedLUFactorization)
124+
# Compute 32-bit type on demand and convert
125+
T32 = eltype(cache.A) <: Complex ? ComplexF32 : Float32
125126
A_f32 = T32.(cache.A)
126127
copyto!(A_gpu_f32, A_f32)
127128
fact = lu(A_gpu_f32)
128-
cache.cacheval = (fact, A_gpu_f32, b_gpu_f32, u_gpu_f32, T32, Torig)
129+
cache.cacheval = (fact, A_gpu_f32, b_gpu_f32, u_gpu_f32)
129130
cache.isfresh = false
130131
end
131-
fact, A_gpu_f32, b_gpu_f32, u_gpu_f32, T32, Torig = LinearSolve.@get_cacheval(cache, :CUDAOffload32MixedLUFactorization)
132+
fact, A_gpu_f32, b_gpu_f32, u_gpu_f32 = LinearSolve.@get_cacheval(cache, :CUDAOffload32MixedLUFactorization)
133+
134+
# Compute types on demand for conversions
135+
T32 = eltype(cache.A) <: Complex ? ComplexF32 : Float32
136+
Torig = eltype(cache.u)
137+
132138
# Convert b to Float32, solve, then convert back to original precision
133139
b_f32 = T32.(cache.b)
134140
copyto!(b_gpu_f32, b_f32)
@@ -142,10 +148,9 @@ end
142148
function LinearSolve.init_cacheval(alg::CUDAOffload32MixedLUFactorization, A, b, u, Pl, Pr,
143149
maxiters::Int, abstol, reltol, verbose::Bool,
144150
assumptions::OperatorAssumptions)
145-
# Pre-allocate with Float32 arrays and cache types
151+
# Pre-allocate with Float32 arrays
146152
m, n = size(A)
147153
T32 = eltype(A) <: Complex ? ComplexF32 : Float32
148-
Torig = eltype(u)
149154
noUnitT = typeof(zero(T32))
150155
luT = LinearAlgebra.lutype(noUnitT)
151156
ipiv = CuVector{Int32}(undef, min(m, n))
@@ -154,7 +159,7 @@ function LinearSolve.init_cacheval(alg::CUDAOffload32MixedLUFactorization, A, b,
154159
A_gpu_f32 = CuMatrix{T32}(undef, m, n)
155160
b_gpu_f32 = CuVector{T32}(undef, size(b, 1))
156161
u_gpu_f32 = CuVector{T32}(undef, size(u, 1))
157-
return (fact, A_gpu_f32, b_gpu_f32, u_gpu_f32, T32, Torig)
162+
return (fact, A_gpu_f32, b_gpu_f32, u_gpu_f32)
158163
end
159164

160165
end

ext/LinearSolveMetalExt.jl

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,9 @@ default_alias_b(::MetalOffload32MixedLUFactorization, ::Any, ::Any) = false
3636
function LinearSolve.init_cacheval(alg::MetalOffload32MixedLUFactorization, A, b, u, Pl, Pr,
3737
maxiters::Int, abstol, reltol, verbose::Bool,
3838
assumptions::OperatorAssumptions)
39-
# Pre-allocate with Float32 arrays and cache types
39+
# Pre-allocate with Float32 arrays
4040
m, n = size(A)
4141
T32 = eltype(A) <: Complex ? ComplexF32 : Float32
42-
Torig = eltype(u)
4342
A_f32 = similar(A, T32)
4443
b_f32 = similar(b, T32)
4544
u_f32 = similar(u, T32)
@@ -48,34 +47,40 @@ function LinearSolve.init_cacheval(alg::MetalOffload32MixedLUFactorization, A, b
4847
A_mtl = MtlArray{T32}(undef, m, n)
4948
b_mtl = MtlVector{T32}(undef, size(b, 1))
5049
u_mtl = MtlVector{T32}(undef, size(u, 1))
51-
return (luinst, A_f32, b_f32, u_f32, A_mtl, b_mtl, u_mtl, T32, Torig)
50+
return (luinst, A_f32, b_f32, u_f32, A_mtl, b_mtl, u_mtl)
5251
end
5352

5453
function SciMLBase.solve!(cache::LinearCache, alg::MetalOffload32MixedLUFactorization;
5554
kwargs...)
5655
A = cache.A
5756
A = convert(AbstractMatrix, A)
5857
if cache.isfresh
59-
luinst, A_f32, b_f32, u_f32, A_mtl, b_mtl, u_mtl, T32, Torig = @get_cacheval(cache, :MetalOffload32MixedLUFactorization)
60-
# Convert to appropriate 32-bit type for factorization using cached type
58+
luinst, A_f32, b_f32, u_f32, A_mtl, b_mtl, u_mtl = @get_cacheval(cache, :MetalOffload32MixedLUFactorization)
59+
# Compute 32-bit type on demand and convert
60+
T32 = eltype(A) <: Complex ? ComplexF32 : Float32
6161
A_f32 .= T32.(A)
6262
copyto!(A_mtl, A_f32)
6363
res = lu(A_mtl)
6464
# Store factorization and pre-allocated arrays
6565
fact = LU(Array(res.factors), Array{Int}(res.ipiv), res.info)
66-
cache.cacheval = (fact, A_f32, b_f32, u_f32, A_mtl, b_mtl, u_mtl, T32, Torig)
66+
cache.cacheval = (fact, A_f32, b_f32, u_f32, A_mtl, b_mtl, u_mtl)
6767
cache.isfresh = false
6868
end
6969

70-
fact, A_f32, b_f32, u_f32, A_mtl, b_mtl, u_mtl, T32, Torig = @get_cacheval(cache, :MetalOffload32MixedLUFactorization)
71-
# Convert b to 32-bit for solving using cached type
70+
fact, A_f32, b_f32, u_f32, A_mtl, b_mtl, u_mtl = @get_cacheval(cache, :MetalOffload32MixedLUFactorization)
71+
72+
# Compute types on demand for conversions
73+
T32 = eltype(cache.A) <: Complex ? ComplexF32 : Float32
74+
Torig = eltype(cache.u)
75+
76+
# Convert b to 32-bit for solving
7277
b_f32 .= T32.(cache.b)
7378

7479
# Create a temporary Float32 LU factorization for solving
7580
fact_f32 = LU(T32.(fact.factors), fact.ipiv, fact.info)
7681
ldiv!(u_f32, fact_f32, b_f32)
7782

78-
# Convert back to original precision using cached type
83+
# Convert back to original precision
7984
cache.u .= Torig.(u_f32)
8085
SciMLBase.build_linear_solution(alg, cache.u, nothing, cache)
8186
end

ext/LinearSolveRecursiveFactorizationExt.jl

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -47,14 +47,13 @@ function LinearSolve.init_cacheval(alg::RF32MixedLUFactorization{P, T}, A, b, u,
4747
# Pre-allocate appropriate 32-bit arrays based on input type
4848
m, n = size(A)
4949
T32 = eltype(A) <: Complex ? ComplexF32 : Float32
50-
Torig = eltype(u)
5150
A_32 = similar(A, T32)
5251
b_32 = similar(b, T32)
5352
u_32 = similar(u, T32)
5453
luinst = ArrayInterface.lu_instance(rand(T32, 0, 0))
5554
ipiv = Vector{LinearAlgebra.BlasInt}(undef, min(m, n))
56-
# Return tuple with pre-allocated arrays and cached types
57-
(luinst, ipiv, A_32, b_32, u_32, T32, Torig)
55+
# Return tuple with pre-allocated arrays
56+
(luinst, ipiv, A_32, b_32, u_32)
5857
end
5958

6059
function SciMLBase.solve!(
@@ -65,8 +64,9 @@ function SciMLBase.solve!(
6564

6665
if cache.isfresh
6766
# Get pre-allocated arrays from cacheval
68-
luinst, ipiv, A_32, b_32, u_32, T32, Torig = LinearSolve.@get_cacheval(cache, :RF32MixedLUFactorization)
69-
# Copy A to pre-allocated 32-bit array using cached type
67+
luinst, ipiv, A_32, b_32, u_32 = LinearSolve.@get_cacheval(cache, :RF32MixedLUFactorization)
68+
# Compute 32-bit type on demand and copy A
69+
T32 = eltype(A) <: Complex ? ComplexF32 : Float32
7070
A_32 .= T32.(A)
7171

7272
# Ensure ipiv is the right size
@@ -75,7 +75,7 @@ function SciMLBase.solve!(
7575
end
7676

7777
fact = RecursiveFactorization.lu!(A_32, ipiv, Val(P), Val(T), check = false)
78-
cache.cacheval = (fact, ipiv, A_32, b_32, u_32, T32, Torig)
78+
cache.cacheval = (fact, ipiv, A_32, b_32, u_32)
7979

8080
if !LinearAlgebra.issuccess(fact)
8181
return SciMLBase.build_linear_solution(
@@ -86,15 +86,19 @@ function SciMLBase.solve!(
8686
end
8787

8888
# Get the factorization and pre-allocated arrays from the cache
89-
fact_cached, ipiv, A_32, b_32, u_32, T32, Torig = LinearSolve.@get_cacheval(cache, :RF32MixedLUFactorization)
89+
fact_cached, ipiv, A_32, b_32, u_32 = LinearSolve.@get_cacheval(cache, :RF32MixedLUFactorization)
9090

91-
# Copy b to pre-allocated 32-bit array using cached type
91+
# Compute types on demand for conversions
92+
T32 = eltype(cache.A) <: Complex ? ComplexF32 : Float32
93+
Torig = eltype(cache.u)
94+
95+
# Copy b to pre-allocated 32-bit array
9296
b_32 .= T32.(cache.b)
9397

9498
# Solve in 32-bit precision
9599
ldiv!(u_32, fact_cached, b_32)
96100

97-
# Convert back to original precision using cached type
101+
# Convert back to original precision
98102
cache.u .= Torig.(u_32)
99103

100104
SciMLBase.build_linear_solution(

src/appleaccelerate.jl

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -300,13 +300,12 @@ function LinearSolve.init_cacheval(alg::AppleAccelerate32MixedLUFactorization, A
300300
# Pre-allocate appropriate 32-bit arrays based on input type
301301
m, n = size(A)
302302
T32 = eltype(A) <: Complex ? ComplexF32 : Float32
303-
Torig = eltype(u)
304303
A_32 = similar(A, T32)
305304
b_32 = similar(b, T32)
306305
u_32 = similar(u, T32)
307306
luinst = ArrayInterface.lu_instance(rand(T32, 0, 0))
308-
# Return tuple with pre-allocated arrays and cached types
309-
(LU(luinst.factors, similar(A_32, Cint, 0), luinst.info), Ref{Cint}(), A_32, b_32, u_32, T32, Torig)
307+
# Return tuple with pre-allocated arrays
308+
(LU(luinst.factors, similar(A_32, Cint, 0), luinst.info), Ref{Cint}(), A_32, b_32, u_32)
310309
end
311310

312311
function SciMLBase.solve!(cache::LinearCache, alg::AppleAccelerate32MixedLUFactorization;
@@ -318,11 +317,12 @@ function SciMLBase.solve!(cache::LinearCache, alg::AppleAccelerate32MixedLUFacto
318317

319318
if cache.isfresh
320319
# Get pre-allocated arrays from cacheval
321-
luinst, info, A_32, b_32, u_32, T32, Torig = @get_cacheval(cache, :AppleAccelerate32MixedLUFactorization)
322-
# Copy A to pre-allocated 32-bit array using cached type
320+
luinst, info, A_32, b_32, u_32 = @get_cacheval(cache, :AppleAccelerate32MixedLUFactorization)
321+
# Compute 32-bit type on demand and copy A
322+
T32 = eltype(A) <: Complex ? ComplexF32 : Float32
323323
A_32 .= T32.(A)
324324
res = aa_getrf!(A_32; ipiv = luinst.ipiv, info = info)
325-
fact = (LU(res[1:3]...), res[4], A_32, b_32, u_32, T32, Torig)
325+
fact = (LU(res[1:3]...), res[4], A_32, b_32, u_32)
326326
cache.cacheval = fact
327327

328328
if !LinearAlgebra.issuccess(fact[1])
@@ -332,21 +332,25 @@ function SciMLBase.solve!(cache::LinearCache, alg::AppleAccelerate32MixedLUFacto
332332
cache.isfresh = false
333333
end
334334

335-
A_lu, info, A_32, b_32, u_32, T32, Torig = @get_cacheval(cache, :AppleAccelerate32MixedLUFactorization)
335+
A_lu, info, A_32, b_32, u_32 = @get_cacheval(cache, :AppleAccelerate32MixedLUFactorization)
336336
require_one_based_indexing(cache.u, cache.b)
337337
m, n = size(A_lu, 1), size(A_lu, 2)
338338

339-
# Copy b to pre-allocated 32-bit array using cached type
339+
# Compute types on demand for conversions
340+
T32 = eltype(A) <: Complex ? ComplexF32 : Float32
341+
Torig = eltype(cache.u)
342+
343+
# Copy b to pre-allocated 32-bit array
340344
b_32 .= T32.(cache.b)
341345

342346
if m > n
343347
aa_getrs!('N', A_lu.factors, A_lu.ipiv, b_32; info)
344-
# Convert back to original precision using cached type
348+
# Convert back to original precision
345349
cache.u[1:n] .= Torig.(b_32[1:n])
346350
else
347351
copyto!(u_32, b_32)
348352
aa_getrs!('N', A_lu.factors, A_lu.ipiv, u_32; info)
349-
# Convert back to original precision using cached type
353+
# Convert back to original precision
350354
cache.u .= Torig.(u_32)
351355
end
352356

src/mkl.jl

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -283,13 +283,12 @@ function LinearSolve.init_cacheval(alg::MKL32MixedLUFactorization, A, b, u, Pl,
283283
# Pre-allocate appropriate 32-bit arrays based on input type
284284
m, n = size(A)
285285
T32 = eltype(A) <: Complex ? ComplexF32 : Float32
286-
Torig = eltype(u)
287286
A_32 = similar(A, T32)
288287
b_32 = similar(b, T32)
289288
u_32 = similar(u, T32)
290289
luinst = ArrayInterface.lu_instance(rand(T32, 0, 0))
291-
# Return tuple with pre-allocated arrays and cached types
292-
(luinst, Ref{BlasInt}(), A_32, b_32, u_32, T32, Torig)
290+
# Return tuple with pre-allocated arrays
291+
(luinst, Ref{BlasInt}(), A_32, b_32, u_32)
293292
end
294293

295294
function SciMLBase.solve!(cache::LinearCache, alg::MKL32MixedLUFactorization;
@@ -301,11 +300,12 @@ function SciMLBase.solve!(cache::LinearCache, alg::MKL32MixedLUFactorization;
301300

302301
if cache.isfresh
303302
# Get pre-allocated arrays from cacheval
304-
luinst, info, A_32, b_32, u_32, T32, Torig = @get_cacheval(cache, :MKL32MixedLUFactorization)
305-
# Copy A to pre-allocated 32-bit array using cached type
303+
luinst, info, A_32, b_32, u_32 = @get_cacheval(cache, :MKL32MixedLUFactorization)
304+
# Compute 32-bit type on demand and copy A
305+
T32 = eltype(A) <: Complex ? ComplexF32 : Float32
306306
A_32 .= T32.(A)
307307
res = getrf!(A_32; ipiv = luinst.ipiv, info = info)
308-
fact = (LU(res[1:3]...), res[4], A_32, b_32, u_32, T32, Torig)
308+
fact = (LU(res[1:3]...), res[4], A_32, b_32, u_32)
309309
cache.cacheval = fact
310310

311311
if !LinearAlgebra.issuccess(fact[1])
@@ -315,21 +315,25 @@ function SciMLBase.solve!(cache::LinearCache, alg::MKL32MixedLUFactorization;
315315
cache.isfresh = false
316316
end
317317

318-
A_lu, info, A_32, b_32, u_32, T32, Torig = @get_cacheval(cache, :MKL32MixedLUFactorization)
318+
A_lu, info, A_32, b_32, u_32 = @get_cacheval(cache, :MKL32MixedLUFactorization)
319319
require_one_based_indexing(cache.u, cache.b)
320320
m, n = size(A_lu, 1), size(A_lu, 2)
321321

322-
# Copy b to pre-allocated 32-bit array using cached type
322+
# Compute types on demand for conversions
323+
T32 = eltype(A) <: Complex ? ComplexF32 : Float32
324+
Torig = eltype(cache.u)
325+
326+
# Copy b to pre-allocated 32-bit array
323327
b_32 .= T32.(cache.b)
324328

325329
if m > n
326330
getrs!('N', A_lu.factors, A_lu.ipiv, b_32; info)
327-
# Convert back to original precision using cached type
331+
# Convert back to original precision
328332
cache.u[1:n] .= Torig.(b_32[1:n])
329333
else
330334
copyto!(u_32, b_32)
331335
getrs!('N', A_lu.factors, A_lu.ipiv, u_32; info)
332-
# Convert back to original precision using cached type
336+
# Convert back to original precision
333337
cache.u .= Torig.(u_32)
334338
end
335339

src/openblas.jl

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -308,13 +308,12 @@ function LinearSolve.init_cacheval(alg::OpenBLAS32MixedLUFactorization, A, b, u,
308308
# Pre-allocate appropriate 32-bit arrays based on input type
309309
m, n = size(A)
310310
T32 = eltype(A) <: Complex ? ComplexF32 : Float32
311-
Torig = eltype(u)
312311
A_32 = similar(A, T32)
313312
b_32 = similar(b, T32)
314313
u_32 = similar(u, T32)
315314
luinst = ArrayInterface.lu_instance(rand(T32, 0, 0))
316-
# Return tuple with pre-allocated arrays and cached types
317-
(luinst, Ref{BlasInt}(), A_32, b_32, u_32, T32, Torig)
315+
# Return tuple with pre-allocated arrays
316+
(luinst, Ref{BlasInt}(), A_32, b_32, u_32)
318317
end
319318

320319
function SciMLBase.solve!(cache::LinearCache, alg::OpenBLAS32MixedLUFactorization;
@@ -326,11 +325,12 @@ function SciMLBase.solve!(cache::LinearCache, alg::OpenBLAS32MixedLUFactorizatio
326325

327326
if cache.isfresh
328327
# Get pre-allocated arrays from cacheval
329-
luinst, info, A_32, b_32, u_32, T32, Torig = @get_cacheval(cache, :OpenBLAS32MixedLUFactorization)
330-
# Copy A to pre-allocated 32-bit array using cached type
328+
luinst, info, A_32, b_32, u_32 = @get_cacheval(cache, :OpenBLAS32MixedLUFactorization)
329+
# Compute 32-bit type on demand and copy A
330+
T32 = eltype(A) <: Complex ? ComplexF32 : Float32
331331
A_32 .= T32.(A)
332332
res = openblas_getrf!(A_32; ipiv = luinst.ipiv, info = info)
333-
fact = (LU(res[1:3]...), res[4], A_32, b_32, u_32, T32, Torig)
333+
fact = (LU(res[1:3]...), res[4], A_32, b_32, u_32)
334334
cache.cacheval = fact
335335

336336
if !LinearAlgebra.issuccess(fact[1])
@@ -340,21 +340,25 @@ function SciMLBase.solve!(cache::LinearCache, alg::OpenBLAS32MixedLUFactorizatio
340340
cache.isfresh = false
341341
end
342342

343-
A_lu, info, A_32, b_32, u_32, T32, Torig = @get_cacheval(cache, :OpenBLAS32MixedLUFactorization)
343+
A_lu, info, A_32, b_32, u_32 = @get_cacheval(cache, :OpenBLAS32MixedLUFactorization)
344344
require_one_based_indexing(cache.u, cache.b)
345345
m, n = size(A_lu, 1), size(A_lu, 2)
346346

347-
# Copy b to pre-allocated 32-bit array using cached type
347+
# Compute types on demand for conversions
348+
T32 = eltype(A) <: Complex ? ComplexF32 : Float32
349+
Torig = eltype(cache.u)
350+
351+
# Copy b to pre-allocated 32-bit array
348352
b_32 .= T32.(cache.b)
349353

350354
if m > n
351355
openblas_getrs!('N', A_lu.factors, A_lu.ipiv, b_32; info)
352-
# Convert back to original precision using cached type
356+
# Convert back to original precision
353357
cache.u[1:n] .= Torig.(b_32[1:n])
354358
else
355359
copyto!(u_32, b_32)
356360
openblas_getrs!('N', A_lu.factors, A_lu.ipiv, u_32; info)
357-
# Convert back to original precision using cached type
361+
# Convert back to original precision
358362
cache.u .= Torig.(u_32)
359363
end
360364

0 commit comments

Comments
 (0)