Skip to content

Commit fc0d327

Browse files
authored
Merge pull request #403 from danielwe/opnorm
Implement opnorm
2 parents 94a2e60 + 7899dcf commit fc0d327

File tree

2 files changed

+49
-16
lines changed

2 files changed

+49
-16
lines changed

src/host/linalg.jl

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -296,45 +296,63 @@ end
296296
## norm
297297

298298
function LinearAlgebra.norm(v::AbstractGPUArray{T}, p::Real=2) where {T}
299-
result_type = typeof(float(norm(zero(T))))
299+
result_type, sum_type, promote_ = _normtypes(T)
300300
isempty(v) && return zero(result_type)
301301
p == 0 && return convert(result_type, count(!iszero, v))
302-
# Accumulate in at least Float32, like nrm2 in CUBLAS
303-
acc_type = promote_type(Float32, result_type)
304-
spp = convert(acc_type, p)
305-
init = zero(acc_type) # To set the accumulation type in `sum`
306-
# If acc_type is wider than T, widen before applying other functions. To work in GPU
307-
# kernels this operation must close around a value, not a type, hence the prototype
308-
prototype = zero(promote_type(T, acc_type))
309-
widen(x) = convert(typeof(prototype), x)
302+
spp = convert(sum_type, p)
303+
init = zero(sum_type) # To set the accumulation type in `sum`
310304
# Rescaling heuristic similar to Base, see LinearAlgebra/src/generic.jl
311305
result = if p > 1 || p < -1 # May need rescaling
312306
infnorm = p > 1 ? maximum(norm, v) : minimum(norm, v)
313307
if isinf(p) || iszero(infnorm) || isinf(infnorm)
314308
return convert(result_type, infnorm) # Return early to skip conversions
315309
end
316-
factor = convert(acc_type, infnorm)
310+
factor = convert(sum_type, infnorm)
317311
if p == 2
318312
if isfinite(length(v) * factor^2) && !iszero(factor^2) # No rescaling
319-
sqrt(sum(x -> LinearAlgebra.norm_sqr(widen(x)), v; init=init))
313+
sqrt(sum(x -> LinearAlgebra.norm_sqr(promote_(x)), v; init=init))
320314
else # Rescaling
321-
factor * sqrt(sum(x -> (norm(widen(x)) / factor)^2, v; init=init))
315+
factor * sqrt(sum(x -> (norm(promote_(x)) / factor)^2, v; init=init))
322316
end
323317
else
324318
if isfinite(length(v) * factor^spp) && !iszero(factor^spp) # No rescaling
325-
sum(x -> norm(widen(x))^spp, v; init=init)^inv(spp)
319+
sum(x -> norm(promote_(x))^spp, v; init=init)^inv(spp)
326320
else # Rescaling
327-
factor * (sum(x -> (norm(widen(x)) / factor)^spp, v; init=init)^inv(spp))
321+
factor * (sum(x -> (norm(promote_(x)) / factor)^spp, v; init=init)^inv(spp))
328322
end
329323
end
330324
elseif p == 1
331-
sum(x -> norm(widen(x)), v; init=init)
325+
sum(x -> norm(promote_(x)), v; init=init)
332326
else
333-
sum(x -> norm(widen(x))^spp, v; init=init)^inv(spp)
327+
sum(x -> norm(promote_(x))^spp, v; init=init)^inv(spp)
334328
end
335329
return convert(result_type, result)
336330
end
337331

332+
function _normtypes(::Type{T}) where {T}
333+
result_type = typeof(float(norm(zero(T))))
334+
# Accumulate in at least Float32, like nrm2 in CUBLAS
335+
sum_type = promote_type(Float32, result_type)
336+
# If sum_type is wider than T, promote before applying other functions. To work in GPU
337+
# kernels this operation must close around a value, not a type, hence the prototype
338+
prototype = zero(promote_type(T, sum_type))
339+
promote_(x) = convert(typeof(prototype), x)
340+
return result_type, sum_type, promote_
341+
end
342+
343+
## opnorm
344+
345+
function LinearAlgebra.opnorm1(A::AnyGPUArray{T,2}) where {T}
346+
result_type, sum_type, promote_ = _normtypes(T)
347+
result = maximum(sum(x -> norm(promote_(x)), A; dims=1); init=zero(sum_type))
348+
return convert(result_type, result)
349+
end
350+
351+
function LinearAlgebra.opnormInf(A::AnyGPUArray{T,2}) where {T}
352+
result_type, sum_type, promote_ = _normtypes(T)
353+
result = maximum(sum(x -> norm(promote_(x)), A; dims=2); init=zero(sum_type))
354+
return convert(result_type, result)
355+
end
338356

339357
## symmetric
340358

test/testsuite/linalg.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,4 +255,19 @@ end
255255
@test compare(norm, AT, arr, Ref(p))
256256
end
257257
end
258+
@testset "$p-opnorm($sz x $T)" for sz in [(2, 0), (2, 3)],
259+
p in Any[1, Inf],
260+
T in eltypes
261+
if T == Int8
262+
continue
263+
end
264+
if !in(float(real(T)), eltypes)
265+
# norm promotes to float, so make sure that type is supported
266+
continue
267+
end
268+
range = real(T) <: Integer ? (T.(1:10)) : T # prevent integer overflow
269+
mat = rand(range, sz)
270+
@test compare(opnorm, AT, mat, Ref(p))
271+
@test isrealfloattype(typeof(opnorm(AT(mat), p)))
272+
end
258273
end

0 commit comments

Comments
 (0)