Skip to content

Commit 6482f3d

Browse files
committed
Remove undesired promotions with Float64 literals
1 parent f26e4ee commit 6482f3d

File tree

4 files changed

+499
-290
lines changed

4 files changed

+499
-290
lines changed

src/deviation.jl

Lines changed: 39 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,14 @@ end
4646
Compute the squared L2 distance between two arrays: ``\\sum_{i=1}^n |a_i - b_i|^2``.
4747
Efficient equivalent of `sum(abs2, a - b)`.
4848
"""
49-
function sqL2dist(a::AbstractArray{T}, b::AbstractArray{T}) where T<:Number
49+
function sqL2dist(a::AbstractArray{<:Number}, b::AbstractArray{<:Number})
5050
n = length(a)
5151
length(b) == n || throw(DimensionMismatch("Input dimension mismatch"))
52-
r = 0.0
53-
for i in eachindex(a, b)
54-
r += abs2(a[i] - b[i])
52+
if iszero(n)
53+
r = zero(abs2(zero(eltype(a)) - zero(eltype(b))))
54+
else
55+
broadcasted = Broadcast.broadcasted((ai, bi) -> abs2(ai - bi), vec(a), vec(b))
56+
r = sum(Broadcast.instantiate(broadcasted))
5557
end
5658
return r
5759
end
@@ -64,7 +66,7 @@ end
6466
Compute the L2 distance between two arrays: ``\\sqrt{\\sum_{i=1}^n |a_i - b_i|^2}``.
6567
Efficient equivalent of `sqrt(sum(abs2, a - b))`.
6668
"""
67-
L2dist(a::AbstractArray{T}, b::AbstractArray{T}) where {T<:Number} = sqrt(sqL2dist(a, b))
69+
L2dist(a::AbstractArray{<:Number}, b::AbstractArray{<:Number}) = sqrt(sqL2dist(a, b))
6870

6971

