@@ -611,7 +611,7 @@ ClipNorm(ω = 10f0, p = 2; throw::Bool = true) = ClipNorm{float(typeof(ω))}(ω,
611611init (o:: ClipNorm , x:: AbstractArray ) = nothing
612612
613613function apply! (o:: ClipNorm , state, x, dx)
614- nrm = norm (dx, o. p)
614+ nrm = _norm (dx, o. p)
615615 if o. throw && ! isfinite (nrm)
616616 throw (DomainError (" gradient has $(o. p) -norm $nrm , for array $(summary (x)) " ))
617617 end
@@ -620,6 +620,48 @@ function apply!(o::ClipNorm, state, x, dx)
620620 return state, @lazy dx * λ
621621end
622622
623+ _norm (dx:: AbstractArray , p:: Real ) = norm (dx, p) # LinearAlgebra, CUDA
624+ function _norm (dx:: Broadcast.Broadcasted , p:: Real )
625+ if p == 2
626+ # This lacks the undeflow/overflow tests of LinearAlgebra's version
627+ sqrt (sum (abs2, dx))
628+ elseif p == 1
629+ float (sum (abs, dx))
630+ elseif p == Inf
631+ float (maximum (abs, dx))
632+ elseif p == 0
633+ cnt = count (! iszero, dx)
634+ T = Base. @default_eltype dx
635+ T <: Number ? convert (float (T), cnt) : cnt
636+ elseif p == - Inf
637+ float (minimum (abs, dx))
638+ else
639+ # This isn't optimally fast but does ensure p::Float64 doesn't promote
640+ tmp = abs .(dx)
641+ q = convert (float (eltype (tmp)), p)
642+ sum (tmp .^ q) ^ (1 / q)
643+ end
644+ end
645+
646+ #=
647+
648+ julia> using Metal
649+
650+ julia> using Base.Broadcast: broadcasted, instantiate
651+
652+ julia> bc = instantiate(broadcasted(+, MtlArray(rand(Float32, 3)), 1));
653+
654+ julia> norm(bc)
655+ ┌ Warning: Performing scalar indexing
656+
657+ └ @ Metal ~/.julia/packages/Metal/TtPHW/src/compiler/compilation.jl:77
658+ ERROR: NSError: Undefined symbols:
659+ llvm.maximum.f32, referenced from: _Z24partial_mapreduce_device8identity3max7Float323ValILi1024EES2_I22CartesianIndices__3___ES2_I22CartesianIndices__1___ES2_ILi1EES2_ILi1EES2_ILitrueEE14MtlDeviceArrayIS1_Li2ELi1EE11BroadcastedI13MtlArrayStyleILi1EE5TupleI5OneToI5Int64EE4normS6_IS4_IS5_ILi1EES6_IS7_IS8_EE1_S6_IS3_IS1_Li1ELi1EES8_EEEE
660+
661+ julia> Metal.allowscalar(false)
662+
663+ =#
664+
623665"""
624666 OptimiserChain(opts...)
625667
0 commit comments