Skip to content

Commit 0bdf24f

Browse files
authored
Move reinterpret-based optimization for complex matrix * real vec/mat to lower level. (#44052)
1 parent 33a71b7 commit 0bdf24f

File tree

2 files changed

+113
-76
lines changed

2 files changed

+113
-76
lines changed

stdlib/LinearAlgebra/src/matmul.jl

Lines changed: 82 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -65,23 +65,16 @@ end
6565
alpha::Number, beta::Number) where {T<:BlasFloat} =
6666
gemv!(y, 'N', A, x, alpha, beta)
6767

68-
# Complex matrix times real vector. Reinterpret the matrix as a real matrix and do real matvec compuation.
69-
for elty in (Float32, Float64)
70-
@eval begin
71-
@inline function mul!(y::StridedVector{Complex{$elty}}, A::StridedVecOrMat{Complex{$elty}}, x::StridedVector{$elty},
72-
alpha::Real, beta::Real)
73-
Afl = reinterpret($elty, A)
74-
yfl = reinterpret($elty, y)
75-
mul!(yfl, Afl, x, alpha, beta)
76-
return y
77-
end
78-
end
79-
end
68+
# Complex matrix times real vector.
69+
# Reinterpret the matrix as a real matrix and do real matvec compuation.
70+
@inline mul!(y::StridedVector{Complex{T}}, A::StridedVecOrMat{Complex{T}}, x::StridedVector{T},
71+
alpha::Number, beta::Number) where {T<:BlasReal} =
72+
gemv!(y, 'N', A, x, alpha, beta)
8073

8174
# Real matrix times complex vector.
8275
# Multiply the matrix with the real and imaginary parts separately
8376
@inline mul!(y::StridedVector{Complex{T}}, A::StridedMaybeAdjOrTransMat{T}, x::StridedVector{Complex{T}},
84-
alpha::Number, beta::Number) where {T<:BlasFloat} =
77+
alpha::Number, beta::Number) where {T<:BlasReal} =
8578
gemv!(y, A isa StridedArray ? 'N' : 'T', A isa StridedArray ? A : parent(A), x, alpha, beta)
8679

