@@ -12,8 +12,8 @@ Base.mapreducedim!(f, op, R::AnyGPUArray, A::Broadcast.Broadcasted) = mapreduced
1212
1313neutral_element (op, T) =
1414 error (""" GPUArrays.jl needs to know the neutral element for your operator `$op `.
15- Please pass it as an explicit argument to (if possible), or register it
16- globally your operator by defining `GPUArrays.neutral_element(::typeof($op ), T)`.""" )
15+ Please pass it as an explicit argument to `GPUArrays.mapreducedim!`,
16+ or register it globally by defining `GPUArrays.neutral_element(::typeof($op ), T)`.""" )
1717neutral_element (:: typeof (Base.:(| )), T) = zero (T)
1818neutral_element (:: typeof (Base.:(+ )), T) = zero (T)
1919neutral_element (:: typeof (Base. add_sum), T) = zero (T)
@@ -64,7 +64,7 @@ function _mapreduce(f::F, op::OP, As::Vararg{Any,N}; dims::D, init) where {F,OP,
6464 R = similar (bc, ET, red)
6565
6666 if prod (sz) == 0
67- R . = init
67+ fill! (R, init)
6868 else
6969 mapreducedim! (identity, op, R, bc; init= init)
7070 end
@@ -85,9 +85,6 @@ Base.all(f::Function, A::AnyGPUArray) = mapreduce(f, &, A)
8585Base. count (pred:: Function , A:: AnyGPUArray ; dims= :, init= 0 ) =
8686 mapreduce (pred, Base. add_sum, A; init= init, dims= dims)
8787
88- Base.:(== )(A:: AnyGPUArray , B:: AnyGPUArray ) = Bool (mapreduce (== , & , A, B))
89- Base. isequal (A:: AnyGPUArray , B:: AnyGPUArray ) = mapreduce (isequal, & , A, B)
90-
9188# avoid calling into `initarray!`
9289for (fname, op) in [(:sum , :(Base. add_sum)), (:prod , :(Base. mul_prod)),
9390 (:maximum , :(Base. max)), (:minimum , :(Base. min)),
@@ -100,3 +97,40 @@ for (fname, op) in [(:sum, :(Base.add_sum)), (:prod, :(Base.mul_prod)),
10097end
10198
10299LinearAlgebra. ishermitian (A:: AbstractGPUMatrix ) = mapreduce (== , & , A, adjoint (A))
100+
101+
102+ # comparisons
103+
104+ # ignores missing
105+ function Base. isequal (A:: AnyGPUArray , B:: AnyGPUArray )
106+ if A === B return true end
107+ if axes (A) != axes (B)
108+ return false
109+ end
110+ mapreduce (isequal, & , A, B; init= true )
111+ end
112+
113+ # returns `missing` when missing values are involved
114+ function Base.:(== )(A:: AnyGPUArray , B:: AnyGPUArray )
115+ if axes (A) != axes (B)
116+ return false
117+ end
118+
119+ function mapper (a, b)
120+ eq = (a == b)
121+ if ismissing (eq)
122+ (; is_missing= true , is_equal= #= don't care=# false )
123+ else
124+ (; is_missing= false , is_equal= eq)
125+ end
126+ end
127+ function reducer (a, b)
128+ if a. is_missing || b. is_missing
129+ (; is_missing= true , is_equal= #= don't care=# false )
130+ else
131+ (; is_missing= false , is_equal= a. is_equal & b. is_equal)
132+ end
133+ end
134+ res = mapreduce (mapper, reducer, A, B; init= (; is_missing= false , is_equal= true ))
135+ res. is_missing ? missing : res. is_equal
136+ end
0 commit comments