diff --git a/bindings/python/src/_pyghex/unstructured/communication_object.cpp b/bindings/python/src/_pyghex/unstructured/communication_object.cpp index fe624d0e..66de9143 100644 --- a/bindings/python/src/_pyghex/unstructured/communication_object.cpp +++ b/bindings/python/src/_pyghex/unstructured/communication_object.cpp @@ -55,8 +55,9 @@ extract_cuda_stream(pybind11::object python_stream) throw pybind11::type_error(error.str()); } + //Currently there is only version 0. const auto protocol_version = cuda_stream_protocol[0].cast(); - if (protocol_version == 0) + if (protocol_version != 0) { std::stringstream error; error << "Expected `__cuda_stream__` protocol version 0, but got " diff --git a/test/bindings/python/test_unstructured_domain_descriptor.py b/test/bindings/python/test_unstructured_domain_descriptor.py index c39d2de3..844f8851 100644 --- a/test/bindings/python/test_unstructured_domain_descriptor.py +++ b/test/bindings/python/test_unstructured_domain_descriptor.py @@ -12,8 +12,20 @@ try: import cupy as cp + + # Mock to implement CUDA's Stream protocol: https://nvidia.github.io/cuda-python/cuda-core/latest/interoperability.html#cuda-stream-protocol + class CUDAStreamProtocolMock: + def __init__(self, *args, **kwargs): + self.cupy_stream = cp.cuda.Stream(*args, **kwargs) + + def __cuda_stream__(self): + return 0, self.cupy_stream.ptr + + STREAM_TYPES_TO_TEST = [None, cp.cuda.Stream, CUDAStreamProtocolMock] + except ImportError: cp = None + STREAM_TYPES_TO_TEST = [None] # Must be at least one element. import ghex from ghex.context import make_context @@ -217,6 +229,7 @@ @pytest.mark.parametrize("on_gpu", [True, False]) @pytest.mark.mpi def test_domain_descriptor(on_gpu, capsys, mpi_cart_comm, dtype): + # Does not uses streams. if on_gpu and cp is None: pytest.skip(reason="`CuPy` is not installed.") @@ -289,8 +302,9 @@ def check_field(data, order): @pytest.mark.parametrize("dtype", [np.float64, np.float32, np.int32, np.int64]) @pytest.mark.parametrize("on_gpu", [True, False]) +@pytest.mark.parametrize("stream_type", STREAM_TYPES_TO_TEST) @pytest.mark.mpi -def test_domain_descriptor_async(on_gpu, capsys, mpi_cart_comm, dtype): +def test_domain_descriptor_async(on_gpu, stream_type, capsys, mpi_cart_comm, dtype): if on_gpu: if cp is None: @@ -354,7 +368,7 @@ def check_field(data, order, stream): d1, f1 = make_field("C") d2, f2 = make_field("F") - stream = cp.cuda.Stream(non_blocking=True) if on_gpu else None + stream = None if stream_type is None else stream_type(non_blocking=True) handle = co.schedule_exchange(stream, [pattern(f1), pattern(f2)]) assert not co.has_scheduled_exchange()