@@ -20,6 +20,55 @@ function test_intrinsics_kernel(results)
2020 return
2121end
2222
23+ # Do NOT use this kernel as an example for your code.
24+ # It was written assuming one workgroup of size 32 and
25+ # is only valid for those
26+ function shfl_down_test_kernel (a, b)
27+ # This is not valid
28+ idx = KI. get_local_id (). x
29+
30+ temp = KI. localmemory (eltype (b), 32 )
31+ temp[idx] = a[idx]
32+
33+ KI. barrier ()
34+
35+ if idx == 1
36+ value = temp[idx]
37+
38+ value = value + KI. shfl_down (value, 16 )
39+ value = value + KI. shfl_down (value, 8 )
40+ value = value + KI. shfl_down (value, 4 )
41+ value = value + KI. shfl_down (value, 2 )
42+ value = value + KI. shfl_down (value, 1 )
43+
44+ b[idx] = value
45+ end
46+ return
47+ end
48+
49+ function shfl_down_test_kernell (a, b)
50+ idx = Metal. thread_position_in_grid (). x
51+ idx_in_simd = Metal. thread_index_in_simdgroup ()
52+ simd_idx = Metal. simdgroup_index_in_threadgroup ()
53+
54+ temp = Metal. MtlThreadGroupArray (eltype (b), 32 )
55+ temp[idx] = a[idx]
56+ Metal. simdgroup_barrier (Metal. MemoryFlagThreadGroup)
57+
58+ if idx == 1
59+ value = temp[idx_in_simd]
60+
61+ value = value + Metal. simd_shuffle_down (value, 16 )
62+ value = value + Metal. simd_shuffle_down (value, 8 )
63+ value = value + Metal. simd_shuffle_down (value, 4 )
64+ value = value + Metal. simd_shuffle_down (value, 2 )
65+ value = value + Metal. simd_shuffle_down (value, 1 )
66+
67+ b[idx] = value
68+ end
69+ return
70+ end
71+
2372function intrinsics_testsuite (backend, AT)
2473 @testset " KernelIntrinsics Tests" begin
2574 @testset " Launch parameters" begin
@@ -119,6 +168,18 @@ function intrinsics_testsuite(backend, AT)
119168 @test local_id_x == expected_local
120169 end
121170 end
171+ @testset " shfl_down(::$T )" for T in KI. shfl_down_types (backend ())
172+ a = zeros (T, 32 )
173+ rand! (a, (1 : 4 ))
174+
175+ dev_a = AT (a)
176+ dev_b = AT (zeros (T, 32 ))
177+
178+ KI. @kernel backend () workgroupsize= 32 shfl_down_test_kernel (dev_a, dev_b)
179+
180+ b = Array (dev_b)
181+ @test sum (a) ≈ b[1 ]
182+ end
122183 end
123184 return nothing
124185end
0 commit comments