Skip to content

Commit f52c71a

Browse files
authored
[NUMBA][CUDA] Add numba_debug flag. (#1216)
* [NUMBA][CUDA] Add numba_debug flag. Adds a flag numba_debug. When used with debug flag enables debug support in numba cuda. * removed description * Added filter to test for cuda 13.2 * removed white space * removed redundant ProgramOption test * removed unused var * fixed comment
1 parent b9c76b3 commit f52c71a

File tree

2 files changed

+28
-1
lines changed

2 files changed

+28
-1
lines changed

cuda_core/cuda/core/experimental/_program.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,7 @@ class ProgramOptions:
298298
split_compile: int | None = None
299299
fdevice_syntax_only: bool | None = None
300300
minimal: bool | None = None
301+
numba_debug: bool | None = None # Custom option for Numba debugging
301302

302303
def __post_init__(self):
303304
self._name = self.name.encode()
@@ -418,6 +419,8 @@ def __post_init__(self):
418419
self._formatted_options.append("--fdevice-syntax-only")
419420
if self.minimal is not None and self.minimal:
420421
self._formatted_options.append("--minimal")
422+
if self.numba_debug:
423+
self._formatted_options.append("--numba-debug")
421424

422425
def _as_bytes(self):
423426
# TODO: allow tuples once NVIDIA/cuda-python#72 is resolved

cuda_core/tests/test_program.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,29 @@ def _is_nvvm_available():
3131
)
3232

3333
try:
34-
from cuda.core.experimental._utils.cuda_utils import driver, handle_return
34+
from cuda.core.experimental._utils.cuda_utils import driver, handle_return, nvrtc
3535

3636
_cuda_driver_version = handle_return(driver.cuDriverGetVersion())
3737
except Exception:
3838
_cuda_driver_version = 0
3939

40+
41+
def _get_nvrtc_version_for_tests():
42+
"""
43+
Get NVRTC version.
44+
45+
Returns:
46+
int: Version in format major * 1000 + minor * 100 (e.g., 13200 for CUDA 13.2)
47+
None: If NVRTC is not available
48+
"""
49+
try:
50+
nvrtc_major, nvrtc_minor = handle_return(nvrtc.nvrtcVersion())
51+
version = nvrtc_major * 1000 + nvrtc_minor * 100
52+
return version
53+
except Exception:
54+
return None
55+
56+
4057
_libnvvm_version = None
4158
_libnvvm_version_attempted = False
4259

@@ -176,6 +193,13 @@ def ptx_code_object():
176193
[
177194
ProgramOptions(name="abc"),
178195
ProgramOptions(device_code_optimize=True, debug=True),
196+
pytest.param(
197+
ProgramOptions(debug=True, numba_debug=True),
198+
marks=pytest.mark.skipif(
199+
(_get_nvrtc_version_for_tests() or 0) < 13200,
200+
reason="numba_debug requires NVRTC >= 13.2",
201+
),
202+
),
179203
ProgramOptions(relocatable_device_code=True, max_register_count=32),
180204
ProgramOptions(ftz=True, prec_sqrt=False, prec_div=False),
181205
ProgramOptions(fma=False, use_fast_math=True),

0 commit comments

Comments
 (0)