@@ -28,8 +28,14 @@ neutral_element(::typeof(Base._extrema_rf), ::Type{<:NTuple{2,T}}) where {T} = t
2828# resolve ambiguities
2929Base. mapreduce (f, op, A:: AnyGPUArray , As:: AbstractArrayOrBroadcasted... ;
3030 dims= :, init= nothing ) = _mapreduce (f, op, A, As... ; dims= dims, init= init)
31+ # dims=:, init=nothing) = AK._mapreduce(f, op, A, As...; dims=dims, init=init)
3132Base. mapreduce (f, op, A:: Broadcast.Broadcasted{<:AbstractGPUArrayStyle} , As:: AbstractArrayOrBroadcasted... ;
3233 dims= :, init= nothing ) = _mapreduce (f, op, A, As... ; dims= dims, init= init)
34+ # dims=:, init=nothing) = AK.mapreduce(f, op, #_mapreduce(f, op, A, As...; dims=dims, init=init)
35+ Base. mapreduce (f, op, A:: AnyGPUArray ;
36+ dims= :, init= nothing ) = AK. mapreduce (f, op, A; init, dims= dims isa Colon ? nothing : dims)
37+ Base. mapreduce (f, op, A:: Broadcast.Broadcasted{<:AbstractGPUArrayStyle} ;
38+ dims= :, init= nothing ) = AK. mapreduce (f, op, A; init, dims= dims isa Colon ? nothing : dims)
3339
3440function _mapreduce (f:: F , op:: OP , As:: Vararg{Any,N} ; dims:: D , init) where {F,OP,N,D}
3541 # figure out the destination container type by looking at the initializer element,
@@ -85,14 +91,14 @@ function _mapreduce(f::F, op::OP, As::Vararg{Any,N}; dims::D, init) where {F,OP,
8591 end
8692end
8793
88- Base. any (A:: AnyGPUArray{Bool} ) = mapreduce (identity, | , A)
89- Base. all (A:: AnyGPUArray{Bool} ) = mapreduce (identity, & , A)
94+ Base. any (A:: AnyGPUArray{Bool} ) = AK . any (identity, A)
95+ Base. all (A:: AnyGPUArray{Bool} ) = AK . all (identity, A)
9096
91- Base. any (f:: Function , A:: AnyGPUArray ) = mapreduce (f, | , A)
92- Base. all (f:: Function , A:: AnyGPUArray ) = mapreduce (f, & , A)
97+ Base. any (f:: Function , A:: AnyGPUArray ) = AK . any (f , A)
98+ Base. all (f:: Function , A:: AnyGPUArray ) = AK . all (f , A)
9399
94100Base. count (pred:: Function , A:: AnyGPUArray ; dims= :, init= 0 ) =
95- mapreduce (pred, Base. add_sum, A; init= init, dims= dims)
101+ AK . count (pred, A; init, dims = dims isa Colon ? nothing : dims) # mapreduce(pred, Base.add_sum, A; init=init, dims=dims)
96102
97103# avoid calling into `initarray!`
98104for (fname, op) in [(:sum , :(Base. add_sum)), (:prod , :(Base. mul_prod)),
0 commit comments