@@ -2310,14 +2310,14 @@ ur_mem_flags_t AccessModeToUr(access::mode AccessorMode) {
2310
2310
}
2311
2311
}
2312
2312
2313
- // Gets UR argument struct for a given kernel and device based on the argument
2314
- // type. Refactored from SetKernelParamsAndLaunch to allow it to be used in
2315
- // the graphs extension (LaunchWithArgs for graphs is planned future work).
2316
- static void GetUrArgsBasedOnType (
2313
+ // Sets arguments for a given kernel and device based on the argument type.
2314
+ // Refactored from SetKernelParamsAndLaunch to allow it to be used in the graphs
2315
+ // extension.
2316
+ static void SetArgBasedOnType (
2317
+ adapter_impl &Adapter, ur_kernel_handle_t Kernel,
2317
2318
device_image_impl *DeviceImageImpl,
2318
2319
const std::function<void *(Requirement *Req)> &getMemAllocationFunc,
2319
- context_impl &ContextImpl, detail::ArgDesc &Arg, size_t NextTrueIndex,
2320
- std::vector<ur_exp_kernel_arg_properties_t> &UrArgs) {
2320
+ context_impl &ContextImpl, detail::ArgDesc &Arg, size_t NextTrueIndex) {
2321
2321
switch (Arg.MType ) {
2322
2322
case kernel_param_kind_t ::kind_dynamic_work_group_memory:
2323
2323
break ;
@@ -2337,61 +2337,52 @@ static void GetUrArgsBasedOnType(
2337
2337
getMemAllocationFunc
2338
2338
? reinterpret_cast <ur_mem_handle_t >(getMemAllocationFunc (Req))
2339
2339
: nullptr ;
2340
- ur_exp_kernel_arg_value_t Value = {};
2341
- Value.memObjTuple = {MemArg, AccessModeToUr (Req->MAccessMode )};
2342
- UrArgs.push_back ({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr ,
2343
- UR_EXP_KERNEL_ARG_TYPE_MEM_OBJ,
2344
- static_cast <uint32_t >(NextTrueIndex), sizeof (MemArg),
2345
- Value});
2340
+ ur_kernel_arg_mem_obj_properties_t MemObjData{};
2341
+ MemObjData.stype = UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES;
2342
+ MemObjData.memoryAccess = AccessModeToUr (Req->MAccessMode );
2343
+ Adapter.call <UrApiKind::urKernelSetArgMemObj>(Kernel, NextTrueIndex,
2344
+ &MemObjData, MemArg);
2346
2345
break ;
2347
2346
}
2348
2347
case kernel_param_kind_t ::kind_std_layout: {
2349
- ur_exp_kernel_arg_type_t Type;
2350
2348
if (Arg.MPtr ) {
2351
- Type = UR_EXP_KERNEL_ARG_TYPE_VALUE;
2349
+ Adapter.call <UrApiKind::urKernelSetArgValue>(
2350
+ Kernel, NextTrueIndex, Arg.MSize , nullptr , Arg.MPtr );
2352
2351
} else {
2353
- Type = UR_EXP_KERNEL_ARG_TYPE_LOCAL;
2352
+ Adapter.call <UrApiKind::urKernelSetArgLocal>(Kernel, NextTrueIndex,
2353
+ Arg.MSize , nullptr );
2354
2354
}
2355
- ur_exp_kernel_arg_value_t Value = {};
2356
- Value.value = {Arg.MPtr };
2357
- UrArgs.push_back ({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr ,
2358
- Type, static_cast <uint32_t >(NextTrueIndex),
2359
- static_cast <size_t >(Arg.MSize ), Value});
2360
2355
2361
2356
break ;
2362
2357
}
2363
2358
case kernel_param_kind_t ::kind_sampler: {
2364
2359
sampler *SamplerPtr = (sampler *)Arg.MPtr ;
2365
- ur_exp_kernel_arg_value_t Value = {};
2366
- Value.sampler = (ur_sampler_handle_t )detail::getSyclObjImpl (*SamplerPtr)
2367
- ->getOrCreateSampler (ContextImpl);
2368
- UrArgs.push_back ({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr ,
2369
- UR_EXP_KERNEL_ARG_TYPE_SAMPLER,
2370
- static_cast <uint32_t >(NextTrueIndex),
2371
- sizeof (ur_sampler_handle_t ), Value});
2360
+ ur_sampler_handle_t Sampler =
2361
+ (ur_sampler_handle_t )detail::getSyclObjImpl (*SamplerPtr)
2362
+ ->getOrCreateSampler (ContextImpl);
2363
+ Adapter.call <UrApiKind::urKernelSetArgSampler>(Kernel, NextTrueIndex,
2364
+ nullptr , Sampler);
2372
2365
break ;
2373
2366
}
2374
2367
case kernel_param_kind_t ::kind_pointer: {
2375
- ur_exp_kernel_arg_value_t Value = {};
2376
- // We need to de-rerence to get the actual USM allocation - that's the
2368
+ // We need to de-rerence this to get the actual USM allocation - that's the
2377
2369
// pointer UR is expecting.
2378
- Value.pointer = *static_cast <void *const *>(Arg.MPtr );
2379
- UrArgs.push_back ({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr ,
2380
- UR_EXP_KERNEL_ARG_TYPE_POINTER,
2381
- static_cast <uint32_t >(NextTrueIndex), sizeof (Arg.MPtr ),
2382
- Value});
2370
+ const void *Ptr = *static_cast <const void *const *>(Arg.MPtr );
2371
+ Adapter.call <UrApiKind::urKernelSetArgPointer>(Kernel, NextTrueIndex,
2372
+ nullptr , Ptr);
2383
2373
break ;
2384
2374
}
2385
2375
case kernel_param_kind_t ::kind_specialization_constants_buffer: {
2386
2376
assert (DeviceImageImpl != nullptr );
2387
2377
ur_mem_handle_t SpecConstsBuffer =
2388
2378
DeviceImageImpl->get_spec_const_buffer_ref ();
2389
- ur_exp_kernel_arg_value_t Value = {};
2390
- Value.memObjTuple = {SpecConstsBuffer, UR_MEM_FLAG_READ_ONLY};
2391
- UrArgs.push_back ({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr ,
2392
- UR_EXP_KERNEL_ARG_TYPE_MEM_OBJ,
2393
- static_cast <uint32_t >(NextTrueIndex),
2394
- sizeof (SpecConstsBuffer), Value});
2379
+
2380
+ ur_kernel_arg_mem_obj_properties_t MemObjProps{};
2381
+ MemObjProps.pNext = nullptr ;
2382
+ MemObjProps.stype = UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES;
2383
+ MemObjProps.memoryAccess = UR_MEM_FLAG_READ_ONLY;
2384
+ Adapter.call <UrApiKind::urKernelSetArgMemObj>(
2385
+ Kernel, NextTrueIndex, &MemObjProps, SpecConstsBuffer);
2395
2386
break ;
2396
2387
}
2397
2388
case kernel_param_kind_t ::kind_invalid:
@@ -2424,32 +2415,22 @@ static ur_result_t SetKernelParamsAndLaunch(
2424
2415
DeviceImageImpl ? DeviceImageImpl->get_spec_const_blob_ref () : Empty);
2425
2416
}
2426
2417
2427
- std::vector<ur_exp_kernel_arg_properties_t > UrArgs;
2428
- UrArgs.reserve (Args.size ());
2429
-
2430
2418
if (KernelFuncPtr && !KernelHasSpecialCaptures) {
2431
- auto setFunc = [&UrArgs ,
2419
+ auto setFunc = [&Adapter, Kernel ,
2432
2420
KernelFuncPtr](const detail::kernel_param_desc_t &ParamDesc,
2433
2421
size_t NextTrueIndex) {
2434
2422
const void *ArgPtr = (const char *)KernelFuncPtr + ParamDesc.offset ;
2435
2423
switch (ParamDesc.kind ) {
2436
2424
case kernel_param_kind_t ::kind_std_layout: {
2437
2425
int Size = ParamDesc.info ;
2438
- ur_exp_kernel_arg_value_t Value = {};
2439
- Value.value = ArgPtr;
2440
- UrArgs.push_back ({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr ,
2441
- UR_EXP_KERNEL_ARG_TYPE_VALUE,
2442
- static_cast <uint32_t >(NextTrueIndex),
2443
- static_cast <size_t >(Size), Value});
2426
+ Adapter.call <UrApiKind::urKernelSetArgValue>(Kernel, NextTrueIndex,
2427
+ Size, nullptr , ArgPtr);
2444
2428
break ;
2445
2429
}
2446
2430
case kernel_param_kind_t ::kind_pointer: {
2447
- ur_exp_kernel_arg_value_t Value = {};
2448
- Value.pointer = *static_cast <const void *const *>(ArgPtr);
2449
- UrArgs.push_back ({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr ,
2450
- UR_EXP_KERNEL_ARG_TYPE_POINTER,
2451
- static_cast <uint32_t >(NextTrueIndex),
2452
- sizeof (Value.pointer ), Value});
2431
+ const void *Ptr = *static_cast <const void *const *>(ArgPtr);
2432
+ Adapter.call <UrApiKind::urKernelSetArgPointer>(Kernel, NextTrueIndex,
2433
+ nullptr , Ptr);
2453
2434
break ;
2454
2435
}
2455
2436
default :
@@ -2459,10 +2440,10 @@ static ur_result_t SetKernelParamsAndLaunch(
2459
2440
applyFuncOnFilteredArgs (EliminatedArgMask, KernelNumArgs,
2460
2441
KernelParamDescGetter, setFunc);
2461
2442
} else {
2462
- auto setFunc = [&DeviceImageImpl, &getMemAllocationFunc , &Queue ,
2463
- &UrArgs ](detail::ArgDesc &Arg, size_t NextTrueIndex) {
2464
- GetUrArgsBasedOnType ( DeviceImageImpl, getMemAllocationFunc,
2465
- Queue.getContextImpl (), Arg, NextTrueIndex, UrArgs );
2443
+ auto setFunc = [&Adapter, Kernel, &DeviceImageImpl , &getMemAllocationFunc ,
2444
+ &Queue ](detail::ArgDesc &Arg, size_t NextTrueIndex) {
2445
+ SetArgBasedOnType (Adapter, Kernel, DeviceImageImpl, getMemAllocationFunc,
2446
+ Queue.getContextImpl (), Arg, NextTrueIndex);
2466
2447
};
2467
2448
applyFuncOnFilteredArgs (EliminatedArgMask, Args, setFunc);
2468
2449
}
@@ -2475,12 +2456,8 @@ static ur_result_t SetKernelParamsAndLaunch(
2475
2456
// CUDA-style local memory setting. Note that we may have -1 as a position,
2476
2457
// this indicates the buffer is actually unused and was elided.
2477
2458
if (ImplicitLocalArg.has_value () && ImplicitLocalArg.value () != -1 ) {
2478
- UrArgs.push_back ({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES,
2479
- nullptr ,
2480
- UR_EXP_KERNEL_ARG_TYPE_LOCAL,
2481
- static_cast <uint32_t >(ImplicitLocalArg.value ()),
2482
- WorkGroupMemorySize,
2483
- {nullptr }});
2459
+ Adapter.call <UrApiKind::urKernelSetArgLocal>(
2460
+ Kernel, ImplicitLocalArg.value (), WorkGroupMemorySize, nullptr );
2484
2461
}
2485
2462
2486
2463
adjustNDRangePerKernel (NDRDesc, Kernel, Queue.getDeviceImpl ());
@@ -2538,104 +2515,20 @@ static ur_result_t SetKernelParamsAndLaunch(
2538
2515
{{WorkGroupMemorySize}}});
2539
2516
}
2540
2517
ur_event_handle_t UREvent = nullptr ;
2541
- ur_result_t Error =
2542
- Adapter.call_nocheck <UrApiKind::urEnqueueKernelLaunchWithArgsExp>(
2543
- Queue.getHandleRef (), Kernel, NDRDesc.Dims ,
2544
- HasOffset ? &NDRDesc.GlobalOffset [0 ] : nullptr ,
2545
- &NDRDesc.GlobalSize [0 ], LocalSize, UrArgs.size (), UrArgs.data (),
2546
- property_list.size (),
2547
- property_list.empty () ? nullptr : property_list.data (),
2548
- RawEvents.size (), RawEvents.empty () ? nullptr : &RawEvents[0 ],
2549
- OutEventImpl ? &UREvent : nullptr );
2518
+ ur_result_t Error = Adapter.call_nocheck <UrApiKind::urEnqueueKernelLaunch>(
2519
+ Queue.getHandleRef (), Kernel, NDRDesc.Dims ,
2520
+ HasOffset ? &NDRDesc.GlobalOffset [0 ] : nullptr , &NDRDesc.GlobalSize [0 ],
2521
+ LocalSize, property_list.size (),
2522
+ property_list.empty () ? nullptr : property_list.data (), RawEvents.size (),
2523
+ RawEvents.empty () ? nullptr : &RawEvents[0 ],
2524
+ OutEventImpl ? &UREvent : nullptr );
2550
2525
if (Error == UR_RESULT_SUCCESS && OutEventImpl) {
2551
2526
OutEventImpl->setHandle (UREvent);
2552
2527
}
2553
2528
2554
2529
return Error;
2555
2530
}
2556
2531
2557
- // Sets arguments for a given kernel and device based on the argument type.
2558
- // This is a legacy path which the graphs extension still uses.
2559
- static void SetArgBasedOnType (
2560
- adapter_impl &Adapter, ur_kernel_handle_t Kernel,
2561
- device_image_impl *DeviceImageImpl,
2562
- const std::function<void *(Requirement *Req)> &getMemAllocationFunc,
2563
- context_impl &ContextImpl, detail::ArgDesc &Arg, size_t NextTrueIndex) {
2564
- switch (Arg.MType ) {
2565
- case kernel_param_kind_t ::kind_dynamic_work_group_memory:
2566
- break ;
2567
- case kernel_param_kind_t ::kind_work_group_memory:
2568
- break ;
2569
- case kernel_param_kind_t ::kind_stream:
2570
- break ;
2571
- case kernel_param_kind_t ::kind_dynamic_accessor:
2572
- case kernel_param_kind_t ::kind_accessor: {
2573
- Requirement *Req = (Requirement *)(Arg.MPtr );
2574
-
2575
- // getMemAllocationFunc is nullptr when there are no requirements. However,
2576
- // we may pass default constructed accessors to a command, which don't add
2577
- // requirements. In such case, getMemAllocationFunc is nullptr, but it's a
2578
- // valid case, so we need to properly handle it.
2579
- ur_mem_handle_t MemArg =
2580
- getMemAllocationFunc
2581
- ? reinterpret_cast <ur_mem_handle_t >(getMemAllocationFunc (Req))
2582
- : nullptr ;
2583
- ur_kernel_arg_mem_obj_properties_t MemObjData{};
2584
- MemObjData.stype = UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES;
2585
- MemObjData.memoryAccess = AccessModeToUr (Req->MAccessMode );
2586
- Adapter.call <UrApiKind::urKernelSetArgMemObj>(Kernel, NextTrueIndex,
2587
- &MemObjData, MemArg);
2588
- break ;
2589
- }
2590
- case kernel_param_kind_t ::kind_std_layout: {
2591
- if (Arg.MPtr ) {
2592
- Adapter.call <UrApiKind::urKernelSetArgValue>(
2593
- Kernel, NextTrueIndex, Arg.MSize , nullptr , Arg.MPtr );
2594
- } else {
2595
- Adapter.call <UrApiKind::urKernelSetArgLocal>(Kernel, NextTrueIndex,
2596
- Arg.MSize , nullptr );
2597
- }
2598
-
2599
- break ;
2600
- }
2601
- case kernel_param_kind_t ::kind_sampler: {
2602
- sampler *SamplerPtr = (sampler *)Arg.MPtr ;
2603
- ur_sampler_handle_t Sampler =
2604
- (ur_sampler_handle_t )detail::getSyclObjImpl (*SamplerPtr)
2605
- ->getOrCreateSampler (ContextImpl);
2606
- Adapter.call <UrApiKind::urKernelSetArgSampler>(Kernel, NextTrueIndex,
2607
- nullptr , Sampler);
2608
- break ;
2609
- }
2610
- case kernel_param_kind_t ::kind_pointer: {
2611
- // We need to de-rerence this to get the actual USM allocation - that's the
2612
- // pointer UR is expecting.
2613
- const void *Ptr = *static_cast <const void *const *>(Arg.MPtr );
2614
- Adapter.call <UrApiKind::urKernelSetArgPointer>(Kernel, NextTrueIndex,
2615
- nullptr , Ptr);
2616
- break ;
2617
- }
2618
- case kernel_param_kind_t ::kind_specialization_constants_buffer: {
2619
- assert (DeviceImageImpl != nullptr );
2620
- ur_mem_handle_t SpecConstsBuffer =
2621
- DeviceImageImpl->get_spec_const_buffer_ref ();
2622
-
2623
- ur_kernel_arg_mem_obj_properties_t MemObjProps{};
2624
- MemObjProps.pNext = nullptr ;
2625
- MemObjProps.stype = UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES;
2626
- MemObjProps.memoryAccess = UR_MEM_FLAG_READ_ONLY;
2627
- Adapter.call <UrApiKind::urKernelSetArgMemObj>(
2628
- Kernel, NextTrueIndex, &MemObjProps, SpecConstsBuffer);
2629
- break ;
2630
- }
2631
- case kernel_param_kind_t ::kind_invalid:
2632
- throw sycl::exception (sycl::make_error_code (sycl::errc::runtime),
2633
- " Invalid kernel param kind " +
2634
- codeToString (UR_RESULT_ERROR_INVALID_VALUE));
2635
- break ;
2636
- }
2637
- }
2638
-
2639
2532
static std::tuple<ur_kernel_handle_t , device_image_impl *,
2640
2533
const KernelArgMask *>
2641
2534
getCGKernelInfo (const CGExecKernel &CommandGroup, context_impl &ContextImpl,
0 commit comments