Skip to content

Commit 33a71b7

Browse files
authored
Allow negative stride(A,2) in gemv! (#42054)
1 parent a947fc7 commit 33a71b7

File tree

2 files changed

+43
-34
lines changed

2 files changed

+43
-34
lines changed

stdlib/LinearAlgebra/src/blas.jl

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -673,17 +673,26 @@ for (fname, elty) in ((:dgemv_,:Float64),
673673
end
674674
chkstride1(A)
675675
lda = stride(A,2)
676-
lda >= max(1, size(A,1)) || error("`stride(A,2)` must be at least `max(1, size(A,1))`")
677676
sX = stride(X,1)
678-
pX = pointer(X, sX > 0 ? firstindex(X) : lastindex(X))
679677
sY = stride(Y,1)
680-
pY = pointer(Y, sY > 0 ? firstindex(Y) : lastindex(Y))
681-
GC.@preserve X Y ccall((@blasfunc($fname), libblastrampoline), Cvoid,
678+
if lda < 0
679+
colindex = lastindex(A, 2)
680+
lda = -lda
681+
trans == 'N' ? (sX = -sX) : (sY = -sY)
682+
else
683+
colindex = firstindex(A, 2)
684+
end
685+
lda >= size(A,1) || size(A,2) <= 1 || error("when `size(A,2) > 1`, `abs(stride(A,2))` must be at least `size(A,1)`")
686+
lda = max(1, size(A,1), lda)
687+
pA = pointer(A, Base._sub2ind(A, 1, colindex))
688+
pX = pointer(X, stride(X,1) > 0 ? firstindex(X) : lastindex(X))
689+
pY = pointer(Y, stride(Y,1) > 0 ? firstindex(Y) : lastindex(Y))
690+
GC.@preserve A X Y ccall((@blasfunc($fname), libblastrampoline), Cvoid,
682691
(Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ref{$elty},
683692
Ptr{$elty}, Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt},
684693
Ref{$elty}, Ptr{$elty}, Ref{BlasInt}, Clong),
685694
trans, size(A,1), size(A,2), alpha,
686-
A, lda, pX, sX,
695+
pA, lda, pX, sX,
687696
beta, pY, sY, 1)
688697
Y
689698
end

stdlib/LinearAlgebra/test/blas.jl

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -414,38 +414,38 @@ Random.seed!(100)
414414
@test all(BLAS.gemv('N', U4, o4) .== v41)
415415
@test all(BLAS.gemv('N', U4, o4) .== v41)
416416
@testset "non-standard strides" begin
417-
if elty <: Complex
418-
A = elty[1+2im 3+4im 5+6im 7+8im; 2+3im 4+5im 6+7im 8+9im; 3+4im 5+6im 7+8im 9+10im]
419-
v = elty[1+2im, 2+3im, 3+4im, 4+5im, 5+6im]
420-
dest = view(ones(elty, 7), 6:-2:2)
421-
@test BLAS.gemv!('N', elty(2), view(A, :, 2:2:4), view(v, 1:3:4), elty(3), dest) == elty[-31+154im, -35+178im, -39+202im]
422-
@test BLAS.gemv('N', elty(-1), view(A, 2:3, 2:3), view(v, 2:-1:1)) == elty[15-41im, 17-49im]
423-
@test BLAS.gemv('N', view(A, 1:0, 1:2), view(v, 1:2)) == elty[]
424-
dest = view(ones(elty, 5), 4:-2:2)
425-
@test BLAS.gemv!('T', elty(2), view(A, :, 2:2:4), view(v, 1:2:5), elty(3), dest) == elty[-45+202im, -69+370im]
426-
@test BLAS.gemv('T', elty(-1), view(A, 2:3, 2:3), view(v, 2:-1:1)) == elty[14-38im, 18-54im]
427-
@test BLAS.gemv('T', view(A, 2:3, 2:1), view(v, 1:2)) == elty[]
428-
dest = view(ones(elty, 5), 4:-2:2)
429-
@test BLAS.gemv!('C', elty(2), view(A, :, 2:2:4), view(v, 5:-2:1), elty(3), dest) == elty[179+6im, 347+30im]
430-
@test BLAS.gemv('C', elty(-1), view(A, 2:3, 2:3), view(v, 2:-1:1)) == elty[-40-6im, -56-10im]
431-
@test BLAS.gemv('C', view(A, 2:3, 2:1), view(v, 1:2)) == elty[]
432-
else
433-
A = elty[1 2 3 4; 5 6 7 8; 9 10 11 12]
434-
v = elty[1, 2, 3, 4, 5]
435-
dest = view(ones(elty, 7), 6:-2:2)
436-
@test BLAS.gemv!('N', elty(2), view(A, :, 2:2:4), view(v, 1:3:4), elty(3), dest) == elty[39, 79, 119]
437-
@test BLAS.gemv('N', elty(-1), view(A, 2:3, 2:3), view(v, 2:-1:1)) == elty[-19, -31]
438-
@test BLAS.gemv('N', view(A, 1:0, 1:2), view(v, 1:2)) == elty[]
439-
for trans = ('T', 'C')
440-
dest = view(ones(elty, 5), 4:-2:2)
441-
@test BLAS.gemv!(trans, elty(2), view(A, :, 2:2:4), view(v, 1:2:5), elty(3), dest) == elty[143, 179]
442-
@test BLAS.gemv(trans, elty(-1), view(A, 2:3, 2:3), view(v, 2:-1:1)) == elty[-22, -25]
443-
@test BLAS.gemv(trans, view(A, 2:3, 2:1), view(v, 1:2)) == elty[]
417+
A = rand(elty, 3, 4)
418+
x = rand(elty, 5)
419+
for y = (view(ones(elty, 5), 1:2:5), view(ones(elty, 7), 6:-2:2))
420+
ycopy = copy(y)
421+
@test BLAS.gemv!('N', elty(2), view(A, :, 2:2:4), view(x, 1:3:4), elty(3), y) 2*A[:,2:2:4]*x[1:3:4] + 3*ycopy
422+
ycopy = copy(y)
423+
@test BLAS.gemv!('N', elty(2), view(A, :, 4:-2:2), view(x, 1:3:4), elty(3), y) 2*A[:,4:-2:2]*x[1:3:4] + 3*ycopy
424+
ycopy = copy(y)
425+
@test BLAS.gemv!('N', elty(2), view(A, :, 2:2:4), view(x, 4:-3:1), elty(3), y) 2*A[:,2:2:4]*x[4:-3:1] + 3*ycopy
426+
ycopy = copy(y)
427+
@test BLAS.gemv!('N', elty(2), view(A, :, 4:-2:2), view(x, 4:-3:1), elty(3), y) 2*A[:,4:-2:2]*x[4:-3:1] + 3*ycopy
428+
ycopy = copy(y)
429+
@test BLAS.gemv!('N', elty(2), view(A, :, StepRangeLen(1,0,1)), view(x, 1:1), elty(3), y) 2*A[:,1:1]*x[1:1] + 3*ycopy # stride(A,2) == 0
430+
end
431+
@test BLAS.gemv!('N', elty(1), zeros(elty, 0, 5), zeros(elty, 5), elty(1), zeros(elty, 0)) == elty[] # empty matrix, stride(A,2) == 0
432+
@test BLAS.gemv('N', elty(-1), view(A, 2:3, 1:2:3), view(x, 2:-1:1)) -1*A[2:3,1:2:3]*x[2:-1:1]
433+
@test BLAS.gemv('N', view(A, 2:3, 3:-2:1), view(x, 1:2:3)) A[2:3,3:-2:1]*x[1:2:3]
434+
for (trans, f) = (('T',transpose), ('C',adjoint))
435+
for y = (view(ones(elty, 3), 1:2:3), view(ones(elty, 5), 4:-2:2))
436+
ycopy = copy(y)
437+
@test BLAS.gemv!(trans, elty(2), view(A, :, 2:2:4), view(x, 1:2:5), elty(3), y) 2*f(A[:,2:2:4])*x[1:2:5] + 3*ycopy
438+
ycopy = copy(y)
439+
@test BLAS.gemv!(trans, elty(2), view(A, :, 4:-2:2), view(x, 1:2:5), elty(3), y) 2*f(A[:,4:-2:2])*x[1:2:5] + 3*ycopy
440+
ycopy = copy(y)
441+
@test BLAS.gemv!(trans, elty(2), view(A, :, 2:2:4), view(x, 5:-2:1), elty(3), y) 2*f(A[:,2:2:4])*x[5:-2:1] + 3*ycopy
442+
ycopy = copy(y)
443+
@test BLAS.gemv!(trans, elty(2), view(A, :, 4:-2:2), view(x, 5:-2:1), elty(3), y) 2*f(A[:,4:-2:2])*x[5:-2:1] + 3*ycopy
444444
end
445+
@test BLAS.gemv!(trans, elty(2), view(A, :, StepRangeLen(1,0,1)), view(x, 1:2:5), elty(3), elty[1]) 2*f(A[:,1:1])*x[1:2:5] + elty[3] # stride(A,2) == 0
445446
end
446447
for trans = ('N', 'T', 'C')
447-
@test_throws ErrorException BLAS.gemv(trans, view(A, 1:2:3, 1:2), view(v, 1:2))
448-
@test_throws ErrorException BLAS.gemv(trans, view(A, 1:2, 2:-1:1), view(v, 1:2))
448+
@test_throws ErrorException BLAS.gemv(trans, view(A, 1:2:3, 1:2), view(x, 1:2)) # stride(A,1) must be 1
449449
end
450450
end
451451
end

0 commit comments

Comments
 (0)