7072
# L1 distance
@@ -74,12 +76,14 @@ L2dist(a::AbstractArray{T}, b::AbstractArray{T}) where {T<:Number} = sqrt(sqL2di
7476
Compute the L1 distance between two arrays: ``\\sum_{i=1}^n |a_i - b_i|``.
7577
Efficient equivalent of `sum(abs, a - b)`.
7678
"""
77-
function L1dist(a::AbstractArray{T}, b::AbstractArray{T}) where T<:Number
79+
function L1dist(a::AbstractArray{<:Number}, b::AbstractArray{<:Number})
7880
n = length(a)
7981
length(b) == n || throw(DimensionMismatch("Input dimension mismatch"))
80-
r = 0.0
81-
for i in eachindex(a, b)
82-
r += abs(a[i] - b[i])
82+
if iszero(n)
83+
r = zero(abs(zero(eltype(a)) - zero(eltype(b))))
84+
else
85+
broadcasted = Broadcast.broadcasted((ai, bi) -> abs(ai - bi), vec(a), vec(b))
86+
r = sum(Broadcast.instantiate(broadcasted))
8387
end
8488
return r
8589
end
@@ -93,15 +97,14 @@ Compute the L∞ distance, also called the Chebyshev distance, between
9397
two arrays: ``\\max_{1≤i≤n} |a_i - b_i|``.
9498
Efficient equivalent of `maxabs(a - b)`.
9599
"""
96-
function Linfdist(a::AbstractArray{T}, b::AbstractArray{T}) where T<:Number
100+
function Linfdist(a::AbstractArray{<:Number}, b::AbstractArray{<:Number})
97101
n = length(a)
98102
length(b) == n || throw(DimensionMismatch("Input dimension mismatch"))
99-
r = 0.0
100-
for i in eachindex(a, b)
101-
v = abs(a[i] - b[i])
102-
if r < v
103-
r = v
104-
end
103+
if iszero(n)
104+
r = zero(abs(zero(eltype(a)) - zero(eltype(b))))
105+
else
106+
broadcasted = Broadcast.broadcasted((ai, bi) -> abs(ai - bi), vec(a), vec(b))
107+
r = maximum(Broadcast.instantiate(broadcasted))
105108
end
106109
return r
107110
end
@@ -115,19 +118,20 @@ Compute the generalized Kullback-Leibler divergence between two arrays:
115118
``\\sum_{i=1}^n (a_i \\log(a_i/b_i) - a_i + b_i)``.
116119
Efficient equivalent of `sum(a*log(a/b)-a+b)`.
117120
"""
118-
function gkldiv(a::AbstractArray{T}, b::AbstractArray{T}) where T<:AbstractFloat
121+
function gkldiv(a::AbstractArray{<:Real}, b::AbstractArray{<:Real})
119122
n = length(a)
120-
r = 0.0
121-
for i in eachindex(a, b)
122-
ai = a[i]
123-
bi = b[i]
124-
if ai > 0
125-
r += (ai * log(ai / bi) - ai + bi)
126-
else
127-
r += bi
123+
length(b) == n || throw(DimensionMismatch("Input dimension mismatch"))
124+
if iszero(n)
125+
za = zero(eltype(a))
126+
zb = zero(eltype(b))
127+
r = zero(xlogy(za, za / zb) + (zb - za))
128+
else
129+
broadcasted = Broadcast.broadcasted(vec(a), vec(b)) do ai, bi
130+
return xlogy(ai, ai / bi) + (bi - ai)
128131
end
132+
return sum(Broadcast.instantiate(broadcasted))
129133
end
130-
return r::Float64
134+
return r
131135
end
132136

133137

@@ -137,8 +141,7 @@ end
137141
138142
Return the mean absolute deviation between two arrays: `mean(abs, a - b)`.
139143
"""
140-
meanad(a::AbstractArray{T}, b::AbstractArray{T}) where {T<:Number} =
141-
L1dist(a, b) / length(a)
144+
meanad(a::AbstractArray{<:Number}, b::AbstractArray{<:Number}) = L1dist(a, b) / length(a)
142145

143146

144147
# MaxAD: maximum absolute deviation
@@ -147,7 +150,7 @@ meanad(a::AbstractArray{T}, b::AbstractArray{T}) where {T<:Number} =
147150
148151
Return the maximum absolute deviation between two arrays: `maxabs(a - b)`.
149152
"""
150-
maxad(a::AbstractArray{T}, b::AbstractArray{T}) where {T<:Number} = Linfdist(a, b)
153+
maxad(a::AbstractArray{<:Number}, b::AbstractArray{<:Number}) = Linfdist(a, b)
151154

152155

153156
# MSD: mean squared deviation
@@ -156,8 +159,7 @@ maxad(a::AbstractArray{T}, b::AbstractArray{T}) where {T<:Number} = Linfdist(a,
156159
157160
Return the mean squared deviation between two arrays: `mean(abs2, a - b)`.
158161
"""
159-
msd(a::AbstractArray{T}, b::AbstractArray{T}) where {T<:Number} =
160-
sqL2dist(a, b) / length(a)
162+
msd(a::AbstractArray{<:Number}, b::AbstractArray{<:Number}) = sqL2dist(a, b) / length(a)
161163

162164

163165
# RMSD: root mean squared deviation
@@ -168,13 +170,14 @@ Return the root mean squared deviation between two optionally
168170
normalized arrays. The root mean squared deviation is computed
169171
as `sqrt(msd(a, b))`.
170172
"""
171-
function rmsd(a::AbstractArray{T}, b::AbstractArray{T}; normalize::Bool=false) where T<:Number
173+
function rmsd(a::AbstractArray{<:Number}, b::AbstractArray{<:Number}; normalize::Bool=false)
172174
v = sqrt(msd(a, b))
173175
if normalize
174176
amin, amax = extrema(a)
175-
v /= (amax - amin)
177+
return v / (amax - amin)
178+
else
179+
return v
176180
end
177-
return v
178181
end
179182

180183

@@ -186,6 +189,6 @@ Compute the peak signal-to-noise ratio between two arrays `a` and `b`.
186189
`maxv` is the maximum possible value either array can take. The PSNR
187190
is computed as `10 * log10(maxv^2 / msd(a, b))`.
188191
"""
189-
function psnr(a::AbstractArray{T}, b::AbstractArray{T}, maxv::Real) where T<:Real
190-
20. * log10(maxv) - 10. * log10(msd(a, b))
192+
function psnr(a::AbstractArray{<:Real}, b::AbstractArray{<:Real}, maxv::Real)
193+
20 * log10(maxv) - 10 * log10(msd(a, b))
191194
end

0 commit comments

Comments
 (0)