@@ -117,28 +117,14 @@ function Base.map!(f, dest::AnyGPUArray, xs::AbstractArray...)
117117 end
118118
119119 # grid-stride kernel
120- @kernel function map_kernel (dest, bc, nelem, common_length)
121-
122- j = 0
123- J = @index (Global, Linear)
124- for i in 1 : nelem
125- j += 1
126- if j <= common_length
127-
128- J_c = CartesianIndices (axes (bc))[(J- 1 )* nelem + j]
129- @inbounds dest[J_c] = bc[J_c]
130- end
131- end
120+ @kernel function map_kernel (dest, bc)
121+ j = @index (Global, Linear)
122+ @inbounds dest[j] = bc[j]
132123 end
133- elements = common_length
134- elements_per_thread = typemax (Int)
124+
135125 kernel = map_kernel (get_backend (dest))
136- heuristic = launch_heuristic (get_backend (dest), kernel, dest, bc, 1 ,
137- common_length; elements, elements_per_thread)
138- config = launch_configuration (get_backend (dest), heuristic;
139- elements, elements_per_thread)
140- kernel (dest, bc, config. elements_per_thread,
141- common_length; ndrange = config. threads)
126+ config = KernelAbstractions. launch_config (kernel, common_length, nothing )
127+ kernel (dest, bc; ndrange = config[1 ], workgroupsize = config[2 ])
142128
143129 if eltype (dest) <: BrokenBroadcast
144130 throw (ArgumentError (" Map operation resulting in $(eltype (eltype (dest))) is not GPU compatible" ))
0 commit comments