diff --git a/src/Metal.jl b/src/Metal.jl index 08eba6039..9cef4709c 100644 --- a/src/Metal.jl +++ b/src/Metal.jl @@ -14,6 +14,7 @@ using Artifacts using ObjectiveC, .CoreFoundation, .Foundation, .Dispatch, .OS include("version.jl") +import KernelAbstractions as KA # core library include("../lib/mtl/MTL.jl") @@ -60,13 +61,13 @@ include("mapreduce.jl") include("accumulate.jl") include("indexing.jl") include("random.jl") -include("gpuarrays.jl") # KernelAbstractions include("MetalKernels.jl") -import .MetalKernels: MetalBackend +import .MetalKernels: MetalBackend, KA.launch_config export MetalBackend +include("gpuarrays.jl") include("deprecated.jl") include("precompile.jl") diff --git a/src/gpuarrays.jl b/src/gpuarrays.jl index d8aaae548..84725512b 100644 --- a/src/gpuarrays.jl +++ b/src/gpuarrays.jl @@ -1,14 +1,20 @@ ## GPUArrays interfaces -## execution +GPUArrays.device(x::MtlArray) = x.dev -struct mtlArrayBackend <: AbstractGPUBackend end +import KernelAbstractions +import KernelAbstractions: Backend -struct mtlKernelContext <: AbstractKernelContext end +@inline function GPUArrays.launch_heuristic(::MetalBackend, obj::O, args::Vararg{Any,N}; + elements::Int, elements_per_thread::Int) where {O,N} -@inline function GPUArrays.launch_heuristic(::mtlArrayBackend, f::F, args::Vararg{Any,N}; - elements::Int, elements_per_thread::Int) where {F,N} - kernel = @metal launch=false f(mtlKernelContext(), args...) + ndrange = ceil(Int, elements / elements_per_thread) + ndrange, workgroupsize, iterspace, dynamic = KA.launch_config(obj, ndrange, + nothing) + + ctx = KA.mkcontext(obj, ndrange, iterspace) + + kernel = @metal launch=false obj.f(ctx, args...) # The pipeline state automatically computes occupancy stats threads = min(elements, kernel.pipeline.maxTotalThreadsPerThreadgroup) @@ -17,43 +23,6 @@ struct mtlKernelContext <: AbstractKernelContext end return (; threads=Int(threads), blocks=Int(blocks)) end -function GPUArrays.gpu_call(::mtlArrayBackend, f, args, threads::Int, groups::Int; - name::Union{String,Nothing}) - @metal threads groups name f(mtlKernelContext(), args...) -end - - -## on-device - -# indexing -GPUArrays.blockidx(ctx::mtlKernelContext) = threadgroup_position_in_grid_1d() -GPUArrays.blockdim(ctx::mtlKernelContext) = threads_per_threadgroup_1d() -GPUArrays.threadidx(ctx::mtlKernelContext) = thread_position_in_threadgroup_1d() -GPUArrays.griddim(ctx::mtlKernelContext) = threadgroups_per_grid_1d() -GPUArrays.global_index(ctx::mtlKernelContext) = thread_position_in_grid_1d() -GPUArrays.global_size(ctx::mtlKernelContext) = threads_per_grid_1d() - -# memory - -@inline function GPUArrays.LocalMemory(::mtlKernelContext, ::Type{T}, ::Val{dims}, ::Val{id} - ) where {T, dims, id} - ptr = emit_threadgroup_memory(T, Val(prod(dims))) - MtlDeviceArray(dims, ptr) -end - -# synchronization - -@inline GPUArrays.synchronize_threads(::mtlKernelContext) = - threadgroup_barrier(MemoryFlagThreadGroup) - - - -# -# Host abstractions -# - -GPUArrays.backend(::Type{<:MtlArray}) = mtlArrayBackend() - const GLOBAL_RNGs = Dict{MTLDevice,GPUArrays.RNG}() function GPUArrays.default_rng(::Type{<:MtlArray}) dev = device()