@@ -296,45 +296,63 @@ end
296296# # norm
297297
298298function LinearAlgebra. norm (v:: AbstractGPUArray{T} , p:: Real = 2 ) where {T}
299- result_type = typeof ( float ( norm ( zero (T))) )
299+ result_type, sum_type, promote_ = _normtypes (T )
300300 isempty (v) && return zero (result_type)
301301 p == 0 && return convert (result_type, count (! iszero, v))
302- # Accumulate in at least Float32, like nrm2 in CUBLAS
303- acc_type = promote_type (Float32, result_type)
304- spp = convert (acc_type, p)
305- init = zero (acc_type) # To set the accumulation type in `sum`
306- # If acc_type is wider than T, widen before applying other functions. To work in GPU
307- # kernels this operation must close around a value, not a type, hence the prototype
308- prototype = zero (promote_type (T, acc_type))
309- widen (x) = convert (typeof (prototype), x)
302+ spp = convert (sum_type, p)
303+ init = zero (sum_type) # To set the accumulation type in `sum`
310304 # Rescaling heuristic similar to Base, see LinearAlgebra/src/generic.jl
311305 result = if p > 1 || p < - 1 # May need rescaling
312306 infnorm = p > 1 ? maximum (norm, v) : minimum (norm, v)
313307 if isinf (p) || iszero (infnorm) || isinf (infnorm)
314308 return convert (result_type, infnorm) # Return early to skip conversions
315309 end
316- factor = convert (acc_type , infnorm)
310+ factor = convert (sum_type , infnorm)
317311 if p == 2
318312 if isfinite (length (v) * factor^ 2 ) && ! iszero (factor^ 2 ) # No rescaling
319- sqrt (sum (x -> LinearAlgebra. norm_sqr (widen (x)), v; init= init))
313+ sqrt (sum (x -> LinearAlgebra. norm_sqr (promote_ (x)), v; init= init))
320314 else # Rescaling
321- factor * sqrt (sum (x -> (norm (widen (x)) / factor)^ 2 , v; init= init))
315+ factor * sqrt (sum (x -> (norm (promote_ (x)) / factor)^ 2 , v; init= init))
322316 end
323317 else
324318 if isfinite (length (v) * factor^ spp) && ! iszero (factor^ spp) # No rescaling
325- sum (x -> norm (widen (x))^ spp, v; init= init)^ inv (spp)
319+ sum (x -> norm (promote_ (x))^ spp, v; init= init)^ inv (spp)
326320 else # Rescaling
327- factor * (sum (x -> (norm (widen (x)) / factor)^ spp, v; init= init)^ inv (spp))
321+ factor * (sum (x -> (norm (promote_ (x)) / factor)^ spp, v; init= init)^ inv (spp))
328322 end
329323 end
330324 elseif p == 1
331- sum (x -> norm (widen (x)), v; init= init)
325+ sum (x -> norm (promote_ (x)), v; init= init)
332326 else
333- sum (x -> norm (widen (x))^ spp, v; init= init)^ inv (spp)
327+ sum (x -> norm (promote_ (x))^ spp, v; init= init)^ inv (spp)
334328 end
335329 return convert (result_type, result)
336330end
337331
332+ function _normtypes (:: Type{T} ) where {T}
333+ result_type = typeof (float (norm (zero (T))))
334+ # Accumulate in at least Float32, like nrm2 in CUBLAS
335+ sum_type = promote_type (Float32, result_type)
336+ # If sum_type is wider than T, promote before applying other functions. To work in GPU
337+ # kernels this operation must close around a value, not a type, hence the prototype
338+ prototype = zero (promote_type (T, sum_type))
339+ promote_ (x) = convert (typeof (prototype), x)
340+ return result_type, sum_type, promote_
341+ end
342+
343+ # # opnorm
344+
345+ function LinearAlgebra. opnorm1 (A:: AnyGPUArray{T,2} ) where {T}
346+ result_type, sum_type, promote_ = _normtypes (T)
347+ result = maximum (sum (x -> norm (promote_ (x)), A; dims= 1 ); init= zero (sum_type))
348+ return convert (result_type, result)
349+ end
350+
351+ function LinearAlgebra. opnormInf (A:: AnyGPUArray{T,2} ) where {T}
352+ result_type, sum_type, promote_ = _normtypes (T)
353+ result = maximum (sum (x -> norm (promote_ (x)), A; dims= 2 ); init= zero (sum_type))
354+ return convert (result_type, result)
355+ end
338356
339357# # symmetric
340358
0 commit comments