@@ -173,6 +173,36 @@ if VERSION < v"1.8-"
173173 return B
174174 end
175175else
176+ function LinearAlgebra. mul! (B:: AbstractGPUVecOrMat ,
177+ D:: Diagonal{<:Any, <:AbstractGPUArray} ,
178+ A:: AbstractGPUVecOrMat )
179+ dd = D. diag
180+ d = length (dd)
181+ m, n = size (A, 1 ), size (A, 2 )
182+ m′, n′ = size (B, 1 ), size (B, 2 )
183+ m == d || throw (DimensionMismatch (" right hand side has $m rows but D is $d by $d " ))
184+ (m, n) == (m′, n′) || throw (DimensionMismatch (" expect output to be $m by $n , but got $m′ by $n′ " ))
185+ @. B = dd * A
186+
187+ B
188+ end
189+
190+ function LinearAlgebra. mul! (B:: AbstractGPUVecOrMat ,
191+ D:: Diagonal{<:Any, <:AbstractGPUArray} ,
192+ A:: AbstractGPUVecOrMat ,
193+ α:: Number ,
194+ β:: Number )
195+ dd = D. diag
196+ d = length (dd)
197+ m, n = size (A, 1 ), size (A, 2 )
198+ m′, n′ = size (B, 1 ), size (B, 2 )
199+ m == d || throw (DimensionMismatch (" right hand side has $m rows but D is $d by $d " ))
200+ (m, n) == (m′, n′) || throw (DimensionMismatch (" expect output to be $m by $n , but got $m′ by $n′ " ))
201+ @. B = α * dd* A + β * B
202+
203+ B
204+ end
205+
176206 function LinearAlgebra. ldiv! (B:: AbstractGPUVecOrMat ,
177207 D:: Diagonal{<:Any, <:AbstractGPUArray} ,
178208 A:: AbstractGPUVecOrMat )
0 commit comments