Skip to content

Commit 855cc97

Browse files
authored
Merge pull request #410 from JeffFessler/master
Fix `isequal` and `==`: add size check, support missing values.
1 parent fc0d327 commit 855cc97

File tree

3 files changed

+53
-6
lines changed

3 files changed

+53
-6
lines changed

src/host/mapreduce.jl

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ Base.mapreducedim!(f, op, R::AnyGPUArray, A::Broadcast.Broadcasted) = mapreduced
1212

1313
neutral_element(op, T) =
1414
error("""GPUArrays.jl needs to know the neutral element for your operator `$op`.
15-
Please pass it as an explicit argument to (if possible), or register it
16-
globally your operator by defining `GPUArrays.neutral_element(::typeof($op), T)`.""")
15+
Please pass it as an explicit argument to `GPUArrays.mapreducedim!`,
16+
or register it globally by defining `GPUArrays.neutral_element(::typeof($op), T)`.""")
1717
neutral_element(::typeof(Base.:(|)), T) = zero(T)
1818
neutral_element(::typeof(Base.:(+)), T) = zero(T)
1919
neutral_element(::typeof(Base.add_sum), T) = zero(T)
@@ -64,7 +64,7 @@ function _mapreduce(f::F, op::OP, As::Vararg{Any,N}; dims::D, init) where {F,OP,
6464
R = similar(bc, ET, red)
6565

6666
if prod(sz) == 0
67-
R .= init
67+
fill!(R, init)
6868
else
6969
mapreducedim!(identity, op, R, bc; init=init)
7070
end
@@ -85,9 +85,6 @@ Base.all(f::Function, A::AnyGPUArray) = mapreduce(f, &, A)
8585
Base.count(pred::Function, A::AnyGPUArray; dims=:, init=0) =
8686
mapreduce(pred, Base.add_sum, A; init=init, dims=dims)
8787

88-
Base.:(==)(A::AnyGPUArray, B::AnyGPUArray) = Bool(mapreduce(==, &, A, B))
89-
Base.isequal(A::AnyGPUArray, B::AnyGPUArray) = mapreduce(isequal, &, A, B)
90-
9188
# avoid calling into `initarray!`
9289
for (fname, op) in [(:sum, :(Base.add_sum)), (:prod, :(Base.mul_prod)),
9390
(:maximum, :(Base.max)), (:minimum, :(Base.min)),
@@ -100,3 +97,40 @@ for (fname, op) in [(:sum, :(Base.add_sum)), (:prod, :(Base.mul_prod)),
10097
end
10198

10299
LinearAlgebra.ishermitian(A::AbstractGPUMatrix) = mapreduce(==, &, A, adjoint(A))
100+
101+
102+
# comparisons
103+
104+
# ignores missing
105+
function Base.isequal(A::AnyGPUArray, B::AnyGPUArray)
106+
if A === B return true end
107+
if axes(A) != axes(B)
108+
return false
109+
end
110+
mapreduce(isequal, &, A, B; init=true)
111+
end
112+
113+
# returns `missing` when missing values are involved
114+
function Base.:(==)(A::AnyGPUArray, B::AnyGPUArray)
115+
if axes(A) != axes(B)
116+
return false
117+
end
118+
119+
function mapper(a, b)
120+
eq = (a == b)
121+
if ismissing(eq)
122+
(; is_missing=true, is_equal=#=don't care=#false)
123+
else
124+
(; is_missing=false, is_equal=eq)
125+
end
126+
end
127+
function reducer(a, b)
128+
if a.is_missing || b.is_missing
129+
(; is_missing=true, is_equal=#=don't care=#false)
130+
else
131+
(; is_missing=false, is_equal=a.is_equal & b.is_equal)
132+
end
133+
end
134+
res = mapreduce(mapper, reducer, A, B; init=(; is_missing=false, is_equal=true))
135+
res.is_missing ? missing : res.is_equal
136+
end

test/testsuite.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ struct ArrayAdaptor{AT} end
1818
Adapt.adapt_storage(::ArrayAdaptor{AT}, xs::AbstractArray) where {AT} = AT(xs)
1919

2020
test_result(a::Number, b::Number; kwargs...) = (a, b; kwargs...)
21+
test_result(a::Missing, b::Missing; kwargs...) = true
22+
test_result(a::Number, b::Missing; kwargs...) = false
23+
test_result(a::Missing, b::Number; kwargs...) = false
2124
function test_result(a::AbstractArray{T}, b::AbstractArray{T}; kwargs...) where {T<:Number}
2225
(collect(a), collect(b); kwargs...)
2326
end

test/testsuite/reductions.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,12 @@ end
159159
@testsuite "reductions/== isequal" (AT, eltypes)->begin
160160
@testset "$ET" for ET in eltypes
161161
range = ET <: Real ? (ET(1):ET(10)) : ET
162+
163+
# different sizes should trip up both (CUDA.jl#1524)
164+
@test compare((A, B) -> A == B, AT, rand(range, (2,3)), rand(range, 6))
165+
@test compare((A, B) -> isequal(A, B), AT, rand(range, (2,3)), rand(range, 6))
166+
167+
# equal sizes depend on values
162168
for sz in [(10,), (10,10), (10,10,10), (0,)]
163169
@test compare((A, B) -> A == B, AT, rand(range, sz), rand(range, sz))
164170
@test compare((A, B) -> isequal(A, B), AT, rand(range, sz), rand(range, sz))
@@ -179,5 +185,9 @@ end
179185
@test compare((A, B) -> isequal(A, B), AT, Ac, Bc)
180186
end
181187
end
188+
189+
# missing values should only trip up ==
190+
@test compare((A, B) -> A == B, AT, [missing], [missing])
191+
@test compare((A, B) -> isequal(A, B), AT, [missing], [missing])
182192
end
183193
end

0 commit comments

Comments
 (0)