Skip to content

Commit 9aaa671

Browse files
authored
Fix #1051: Make _graph.py compatible with cuda-python==12.6.* and fix tests (#1236)
1 parent f9df16f commit 9aaa671

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

cuda_core/cuda/core/experimental/_graph.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,10 @@ def complete(self, options: GraphCompleteOptions | None = None) -> Graph:
318318
raise RuntimeError(
319319
"Instantiation for device launch failed due to the nodes belonging to different contexts."
320320
)
321-
elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_CONDITIONAL_HANDLE_UNUSED:
321+
elif (
322+
_py_major_minor >= (12, 8)
323+
and params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_CONDITIONAL_HANDLE_UNUSED
324+
):
322325
raise RuntimeError("One or more conditional handles are not associated with conditional builders.")
323326
elif params.result_out != driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_SUCCESS:
324327
raise RuntimeError(f"Graph instantiation failed with unexpected error code: {params.result_out}")

cuda_core/tests/test_graph.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@ def test_graph_conditional_if_else(init_cuda, condition_value):
304304
try:
305305
gb_if, gb_else = gb.if_else(handle)
306306
except RuntimeError as e:
307-
with pytest.raises(RuntimeError, match="^Driver version"):
307+
with pytest.raises(RuntimeError, match="^(Driver|Binding) version"):
308308
raise e
309309
gb.end_building()
310310
b.close()
@@ -377,7 +377,7 @@ def test_graph_conditional_switch(init_cuda, condition_value):
377377
try:
378378
gb_case = list(gb.switch(handle, 3))
379379
except RuntimeError as e:
380-
with pytest.raises(RuntimeError, match="^Driver version"):
380+
with pytest.raises(RuntimeError, match="^(Driver|Binding) version"):
381381
raise e
382382
gb.end_building()
383383
b.close()
@@ -568,7 +568,7 @@ def build_graph(condition_value):
568568
try:
569569
gb_case = list(gb.switch(handle, 3))
570570
except Exception as e:
571-
with pytest.raises(RuntimeError, match="^Driver version"):
571+
with pytest.raises(RuntimeError, match="^(Driver|Binding) version"):
572572
raise e
573573
gb.end_building()
574574
raise e
@@ -599,7 +599,7 @@ def build_graph(condition_value):
599599
try:
600600
graph_variants = [build_graph(0), build_graph(1), build_graph(2)]
601601
except Exception as e:
602-
with pytest.raises(RuntimeError, match="^Driver version"):
602+
with pytest.raises(RuntimeError, match="^(Driver|Binding) version"):
603603
raise e
604604
b.close()
605605
pytest.skip("Driver does not support conditional switch")

0 commit comments

Comments
 (0)