Skip to content

Commit d7fe515

Browse files
committed
add fixes for a few more unified SVM corner cases
1 parent c876e52 commit d7fe515

File tree

3 files changed

+65
-0
lines changed

3 files changed

+65
-0
lines changed

layers/99_svmplusplus/emulate.cpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1174,6 +1174,24 @@ cl_int CL_API_CALL clSetKernelExecInfo_override(
11741174
ret = (check == CL_SUCCESS) ? CL_SUCCESS : ret;
11751175
return check;
11761176
}
1177+
case CL_KERNEL_EXEC_INFO_SVM_PTRS:
1178+
{
1179+
const void* const* svmPtrs = (const void* const*)param_value;
1180+
const size_t numPtrs = param_value_size / sizeof(void*);
1181+
1182+
std::vector<const void*> nonNullPtrs;
1183+
for (size_t i = 0; i < numPtrs; ++i) {
1184+
if (svmPtrs[i] != nullptr) {
1185+
nonNullPtrs.push_back(svmPtrs[i]);
1186+
}
1187+
}
1188+
1189+
return g_pNextDispatch->clSetKernelExecInfo(
1190+
kernel,
1191+
CL_KERNEL_EXEC_INFO_SVM_PTRS,
1192+
nonNullPtrs.size() * sizeof(void*),
1193+
nonNullPtrs.empty() ? nullptr : nonNullPtrs.data());
1194+
}
11771195
default: break;
11781196
}
11791197

@@ -1195,6 +1213,38 @@ void CL_API_CALL clSVMFree_override(
11951213
}
11961214
}
11971215

1216+
cl_int CL_API_CALL clEnqueueSVMFree_override(
1217+
cl_command_queue command_queue,
1218+
cl_uint num_svm_pointers,
1219+
void* svm_pointers[],
1220+
void (CL_CALLBACK* pfn_free_func)(
1221+
cl_command_queue queue,
1222+
cl_uint num_svm_pointers,
1223+
void* svm_pointers[],
1224+
void* user_data),
1225+
void* user_data,
1226+
cl_uint num_events_in_wait_list,
1227+
const cl_event* event_wait_list,
1228+
cl_event* event)
1229+
{
1230+
std::vector<void*> nonNullPtrs;
1231+
for (cl_uint i = 0; i < num_svm_pointers; ++i) {
1232+
if (svm_pointers[i] != nullptr) {
1233+
nonNullPtrs.push_back(svm_pointers[i]);
1234+
}
1235+
}
1236+
1237+
return g_pNextDispatch->clEnqueueSVMFree(
1238+
command_queue,
1239+
static_cast<cl_uint>(nonNullPtrs.size()),
1240+
nonNullPtrs.empty() ? nullptr : nonNullPtrs.data(),
1241+
pfn_free_func,
1242+
user_data,
1243+
num_events_in_wait_list,
1244+
event_wait_list,
1245+
event);
1246+
}
1247+
11981248
cl_int CL_API_CALL clEnqueueSVMMemcpy_override(
11991249
cl_command_queue command_queue,
12001250
cl_bool blocking_copy,

layers/99_svmplusplus/emulate.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,20 @@ void CL_API_CALL clSVMFree_override(
7575
cl_context context,
7676
void* ptr);
7777

78+
cl_int CL_API_CALL clEnqueueSVMFree_override(
79+
cl_command_queue command_queue,
80+
cl_uint num_svm_pointers,
81+
void* svm_pointers[],
82+
void (CL_CALLBACK* pfn_free_func)(
83+
cl_command_queue queue,
84+
cl_uint num_svm_pointers,
85+
void* svm_pointers[],
86+
void* user_data),
87+
void* user_data,
88+
cl_uint num_events_in_wait_list,
89+
const cl_event* event_wait_list,
90+
cl_event* event);
91+
7892
cl_int CL_API_CALL clEnqueueSVMMemcpy_override(
7993
cl_command_queue command_queue,
8094
cl_bool blocking_copy,

layers/99_svmplusplus/main.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ static void _init_dispatch()
5959
dispatch.clSetKernelArgSVMPointer = clSetKernelArgSVMPointer_override;
6060
dispatch.clSetKernelExecInfo = clSetKernelExecInfo_override;
6161
dispatch.clSVMFree = clSVMFree_override;
62+
dispatch.clEnqueueSVMFree = clEnqueueSVMFree_override;
6263
dispatch.clEnqueueSVMMemcpy = clEnqueueSVMMemcpy_override;
6364
dispatch.clEnqueueSVMMemFill = clEnqueueSVMMemFill_override;
6465
dispatch.clEnqueueSVMMigrateMem = clEnqueueSVMMigrateMem_override;

0 commit comments

Comments
 (0)