|
65 | 65 | alpha::Number, beta::Number) where {T<:BlasFloat} = |
66 | 66 | gemv!(y, 'N', A, x, alpha, beta) |
67 | 67 |
|
68 | | -# Complex matrix times real vector. Reinterpret the matrix as a real matrix and do real matvec compuation. |
69 | | -for elty in (Float32, Float64) |
70 | | - @eval begin |
71 | | - @inline function mul!(y::StridedVector{Complex{$elty}}, A::StridedVecOrMat{Complex{$elty}}, x::StridedVector{$elty}, |
72 | | - alpha::Real, beta::Real) |
73 | | - Afl = reinterpret($elty, A) |
74 | | - yfl = reinterpret($elty, y) |
75 | | - mul!(yfl, Afl, x, alpha, beta) |
76 | | - return y |
77 | | - end |
78 | | - end |
79 | | -end |
| 68 | +# Complex matrix times real vector. |
| 69 | +# Reinterpret the matrix as a real matrix and do real matvec compuation. |
| 70 | +@inline mul!(y::StridedVector{Complex{T}}, A::StridedVecOrMat{Complex{T}}, x::StridedVector{T}, |
| 71 | + alpha::Number, beta::Number) where {T<:BlasReal} = |
| 72 | + gemv!(y, 'N', A, x, alpha, beta) |
80 | 73 |
|
81 | 74 | # Real matrix times complex vector. |
82 | 75 | # Multiply the matrix with the real and imaginary parts separately |
83 | 76 | @inline mul!(y::StridedVector{Complex{T}}, A::StridedMaybeAdjOrTransMat{T}, x::StridedVector{Complex{T}}, |
84 | | - alpha::Number, beta::Number) where {T<:BlasFloat} = |
| 77 | + alpha::Number, beta::Number) where {T<:BlasReal} = |
85 | 78 | gemv!(y, A isa StridedArray ? 'N' : 'T', A isa StridedArray ? A : parent(A), x, alpha, beta) |
86 | 79 |
|
87 | 80 | @inline mul!(y::AbstractVector, A::AbstractVecOrMat, x::AbstractVector, |
|
192 | 185 | (*)(A::AdjOrTransStridedMat{<:BlasReal}, B::StridedMatrix{<:BlasComplex}) = copy(transpose(transpose(B) * parent(A))) |
193 | 186 | (*)(A::StridedMaybeAdjOrTransMat{<:BlasReal}, B::AdjOrTransStridedMat{<:BlasComplex}) = copy(wrapperop(B)(parent(B) * transpose(A))) |
194 | 187 |
|
195 | | -for elty in (Float32,Float64) |
196 | | - @eval begin |
197 | | - @inline function mul!(C::StridedMatrix{Complex{$elty}}, A::StridedVecOrMat{Complex{$elty}}, B::StridedVecOrMat{$elty}, |
198 | | - alpha::Real, beta::Real) |
199 | | - Afl = reinterpret($elty, A) |
200 | | - Cfl = reinterpret($elty, C) |
201 | | - mul!(Cfl, Afl, B, alpha, beta) |
202 | | - return C |
203 | | - end |
204 | | - end |
205 | | -end |
206 | | - |
207 | 188 | """ |
208 | 189 | muladd(A, y, z) |
209 | 190 |
|
@@ -410,18 +391,14 @@ end |
410 | 391 | return gemm_wrapper!(C, 'N', 'T', A, B, MulAddMul(alpha, beta)) |
411 | 392 | end |
412 | 393 | end |
413 | | -# Complex matrix times transposed real matrix. Reinterpret the first matrix to real for efficiency. |
414 | | -for elty in (Float32,Float64) |
415 | | - @eval begin |
416 | | - @inline function mul!(C::StridedMatrix{Complex{$elty}}, A::StridedVecOrMat{Complex{$elty}}, tB::Transpose{<:Any,<:StridedVecOrMat{$elty}}, |
417 | | - alpha::Real, beta::Real) |
418 | | - Afl = reinterpret($elty, A) |
419 | | - Cfl = reinterpret($elty, C) |
420 | | - mul!(Cfl, Afl, tB, alpha, beta) |
421 | | - return C |
422 | | - end |
423 | | - end |
424 | | -end |
| 394 | +# Complex matrix times (transposed) real matrix. Reinterpret the first matrix to real for efficiency. |
| 395 | +@inline mul!(C::StridedMatrix{Complex{T}}, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T}, |
| 396 | + alpha::Number, beta::Number) where {T<:BlasReal} = |
| 397 | + gemm_wrapper!(C, 'N', 'N', A, B, MulAddMul(alpha, beta)) |
| 398 | +@inline mul!(C::StridedMatrix{Complex{T}}, A::StridedVecOrMat{Complex{T}}, tB::Transpose{<:Any,<:StridedVecOrMat{T}}, |
| 399 | + alpha::Number, beta::Number) where {T<:BlasReal} = |
| 400 | + gemm_wrapper!(C, 'N', 'T', A, parent(tB), MulAddMul(alpha, beta)) |
| 401 | + |
425 | 402 | # collapsing the following two defs with C::AbstractVecOrMat yields ambiguities |
426 | 403 | @inline mul!(C::AbstractVector, A::AbstractVecOrMat, tB::Transpose{<:Any,<:AbstractVecOrMat}, |
427 | 404 | alpha::Number, beta::Number) = |
@@ -513,22 +490,36 @@ end |
513 | 490 | function gemv!(y::StridedVector{T}, tA::AbstractChar, A::StridedVecOrMat{T}, x::StridedVector{T}, |
514 | 491 | α::Number=true, β::Number=false) where {T<:BlasFloat} |
515 | 492 | mA, nA = lapack_size(tA, A) |
516 | | - if nA != length(x) |
| 493 | + nA != length(x) && |
517 | 494 | throw(DimensionMismatch(lazy"second dimension of A, $nA, does not match length of x, $(length(x))")) |
518 | | - end |
519 | | - if mA != length(y) |
| 495 | + mA != length(y) && |
520 | 496 | throw(DimensionMismatch(lazy"first dimension of A, $mA, does not match length of y, $(length(y))")) |
| 497 | + mA == 0 && return y |
| 498 | + nA == 0 && return _rmul_or_fill!(y, β) |
| 499 | + alpha, beta = promote(α, β, zero(T)) |
| 500 | + if alpha isa Union{Bool,T} && beta isa Union{Bool,T} && |
| 501 | + stride(A, 1) == 1 && stride(A, 2) >= size(A, 1) |
| 502 | + return BLAS.gemv!(tA, alpha, A, x, beta, y) |
| 503 | + else |
| 504 | + return generic_matvecmul!(y, tA, A, x, MulAddMul(α, β)) |
521 | 505 | end |
522 | | - if mA == 0 |
523 | | - return y |
524 | | - end |
525 | | - if nA == 0 |
526 | | - return _rmul_or_fill!(y, β) |
527 | | - end |
| 506 | +end |
528 | 507 |
|
| 508 | +function gemv!(y::StridedVector{Complex{T}}, tA::AbstractChar, A::StridedVecOrMat{Complex{T}}, x::StridedVector{T}, |
| 509 | + α::Number = true, β::Number = false) where {T<:BlasReal} |
| 510 | + mA, nA = lapack_size(tA, A) |
| 511 | + nA != length(x) && |
| 512 | + throw(DimensionMismatch(lazy"second dimension of A, $nA, does not match length of x, $(length(x))")) |
| 513 | + mA != length(y) && |
| 514 | + throw(DimensionMismatch(lazy"first dimension of A, $mA, does not match length of y, $(length(y))")) |
| 515 | + mA == 0 && return y |
| 516 | + nA == 0 && return _rmul_or_fill!(y, β) |
529 | 517 | alpha, beta = promote(α, β, zero(T)) |
530 | | - if alpha isa Union{Bool,T} && beta isa Union{Bool,T} && stride(A, 1) == 1 && stride(A, 2) >= size(A, 1) |
531 | | - return BLAS.gemv!(tA, alpha, A, x, beta, y) |
| 518 | + if alpha isa Union{Bool,T} && beta isa Union{Bool,T} && |
| 519 | + stride(A, 1) == 1 && stride(A, 2) >= size(A, 1) && |
| 520 | + stride(y, 1) == 1 && tA == 'N' # reinterpret-based optimization is valid only for contiguous `y` |
| 521 | + BLAS.gemv!(tA, alpha, reinterpret(T, A), x, beta, reinterpret(T, y)) |
| 522 | + return y |
532 | 523 | else |
533 | 524 | return generic_matvecmul!(y, tA, A, x, MulAddMul(α, β)) |
534 | 525 | end |
@@ -681,6 +672,49 @@ function gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar |
681 | 672 | generic_matmatmul!(C, tA, tB, A, B, _add) |
682 | 673 | end |
683 | 674 |
|
| 675 | +function gemm_wrapper!(C::StridedVecOrMat{Complex{T}}, tA::AbstractChar, tB::AbstractChar, |
| 676 | + A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T}, |
| 677 | + _add = MulAddMul()) where {T<:BlasReal} |
| 678 | + mA, nA = lapack_size(tA, A) |
| 679 | + mB, nB = lapack_size(tB, B) |
| 680 | + |
| 681 | + if nA != mB |
| 682 | + throw(DimensionMismatch(lazy"A has dimensions ($mA,$nA) but B has dimensions ($mB,$nB)")) |
| 683 | + end |
| 684 | + |
| 685 | + if C === A || B === C |
| 686 | + throw(ArgumentError("output matrix must not be aliased with input matrix")) |
| 687 | + end |
| 688 | + |
| 689 | + if mA == 0 || nA == 0 || nB == 0 || iszero(_add.alpha) |
| 690 | + if size(C) != (mA, nB) |
| 691 | + throw(DimensionMismatch(lazy"C has dimensions $(size(C)), should have ($mA,$nB)")) |
| 692 | + end |
| 693 | + return _rmul_or_fill!(C, _add.beta) |
| 694 | + end |
| 695 | + |
| 696 | + if mA == 2 && nA == 2 && nB == 2 |
| 697 | + return matmul2x2!(C, tA, tB, A, B, _add) |
| 698 | + end |
| 699 | + if mA == 3 && nA == 3 && nB == 3 |
| 700 | + return matmul3x3!(C, tA, tB, A, B, _add) |
| 701 | + end |
| 702 | + |
| 703 | + alpha, beta = promote(_add.alpha, _add.beta, zero(T)) |
| 704 | + |
| 705 | + # Make-sure reinterpret-based optimization is BLAS-compatible. |
| 706 | + if (alpha isa Union{Bool,T} && |
| 707 | + beta isa Union{Bool,T} && |
| 708 | + stride(A, 1) == stride(B, 1) == stride(C, 1) == 1 && |
| 709 | + stride(A, 2) >= size(A, 1) && |
| 710 | + stride(B, 2) >= size(B, 1) && |
| 711 | + stride(C, 2) >= size(C, 1)) && tA == 'N' |
| 712 | + BLAS.gemm!(tA, tB, alpha, reinterpret(T, A), B, beta, reinterpret(T, C)) |
| 713 | + return C |
| 714 | + end |
| 715 | + generic_matmatmul!(C, tA, tB, A, B, _add) |
| 716 | +end |
| 717 | + |
684 | 718 | # blas.jl defines matmul for floats; other integer and mixed precision |
685 | 719 | # cases are handled here |
686 | 720 |
|
|
0 commit comments