11/*
2- // Copyright (c) 2023 Ben Ashbaugh
2+ // Copyright (c) 2023-2025 Ben Ashbaugh
33//
44// SPDX-License-Identifier: MIT
55*/
@@ -183,6 +183,20 @@ static cl_svm_capabilities_khr getSystemSVMCaps(cl_device_id device)
183183 return ret;
184184}
185185
186+ struct SUSMFuncs
187+ {
188+ clHostMemAllocINTEL_fn clHostMemAllocINTEL;
189+ clDeviceMemAllocINTEL_fn clDeviceMemAllocINTEL;
190+ clSharedMemAllocINTEL_fn clSharedMemAllocINTEL;
191+ clMemFreeINTEL_fn clMemFreeINTEL;
192+ clMemBlockingFreeINTEL_fn clMemBlockingFreeINTEL;
193+ clGetMemAllocInfoINTEL_fn clGetMemAllocInfoINTEL;
194+ clSetKernelArgMemPointerINTEL_fn clSetKernelArgMemPointerINTEL;
195+ clEnqueueMemFillINTEL_fn clEnqueueMemFillINTEL;
196+ clEnqueueMemcpyINTEL_fn clEnqueueMemcpyINTEL;
197+ clEnqueueMemAdviseINTEL_fn clEnqueueMemAdviseINTEL;
198+ };
199+
186200struct SAllocInfo
187201{
188202 cl_uint TypeIndex = ~0 ;
@@ -219,6 +233,7 @@ struct SLayerContext
219233
220234 for (auto platform: platforms) {
221235 getSVMTypesForPlatform (platform);
236+ getUSMFuncsForPlatform (platform);
222237 }
223238 }
224239
@@ -232,6 +247,11 @@ struct SLayerContext
232247 return TypeCapsDevice[device];
233248 }
234249
250+ const SUSMFuncs& getUSMFuncs (cl_platform_id platform)
251+ {
252+ return USMFuncs[platform];
253+ }
254+
235255 bool isKnownAlloc (cl_context context, const void * ptr) const
236256 {
237257 if (AllocMaps.find (context) != AllocMaps.end ()) {
@@ -435,9 +455,47 @@ struct SLayerContext
435455 }
436456 }
437457
458+ void getUSMFuncsForPlatform (cl_platform_id platform)
459+ {
460+ SUSMFuncs& funcs = USMFuncs[platform];
461+
462+ funcs.clHostMemAllocINTEL = (clHostMemAllocINTEL_fn)
463+ g_pNextDispatch->clGetExtensionFunctionAddressForPlatform (
464+ platform, " clHostMemAllocINTEL" );
465+ funcs.clDeviceMemAllocINTEL = (clDeviceMemAllocINTEL_fn)
466+ g_pNextDispatch->clGetExtensionFunctionAddressForPlatform (
467+ platform, " clDeviceMemAllocINTEL" );
468+ funcs.clSharedMemAllocINTEL = (clSharedMemAllocINTEL_fn)
469+ g_pNextDispatch->clGetExtensionFunctionAddressForPlatform (
470+ platform, " clSharedMemAllocINTEL" );
471+ funcs.clMemFreeINTEL = (clMemFreeINTEL_fn)
472+ g_pNextDispatch->clGetExtensionFunctionAddressForPlatform (
473+ platform, " clMemFreeINTEL" );
474+ funcs.clMemBlockingFreeINTEL = (clMemBlockingFreeINTEL_fn)
475+ g_pNextDispatch->clGetExtensionFunctionAddressForPlatform (
476+ platform, " clMemBlockingFreeINTEL" );
477+ funcs.clGetMemAllocInfoINTEL = (clGetMemAllocInfoINTEL_fn)
478+ g_pNextDispatch->clGetExtensionFunctionAddressForPlatform (
479+ platform, " clGetMemAllocInfoINTEL" );
480+ funcs.clSetKernelArgMemPointerINTEL = (clSetKernelArgMemPointerINTEL_fn)
481+ g_pNextDispatch->clGetExtensionFunctionAddressForPlatform (
482+ platform, " clSetKernelArgMemPointerINTEL" );
483+ funcs.clEnqueueMemFillINTEL = (clEnqueueMemFillINTEL_fn)
484+ g_pNextDispatch->clGetExtensionFunctionAddressForPlatform (
485+ platform, " clEnqueueMemFillINTEL" );
486+ funcs.clEnqueueMemcpyINTEL = (clEnqueueMemcpyINTEL_fn)
487+ g_pNextDispatch->clGetExtensionFunctionAddressForPlatform (
488+ platform, " clEnqueueMemcpyINTEL" );
489+ funcs.clEnqueueMemAdviseINTEL = (clEnqueueMemAdviseINTEL_fn)
490+ g_pNextDispatch->clGetExtensionFunctionAddressForPlatform (
491+ platform, " clEnqueueMemAdviseINTEL" );
492+ }
493+
438494 std::map<cl_platform_id, std::vector<cl_svm_capabilities_khr>> TypeCapsPlatform;
439495 std::map<cl_device_id, std::vector<cl_svm_capabilities_khr>> TypeCapsDevice;
440496
497+ std::map<cl_platform_id, SUSMFuncs> USMFuncs;
498+
441499 typedef std::map<const void *, SAllocInfo> CAllocMap;
442500 std::map<cl_context, CAllocMap> AllocMaps;
443501
@@ -597,6 +655,7 @@ void* CL_API_CALL clSVMAllocWithPropertiesKHR_EMU(
597655 cl_int* errcode_ret)
598656{
599657 cl_platform_id platform = getPlatform (context);
658+ const auto & USMFuncs = getLayerContext ().getUSMFuncs (platform);
600659
601660 const auto & typeCapsPlatform = getLayerContext ().getSVMCaps (platform);
602661 if (svm_type_index >= typeCapsPlatform.size ()) {
@@ -626,7 +685,7 @@ void* CL_API_CALL clSVMAllocWithPropertiesKHR_EMU(
626685 const auto caps = typeCapsPlatform[svm_type_index];
627686 if ((caps & CL_SVM_TYPE_MACRO_DEVICE_KHR) == CL_SVM_TYPE_MACRO_DEVICE_KHR) {
628687 isUSMPointer = true ;
629- ret = clDeviceMemAllocINTEL (
688+ ret = USMFuncs. clDeviceMemAllocINTEL (
630689 context,
631690 device,
632691 nullptr ,
@@ -636,7 +695,7 @@ void* CL_API_CALL clSVMAllocWithPropertiesKHR_EMU(
636695 }
637696 else if ((caps & CL_SVM_TYPE_MACRO_HOST_KHR) == CL_SVM_TYPE_MACRO_HOST_KHR) {
638697 isUSMPointer = true ;
639- ret = clHostMemAllocINTEL (
698+ ret = USMFuncs. clHostMemAllocINTEL (
640699 context,
641700 nullptr ,
642701 size,
@@ -653,7 +712,7 @@ void* CL_API_CALL clSVMAllocWithPropertiesKHR_EMU(
653712 }
654713 else if ((caps & CL_SVM_TYPE_MACRO_SINGLE_DEVICE_SHARED_KHR) == CL_SVM_TYPE_MACRO_SINGLE_DEVICE_SHARED_KHR) {
655714 isUSMPointer = true ;
656- ret = clSharedMemAllocINTEL (
715+ ret = USMFuncs. clSharedMemAllocINTEL (
657716 context,
658717 device,
659718 nullptr ,
@@ -739,7 +798,9 @@ cl_int CL_API_CALL clSVMFreeWithPropertiesKHR_EMU(
739798
740799 cl_int errorCode = CL_SUCCESS;
741800 if (isUSMPtr (context, ptr)) {
742- errorCode = clMemBlockingFreeINTEL (
801+ cl_platform_id platform = getPlatform (context);
802+ const auto & USMFuncs = getLayerContext ().getUSMFuncs (platform);
803+ errorCode = USMFuncs.clMemBlockingFreeINTEL (
743804 context,
744805 ptr);
745806 } else if (isSVMPtr (context, ptr)) {
@@ -1191,7 +1252,9 @@ cl_int CL_API_CALL clSetKernelArgSVMPointer_override(
11911252 cl_context context = getContext (kernel);
11921253
11931254 if (isUSMPtr (context, arg_value)) {
1194- return clSetKernelArgMemPointerINTEL (
1255+ cl_platform_id platform = getPlatform (context);
1256+ const auto & USMFuncs = getLayerContext ().getUSMFuncs (platform);
1257+ return USMFuncs.clSetKernelArgMemPointerINTEL (
11951258 kernel,
11961259 arg_index,
11971260 arg_value);
@@ -1284,7 +1347,9 @@ void CL_API_CALL clSVMFree_override(
12841347 void * ptr)
12851348{
12861349 if (isUSMPtr (context, ptr)) {
1287- clMemFreeINTEL (context, ptr);
1350+ cl_platform_id platform = getPlatform (context);
1351+ const auto & USMFuncs = getLayerContext ().getUSMFuncs (platform);
1352+ USMFuncs.clMemFreeINTEL (context, ptr);
12881353 } else {
12891354 g_pNextDispatch->clSVMFree (context, ptr);
12901355 }
@@ -1351,7 +1416,9 @@ cl_int CL_API_CALL clEnqueueSVMMemcpy_override(
13511416 }
13521417
13531418 if (isUSMPtr (context, dst_ptr) || isUSMPtr (context, src_ptr)) {
1354- cl_int ret = clEnqueueMemcpyINTEL (
1419+ cl_platform_id platform = getPlatform (context);
1420+ const auto & USMFuncs = getLayerContext ().getUSMFuncs (platform);
1421+ cl_int ret = USMFuncs.clEnqueueMemcpyINTEL (
13551422 command_queue,
13561423 blocking_copy,
13571424 dst_ptr,
@@ -1407,7 +1474,9 @@ cl_int CL_API_CALL clEnqueueSVMMemFill_override(
14071474 }
14081475
14091476 if (isUSMPtr (context, svm_ptr)) {
1410- cl_int ret = clEnqueueMemFillINTEL (
1477+ cl_platform_id platform = getPlatform (context);
1478+ const auto & USMFuncs = getLayerContext ().getUSMFuncs (platform);
1479+ cl_int ret = USMFuncs.clEnqueueMemFillINTEL (
14111480 command_queue,
14121481 svm_ptr,
14131482 pattern,
0 commit comments