1+ /*
2+ // Copyright (c) 2022-2024 Ben Ashbaugh
3+ //
4+ // SPDX-License-Identifier: MIT
5+ */
6+
7+ #include < CL/cl.h>
8+ #include < CL/cl_layer.h>
9+
10+ #include < algorithm>
11+ #include < array>
12+ #include < atomic>
13+ #include < string>
14+ #include < vector>
15+
16+ #include " layer_util.hpp"
17+
18+ #include " emulate.h"
19+ #include " subgroups.cl.h"
20+
21+ const cl_uint g_NVDeviceVendorID = 0x10DE ;
22+
23+ static inline bool isNV (cl_device_id device)
24+ {
25+ cl_uint deviceVendorID = 0 ;
26+ g_pNextDispatch->clGetDeviceInfo (
27+ device,
28+ CL_DEVICE_VENDOR_ID,
29+ sizeof (deviceVendorID),
30+ &deviceVendorID,
31+ nullptr );
32+ return deviceVendorID == g_NVDeviceVendorID;
33+ }
34+
35+ static inline bool isNV (cl_program program)
36+ {
37+ cl_uint numDevices = 0 ;
38+ g_pNextDispatch->clGetProgramInfo (
39+ program,
40+ CL_PROGRAM_NUM_DEVICES,
41+ sizeof (numDevices),
42+ &numDevices,
43+ nullptr );
44+
45+ std::vector<cl_device_id> devices (numDevices);
46+ g_pNextDispatch->clGetProgramInfo (
47+ program,
48+ CL_PROGRAM_DEVICES,
49+ numDevices * sizeof (cl_device_id),
50+ devices.data (),
51+ nullptr );
52+
53+ return std::all_of (
54+ devices.begin (),
55+ devices.end (),
56+ [](cl_device_id device) { return isNV (device); });
57+ }
58+
59+ static inline bool isNV (cl_context context)
60+ {
61+ cl_uint numDevices = 0 ;
62+ g_pNextDispatch->clGetContextInfo (
63+ context,
64+ CL_CONTEXT_NUM_DEVICES,
65+ sizeof (numDevices),
66+ &numDevices,
67+ nullptr );
68+
69+ std::vector<cl_device_id> devices (numDevices);
70+ g_pNextDispatch->clGetContextInfo (
71+ context,
72+ CL_CONTEXT_DEVICES,
73+ numDevices * sizeof (cl_device_id),
74+ devices.data (),
75+ nullptr );
76+
77+ return std::all_of (
78+ devices.begin (),
79+ devices.end (),
80+ [](cl_device_id device) { return isNV (device); });
81+ }
82+
83+ cl_int clBuildProgram_override (
84+ cl_program program,
85+ cl_uint num_devices,
86+ const cl_device_id* device_list,
87+ const char * options,
88+ void (CL_CALLBACK* pfn_notify)(cl_program program, void * user_data),
89+ void* user_data,
90+ cl_int errorCode)
91+ {
92+ if (!isNV (program)) {
93+ return errorCode;
94+ }
95+
96+ // TODO: add the -D define for emulated semaphores and retry build.
97+
98+ return errorCode;
99+ }
100+
101+ cl_program clCreateProgramWithSource_override (
102+ cl_context context,
103+ cl_uint count,
104+ const char ** strings,
105+ const size_t * lengths,
106+ cl_int* errcode_ret)
107+ {
108+ if (!isNV (context)) {
109+ return g_pNextDispatch->clCreateProgramWithSource (
110+ context,
111+ count,
112+ strings,
113+ lengths,
114+ errcode_ret);
115+ }
116+
117+ if (count == 0 || strings == nullptr ) {
118+ if (errcode_ret != nullptr ) {
119+ *errcode_ret = CL_INVALID_VALUE;
120+ }
121+ return nullptr ;
122+ }
123+
124+ std::vector<const char *> newStrings;
125+ newStrings.reserve (count + 1 );
126+ newStrings.insert (newStrings.end (), g_NVSubGroupString);
127+ newStrings.insert (newStrings.end (), strings, strings + count);
128+
129+ std::vector<size_t > newLengths;
130+ if (lengths != nullptr ) {
131+ newLengths.reserve (count + 1 );
132+ newLengths.insert (newLengths.end (), 0 ); // g_NVSubGroupString is nul-terminated
133+ newLengths.insert (newLengths.end (), lengths, lengths + count);
134+ }
135+
136+ return g_pNextDispatch->clCreateProgramWithSource (
137+ context,
138+ count + 1 ,
139+ newStrings.data (),
140+ newLengths.size () ? newLengths.data () : nullptr ,
141+ errcode_ret);
142+ }
143+
144+
145+ cl_int clGetDeviceInfo_override (
146+ cl_device_id device,
147+ cl_device_info param_name,
148+ size_t param_value_size,
149+ void * param_value,
150+ size_t * param_value_size_ret,
151+ cl_int errorCode)
152+ {
153+ cl_uint deviceVendorID = 0 ;
154+ g_pNextDispatch->clGetDeviceInfo (
155+ device,
156+ CL_DEVICE_VENDOR_ID,
157+ sizeof (deviceVendorID),
158+ &deviceVendorID,
159+ nullptr );
160+ if (deviceVendorID != g_NVDeviceVendorID) {
161+ return errorCode;
162+ }
163+
164+ switch (param_name) {
165+ case CL_DEVICE_MAX_NUM_SUB_GROUPS:
166+ {
167+ size_t maxWorkGroupSize = 0 ;
168+ g_pNextDispatch->clGetDeviceInfo (
169+ device,
170+ CL_DEVICE_MAX_WORK_GROUP_SIZE,
171+ sizeof (maxWorkGroupSize),
172+ &maxWorkGroupSize,
173+ nullptr );
174+
175+ size_t warpSize = 0 ;
176+ g_pNextDispatch->clGetDeviceInfo (
177+ device,
178+ CL_DEVICE_WARP_SIZE_NV,
179+ sizeof (warpSize),
180+ &warpSize,
181+ nullptr );
182+
183+ cl_uint maxNumSubGroups =
184+ static_cast <cl_uint>(maxWorkGroupSize / warpSize);
185+ return writeParamToMemory (
186+ param_value_size,
187+ maxNumSubGroups,
188+ param_value_size_ret,
189+ (cl_uint*)param_value);
190+ }
191+ break ;
192+ default :
193+ break ;
194+ }
195+
196+ return errorCode;
197+ }
198+
199+ cl_int clGetKernelSubGroupInfo_override (
200+ cl_kernel kernel,
201+ cl_device_id device,
202+ cl_kernel_sub_group_info param_name,
203+ size_t input_value_size,
204+ const void * input_value,
205+ size_t param_value_size,
206+ void * param_value,
207+ size_t * param_value_size_ret,
208+ cl_int errorCode)
209+ {
210+ cl_uint deviceVendorID = 0 ;
211+ g_pNextDispatch->clGetDeviceInfo (
212+ device,
213+ CL_DEVICE_VENDOR_ID,
214+ sizeof (deviceVendorID),
215+ &deviceVendorID,
216+ nullptr );
217+ if (deviceVendorID != g_NVDeviceVendorID) {
218+ return errorCode;
219+ }
220+
221+ switch (param_name) {
222+ case CL_KERNEL_MAX_SUB_GROUP_SIZE_FOR_NDRANGE:
223+ if (input_value == nullptr || input_value_size % sizeof (size_t ) != 0 ) {
224+ return CL_INVALID_VALUE;
225+ }
226+ {
227+ size_t warpSize = 0 ;
228+ g_pNextDispatch->clGetDeviceInfo (
229+ device,
230+ CL_DEVICE_WARP_SIZE_NV,
231+ sizeof (warpSize),
232+ &warpSize,
233+ nullptr );
234+
235+ return writeParamToMemory (
236+ param_value_size,
237+ warpSize,
238+ param_value_size_ret,
239+ (size_t *)param_value);
240+ }
241+ break ;
242+ case CL_KERNEL_SUB_GROUP_COUNT_FOR_NDRANGE:
243+ if (input_value == nullptr || input_value_size % sizeof (size_t ) != 0 ) {
244+ return CL_INVALID_VALUE;
245+ }
246+ {
247+ const size_t dim = input_value_size / sizeof (size_t );
248+ size_t workGroupSize = 1 ;
249+ for (size_t i = 0 ; i < dim; ++i) {
250+ workGroupSize *= ((size_t *)input_value)[i];
251+ }
252+
253+ size_t warpSize = 0 ;
254+ g_pNextDispatch->clGetDeviceInfo (
255+ device,
256+ CL_DEVICE_WARP_SIZE_NV,
257+ sizeof (warpSize),
258+ &warpSize,
259+ nullptr );
260+
261+ size_t numSubGroups = (workGroupSize + warpSize - 1 ) / warpSize;
262+ return writeParamToMemory (
263+ param_value_size,
264+ numSubGroups,
265+ param_value_size_ret,
266+ (size_t *)param_value);
267+ }
268+ break ;
269+ case CL_KERNEL_LOCAL_SIZE_FOR_SUB_GROUP_COUNT:
270+ if (input_value == nullptr || input_value_size != sizeof (size_t )) {
271+ return CL_INVALID_VALUE;
272+ }
273+ {
274+ }
275+ break ;
276+ case CL_KERNEL_MAX_NUM_SUB_GROUPS:
277+ {
278+ size_t maxWorkGroupSize = 0 ;
279+ g_pNextDispatch->clGetKernelWorkGroupInfo (
280+ kernel,
281+ device,
282+ CL_KERNEL_WORK_GROUP_SIZE,
283+ sizeof (maxWorkGroupSize),
284+ &maxWorkGroupSize,
285+ nullptr );
286+
287+ size_t warpSize = 0 ;
288+ g_pNextDispatch->clGetDeviceInfo (
289+ device,
290+ CL_DEVICE_WARP_SIZE_NV,
291+ sizeof (warpSize),
292+ &warpSize,
293+ nullptr );
294+
295+ size_t maxNumSubGroups = maxWorkGroupSize / warpSize;
296+ return writeParamToMemory (
297+ param_value_size,
298+ maxNumSubGroups,
299+ param_value_size_ret,
300+ (size_t *)param_value);
301+ }
302+ break ;
303+ case CL_KERNEL_COMPILE_NUM_SUB_GROUPS:
304+ // Not sure how to implement this one...
305+ break ;
306+ default :
307+ break ;
308+ }
309+
310+ return errorCode;
311+ }
0 commit comments