8780
@inline mul!(y::AbstractVector, A::AbstractVecOrMat, x::AbstractVector,
@@ -192,18 +185,6 @@ end
192185
(*)(A::AdjOrTransStridedMat{<:BlasReal}, B::StridedMatrix{<:BlasComplex}) = copy(transpose(transpose(B) * parent(A)))
193186
(*)(A::StridedMaybeAdjOrTransMat{<:BlasReal}, B::AdjOrTransStridedMat{<:BlasComplex}) = copy(wrapperop(B)(parent(B) * transpose(A)))
194187

195-
for elty in (Float32,Float64)
196-
@eval begin
197-
@inline function mul!(C::StridedMatrix{Complex{$elty}}, A::StridedVecOrMat{Complex{$elty}}, B::StridedVecOrMat{$elty},
198-
alpha::Real, beta::Real)
199-
Afl = reinterpret($elty, A)
200-
Cfl = reinterpret($elty, C)
201-
mul!(Cfl, Afl, B, alpha, beta)
202-
return C
203-
end
204-
end
205-
end
206-
207188
"""
208189
muladd(A, y, z)
209190
@@ -410,18 +391,14 @@ end
410391
return gemm_wrapper!(C, 'N', 'T', A, B, MulAddMul(alpha, beta))
411392
end
412393
end
413-
# Complex matrix times transposed real matrix. Reinterpret the first matrix to real for efficiency.
414-
for elty in (Float32,Float64)
415-
@eval begin
416-
@inline function mul!(C::StridedMatrix{Complex{$elty}}, A::StridedVecOrMat{Complex{$elty}}, tB::Transpose{<:Any,<:StridedVecOrMat{$elty}},
417-
alpha::Real, beta::Real)
418-
Afl = reinterpret($elty, A)
419-
Cfl = reinterpret($elty, C)
420-
mul!(Cfl, Afl, tB, alpha, beta)
421-
return C
422-
end
423-
end
424-
end
394+
# Complex matrix times (transposed) real matrix. Reinterpret the first matrix to real for efficiency.
395+
@inline mul!(C::StridedMatrix{Complex{T}}, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T},
396+
alpha::Number, beta::Number) where {T<:BlasReal} =
397+
gemm_wrapper!(C, 'N', 'N', A, B, MulAddMul(alpha, beta))
398+
@inline mul!(C::StridedMatrix{Complex{T}}, A::StridedVecOrMat{Complex{T}}, tB::Transpose{<:Any,<:StridedVecOrMat{T}},
399+
alpha::Number, beta::Number) where {T<:BlasReal} =
400+
gemm_wrapper!(C, 'N', 'T', A, parent(tB), MulAddMul(alpha, beta))
401+
425402
# collapsing the following two defs with C::AbstractVecOrMat yields ambiguities
426403
@inline mul!(C::AbstractVector, A::AbstractVecOrMat, tB::Transpose{<:Any,<:AbstractVecOrMat},
427404
alpha::Number, beta::Number) =
@@ -513,22 +490,36 @@ end
513490
function gemv!(y::StridedVector{T}, tA::AbstractChar, A::StridedVecOrMat{T}, x::StridedVector{T},
514491
α::Number=true, β::Number=false) where {T<:BlasFloat}
515492
mA, nA = lapack_size(tA, A)
516-
if nA != length(x)
493+
nA != length(x) &&
517494
throw(DimensionMismatch(lazy"second dimension of A, $nA, does not match length of x, $(length(x))"))
518-
end
519-
if mA != length(y)
495+
mA != length(y) &&
520496
throw(DimensionMismatch(lazy"first dimension of A, $mA, does not match length of y, $(length(y))"))
497+
mA == 0 && return y
498+
nA == 0 && return _rmul_or_fill!(y, β)
499+
alpha, beta = promote(α, β, zero(T))
500+
if alpha isa Union{Bool,T} && beta isa Union{Bool,T} &&
501+
stride(A, 1) == 1 && stride(A, 2) >= size(A, 1)
502+
return BLAS.gemv!(tA, alpha, A, x, beta, y)
503+
else
504+
return generic_matvecmul!(y, tA, A, x, MulAddMul(α, β))
521505
end
522-
if mA == 0
523-
return y
524-
end
525-
if nA == 0
526-
return _rmul_or_fill!(y, β)
527-
end
506+
end
528507

508+
function gemv!(y::StridedVector{Complex{T}}, tA::AbstractChar, A::StridedVecOrMat{Complex{T}}, x::StridedVector{T},
509+
α::Number = true, β::Number = false) where {T<:BlasReal}
510+
mA, nA = lapack_size(tA, A)
511+
nA != length(x) &&
512+
throw(DimensionMismatch(lazy"second dimension of A, $nA, does not match length of x, $(length(x))"))
513+
mA != length(y) &&
514+
throw(DimensionMismatch(lazy"first dimension of A, $mA, does not match length of y, $(length(y))"))
515+
mA == 0 && return y
516+
nA == 0 && return _rmul_or_fill!(y, β)
529517
alpha, beta = promote(α, β, zero(T))
530-
if alpha isa Union{Bool,T} && beta isa Union{Bool,T} && stride(A, 1) == 1 && stride(A, 2) >= size(A, 1)
531-
return BLAS.gemv!(tA, alpha, A, x, beta, y)
518+
if alpha isa Union{Bool,T} && beta isa Union{Bool,T} &&
519+
stride(A, 1) == 1 && stride(A, 2) >= size(A, 1) &&
520+
stride(y, 1) == 1 && tA == 'N' # reinterpret-based optimization is valid only for contiguous `y`
521+
BLAS.gemv!(tA, alpha, reinterpret(T, A), x, beta, reinterpret(T, y))
522+
return y
532523
else
533524
return generic_matvecmul!(y, tA, A, x, MulAddMul(α, β))
534525
end
@@ -681,6 +672,49 @@ function gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar
681672
generic_matmatmul!(C, tA, tB, A, B, _add)
682673
end
683674

675+
function gemm_wrapper!(C::StridedVecOrMat{Complex{T}}, tA::AbstractChar, tB::AbstractChar,
676+
A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T},
677+
_add = MulAddMul()) where {T<:BlasReal}
678+
mA, nA = lapack_size(tA, A)
679+
mB, nB = lapack_size(tB, B)
680+
681+
if nA != mB
682+
throw(DimensionMismatch(lazy"A has dimensions ($mA,$nA) but B has dimensions ($mB,$nB)"))
683+
end
684+
685+
if C === A || B === C
686+
throw(ArgumentError("output matrix must not be aliased with input matrix"))
687+
end
688+
689+
if mA == 0 || nA == 0 || nB == 0 || iszero(_add.alpha)
690+
if size(C) != (mA, nB)
691+
throw(DimensionMismatch(lazy"C has dimensions $(size(C)), should have ($mA,$nB)"))
692+
end
693+
return _rmul_or_fill!(C, _add.beta)
694+
end
695+
696+
if mA == 2 && nA == 2 && nB == 2
697+
return matmul2x2!(C, tA, tB, A, B, _add)
698+
end
699+
if mA == 3 && nA == 3 && nB == 3
700+
return matmul3x3!(C, tA, tB, A, B, _add)
701+
end
702+
703+
alpha, beta = promote(_add.alpha, _add.beta, zero(T))
704+
705+
# Make-sure reinterpret-based optimization is BLAS-compatible.
706+
if (alpha isa Union{Bool,T} &&
707+
beta isa Union{Bool,T} &&
708+
stride(A, 1) == stride(B, 1) == stride(C, 1) == 1 &&
709+
stride(A, 2) >= size(A, 1) &&
710+
stride(B, 2) >= size(B, 1) &&
711+
stride(C, 2) >= size(C, 1)) && tA == 'N'
712+
BLAS.gemm!(tA, tB, alpha, reinterpret(T, A), B, beta, reinterpret(T, C))
713+
return C
714+
end
715+
generic_matmatmul!(C, tA, tB, A, B, _add)
716+
end
717+
684718
# blas.jl defines matmul for floats; other integer and mixed precision
685719
# cases are handled here
686720

stdlib/LinearAlgebra/test/matmul.jl

Lines changed: 31 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -226,34 +226,37 @@ end
226226
end
227227
end
228228

229-
@testset "Complex matrix x real MatOrVec etc (issue #29224)" for T1 in (Float32, Float64)
230-
for T2 in (Float32, Float64)
231-
for arg1_real in (true, false)
232-
@testset "Combination $T1 $T2 $arg1_real $arg2_real" for arg2_real in (true, false)
233-
A0 = reshape(Vector{T1}(1:25), 5, 5) .+
234-
(arg1_real ? 0 : 1im * reshape(Vector{T1}(-3:21), 5, 5))
235-
A = view(A0, 1:2, 1:2)
236-
B = Matrix{T2}([1.0 3.0; -1.0 2.0]) .+
237-
(arg2_real ? 0 : 1im * Matrix{T2}([3.0 4; -1 10]))
238-
AB_correct = copy(A) * B
239-
AB = A * B # view times matrix
240-
@test AB AB_correct
241-
A1 = view(A0, :, 1:2) # rectangular view times matrix
242-
@test A1 * B copy(A1) * B
243-
B1 = view(B, 1:2, 1:2)
244-
AB1 = A * B1 # view times view
245-
@test AB1 AB_correct
246-
x = Vector{T2}([1.0; 10.0]) .+ (arg2_real ? 0 : 1im * Vector{T2}([3; -1]))
247-
Ax_exact = copy(A) * x
248-
Ax = A * x # view times vector
249-
@test Ax Ax_exact
250-
x1 = view(x, 1:2)
251-
Ax1 = A * x1 # view times viewed vector
252-
@test Ax1 Ax_exact
253-
@test copy(A) * x1 Ax_exact # matrix times viewed vector
254-
# View times transposed matrix
255-
Bt = transpose(B)
256-
@test A * Bt A * copy(Bt)
229+
@testset "Complex matrix x real MatOrVec etc (issue #29224)" for T in (Float32, Float64)
230+
A0 = randn(complex(T), 10, 10)
231+
B0 = randn(T, 10, 10)
232+
@testset "Combination Mat{$(complex(T))} Mat{$T}" for Bax1 in (1:5, 2:2:10), Bax2 in (1:5, 2:2:10)
233+
B = view(A0, Bax1, Bax2)
234+
tB = transpose(B)
235+
Bd, tBd = copy(B), copy(tB)
236+
for Aax1 in (1:5, 2:2:10, (:)), Aax2 in (1:5, 2:2:10)
237+
A = view(A0, Aax1, Aax2)
238+
AB_correct = copy(A) * Bd
239+
AtB_correct = copy(A) * tBd
240+
@test A*Bd AB_correct # view times matrix
241+
@test A*B AB_correct # view times view
242+
@test A*tBd AtB_correct # view times transposed matrix
243+
@test A*tB AtB_correct # view times transposed view
244+
end
245+
end
246+
x = randn(T, 10)
247+
y0 = similar(A0, 20)
248+
@testset "Combination Mat{$(complex(T))} Vec{$T}" for Aax1 in (1:5, 2:2:10, (:)), Aax2 in (1:5, 2:2:10)
249+
A = view(A0, Aax1, Aax2)
250+
Ad = copy(A)
251+
for indx in (1:5, 1:2:10, 6:-1:2)
252+
vx = view(x, indx)
253+
dx = x[indx]
254+
Ax_correct = Ad*dx
255+
@test A*vx A*dx Ad*vx Ax_correct # view/matrix times view/vector
256+
for indy in (1:2:2size(A,1), size(A,1):-1:1)
257+
y = view(y0, indy)
258+
@test mul!(y, A, vx) mul!(y, A, dx) mul!(y, Ad, vx)
259+
mul!(y, Ad, dx) Ax_correct # test for uncontiguous dest
257260
end
258261
end
259262
end

0 commit comments

Comments
 (0)