Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/Metal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ using Artifacts
using ObjectiveC, .CoreFoundation, .Foundation, .Dispatch, .OS

include("version.jl")
import KernelAbstractions as KA

# core library
include("../lib/mtl/MTL.jl")
Expand Down Expand Up @@ -60,13 +61,13 @@ include("mapreduce.jl")
include("accumulate.jl")
include("indexing.jl")
include("random.jl")
include("gpuarrays.jl")

# KernelAbstractions
include("MetalKernels.jl")
import .MetalKernels: MetalBackend
import .MetalKernels: MetalBackend, KA.launch_config
export MetalBackend

include("gpuarrays.jl")
include("deprecated.jl")

include("precompile.jl")
Expand Down
55 changes: 12 additions & 43 deletions src/gpuarrays.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
## GPUArrays interfaces

## execution
GPUArrays.device(x::MtlArray) = x.dev

struct mtlArrayBackend <: AbstractGPUBackend end
import KernelAbstractions
import KernelAbstractions: Backend

struct mtlKernelContext <: AbstractKernelContext end
@inline function GPUArrays.launch_heuristic(::MetalBackend, obj::O, args::Vararg{Any,N};
elements::Int, elements_per_thread::Int) where {O,N}

@inline function GPUArrays.launch_heuristic(::mtlArrayBackend, f::F, args::Vararg{Any,N};
elements::Int, elements_per_thread::Int) where {F,N}
kernel = @metal launch=false f(mtlKernelContext(), args...)
ndrange = ceil(Int, elements / elements_per_thread)
ndrange, workgroupsize, iterspace, dynamic = KA.launch_config(obj, ndrange,
nothing)

ctx = KA.mkcontext(obj, ndrange, iterspace)

kernel = @metal launch=false obj.f(ctx, args...)

# The pipeline state automatically computes occupancy stats
threads = min(elements, kernel.pipeline.maxTotalThreadsPerThreadgroup)
Expand All @@ -17,43 +23,6 @@ struct mtlKernelContext <: AbstractKernelContext end
return (; threads=Int(threads), blocks=Int(blocks))
end

function GPUArrays.gpu_call(::mtlArrayBackend, f, args, threads::Int, groups::Int;
name::Union{String,Nothing})
@metal threads groups name f(mtlKernelContext(), args...)
end


## on-device

# indexing
GPUArrays.blockidx(ctx::mtlKernelContext) = threadgroup_position_in_grid_1d()
GPUArrays.blockdim(ctx::mtlKernelContext) = threads_per_threadgroup_1d()
GPUArrays.threadidx(ctx::mtlKernelContext) = thread_position_in_threadgroup_1d()
GPUArrays.griddim(ctx::mtlKernelContext) = threadgroups_per_grid_1d()
GPUArrays.global_index(ctx::mtlKernelContext) = thread_position_in_grid_1d()
GPUArrays.global_size(ctx::mtlKernelContext) = threads_per_grid_1d()

# memory

@inline function GPUArrays.LocalMemory(::mtlKernelContext, ::Type{T}, ::Val{dims}, ::Val{id}
) where {T, dims, id}
ptr = emit_threadgroup_memory(T, Val(prod(dims)))
MtlDeviceArray(dims, ptr)
end

# synchronization

@inline GPUArrays.synchronize_threads(::mtlKernelContext) =
threadgroup_barrier(MemoryFlagThreadGroup)



#
# Host abstractions
#

GPUArrays.backend(::Type{<:MtlArray}) = mtlArrayBackend()

const GLOBAL_RNGs = Dict{MTLDevice,GPUArrays.RNG}()
function GPUArrays.default_rng(::Type{<:MtlArray})
dev = device()
Expand Down