11# GPUArrays.jl interface
22
3+ import KernelAbstractions
4+ import KernelAbstractions: Backend
35
46#
57# Device functionality
810
911# # execution
1012
11- struct oneArrayBackend <: AbstractGPUBackend end
12-
13- struct oneKernelContext <: AbstractKernelContext end
13+ struct oneArrayBackend <: Backend end
1414
1515@inline function GPUArrays. launch_heuristic (:: oneArrayBackend , f:: F , args:: Vararg{Any,N} ;
1616 elements:: Int , elements_per_thread:: Int ) where {F,N}
@@ -25,48 +25,6 @@ struct oneKernelContext <: AbstractKernelContext end
2525 return (threads= items, blocks= 32 )
2626end
2727
28- function GPUArrays. gpu_call (:: oneArrayBackend , f, args, threads:: Int , blocks:: Int ;
29- name:: Union{String,Nothing} )
30- @oneapi items= threads groups= blocks name= name f (oneKernelContext (), args... )
31- end
32-
33-
34- # # on-device
35-
36- # indexing
37-
38- GPUArrays. blockidx (ctx:: oneKernelContext ) = oneAPI. get_group_id (0 )
39- GPUArrays. blockdim (ctx:: oneKernelContext ) = oneAPI. get_local_size (0 )
40- GPUArrays. threadidx (ctx:: oneKernelContext ) = oneAPI. get_local_id (0 )
41- GPUArrays. griddim (ctx:: oneKernelContext ) = oneAPI. get_num_groups (0 )
42-
43- # math
44-
45- @inline GPUArrays. cos (ctx:: oneKernelContext , x) = oneAPI. cos (x)
46- @inline GPUArrays. sin (ctx:: oneKernelContext , x) = oneAPI. sin (x)
47- @inline GPUArrays. sqrt (ctx:: oneKernelContext , x) = oneAPI. sqrt (x)
48- @inline GPUArrays. log (ctx:: oneKernelContext , x) = oneAPI. log (x)
49-
50- # memory
51-
52- @inline function GPUArrays. LocalMemory (:: oneKernelContext , :: Type{T} , :: Val{dims} , :: Val{id}
53- ) where {T, dims, id}
54- ptr = oneAPI. emit_localmemory (Val (id), T, Val (prod (dims)))
55- oneDeviceArray (dims, LLVMPtr {T, onePI.AS.Local} (ptr))
56- end
57-
58- # synchronization
59-
60- @inline GPUArrays. synchronize_threads (:: oneKernelContext ) = oneAPI. barrier ()
61-
62-
63-
64- #
65- # Host abstractions
66- #
67-
68- GPUArrays. backend (:: Type{<:oneArray} ) = oneArrayBackend ()
69-
7028const GLOBAL_RNGs = Dict {ZeDevice,GPUArrays.RNG} ()
7129function GPUArrays. default_rng (:: Type{<:oneArray} )
7230 dev = device ()
0 commit comments