Skip to content

Commit 17c7234

Browse files
christiangnrdpxl-th
andcommitted
shfl_down intrinsics
Co-Authored-By: Anton Smirnov <tonysmn97@gmail.com>
1 parent d052112 commit 17c7234

File tree

2 files changed

+87
-0
lines changed

2 files changed

+87
-0
lines changed

src/intrinsics.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,32 @@ Declare memory that is local to a workgroup.
116116
"""
117117
localmemory(::Type{T}, dims) where {T} = localmemory(T, Val(dims))
118118

119+
"""
120+
shfl_down(val::T, offset::Integer)::T where T
121+
122+
Read `val` from a lane with higher id given by `offset`.
123+
124+
!!! note
125+
Backend implementations **must** implement:
126+
```
127+
@device_override shfl_down(val::T, offset::Integer)::T where T
128+
```
129+
As well as the on-device functionality.
130+
"""
131+
function shfl_down end
132+
133+
"""
134+
shfl_down_types(::Backend)::Vector{DataType}
135+
136+
Returns a vector of `DataType`s supported on `backend`
137+
138+
!!! note
139+
Backend implementations **must** implement this function
140+
only if they support `shfl_down` for any types.
141+
"""
142+
shfl_down_types(::Backend) = DataType[]
143+
144+
119145
"""
120146
barrier()
121147

test/intrinsics.jl

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,55 @@ function test_intrinsics_kernel(results)
2020
return
2121
end
2222

23+
# Do NOT use this kernel as an example for your code.
24+
# It was written assuming one workgroup of size 32 and
25+
# is only valid for those
26+
function shfl_down_test_kernel(a, b)
27+
# This is not valid
28+
idx = KI.get_local_id().x
29+
30+
temp = KI.localmemory(eltype(b), 32)
31+
temp[idx] = a[idx]
32+
33+
KI.barrier()
34+
35+
if idx == 1
36+
value = temp[idx]
37+
38+
value = value + KI.shfl_down(value, 16)
39+
value = value + KI.shfl_down(value, 8)
40+
value = value + KI.shfl_down(value, 4)
41+
value = value + KI.shfl_down(value, 2)
42+
value = value + KI.shfl_down(value, 1)
43+
44+
b[idx] = value
45+
end
46+
return
47+
end
48+
49+
function shfl_down_test_kernell(a, b)
50+
idx = Metal.thread_position_in_grid().x
51+
idx_in_simd = Metal.thread_index_in_simdgroup()
52+
simd_idx = Metal.simdgroup_index_in_threadgroup()
53+
54+
temp = Metal.MtlThreadGroupArray(eltype(b), 32)
55+
temp[idx] = a[idx]
56+
Metal.simdgroup_barrier(Metal.MemoryFlagThreadGroup)
57+
58+
if idx == 1
59+
value = temp[idx_in_simd]
60+
61+
value = value + Metal.simd_shuffle_down(value, 16)
62+
value = value + Metal.simd_shuffle_down(value, 8)
63+
value = value + Metal.simd_shuffle_down(value, 4)
64+
value = value + Metal.simd_shuffle_down(value, 2)
65+
value = value + Metal.simd_shuffle_down(value, 1)
66+
67+
b[idx] = value
68+
end
69+
return
70+
end
71+
2372
function intrinsics_testsuite(backend, AT)
2473
@testset "KernelIntrinsics Tests" begin
2574
@testset "Launch parameters" begin
@@ -119,6 +168,18 @@ function intrinsics_testsuite(backend, AT)
119168
@test local_id_x == expected_local
120169
end
121170
end
171+
@testset "shfl_down(::$T)" for T in KI.shfl_down_types(backend())
172+
a = zeros(T, 32)
173+
rand!(a, (1:4))
174+
175+
dev_a = AT(a)
176+
dev_b = AT(zeros(T, 32))
177+
178+
KI.@kernel backend() workgroupsize=32 shfl_down_test_kernel(dev_a, dev_b)
179+
180+
b = Array(dev_b)
181+
@test sum(a) b[1]
182+
end
122183
end
123184
return nothing
124185
end

0 commit comments

Comments
 (0)