Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,9 @@ extract_cuda_stream(pybind11::object python_stream)
throw pybind11::type_error(error.str());
}

//Currently there is only version 0.
const auto protocol_version = cuda_stream_protocol[0].cast<std::size_t>();
if (protocol_version == 0)
if (protocol_version != 0)
{
std::stringstream error;
error << "Expected `__cuda_stream__` protocol version 0, but got "
Expand Down
18 changes: 16 additions & 2 deletions test/bindings/python/test_unstructured_domain_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,20 @@

try:
import cupy as cp

# Mock to implement CUDA's Stream protocol: https://nvidia.github.io/cuda-python/cuda-core/latest/interoperability.html#cuda-stream-protocol
class CUDAStreamProtocolMock:
def __init__(self, *args, **kwargs):
self.cupy_stream = cp.cuda.Stream(*args, **kwargs)

def __cuda_stream__(self):
return 0, self.cupy_stream.ptr

STREAM_TYPES_TO_TEST = [None, cp.cuda.Stream, CUDAStreamProtocolMock]

except ImportError:
cp = None
STREAM_TYPES_TO_TEST = [None] # Must be at least one element.

import ghex
from ghex.context import make_context
Expand Down Expand Up @@ -217,6 +229,7 @@
@pytest.mark.parametrize("on_gpu", [True, False])
@pytest.mark.mpi
def test_domain_descriptor(on_gpu, capsys, mpi_cart_comm, dtype):
# Does not uses streams.

if on_gpu and cp is None:
pytest.skip(reason="`CuPy` is not installed.")
Expand Down Expand Up @@ -289,8 +302,9 @@ def check_field(data, order):

@pytest.mark.parametrize("dtype", [np.float64, np.float32, np.int32, np.int64])
@pytest.mark.parametrize("on_gpu", [True, False])
@pytest.mark.parametrize("stream_type", STREAM_TYPES_TO_TEST)
@pytest.mark.mpi
def test_domain_descriptor_async(on_gpu, capsys, mpi_cart_comm, dtype):
def test_domain_descriptor_async(on_gpu, stream_type, capsys, mpi_cart_comm, dtype):

if on_gpu:
if cp is None:
Expand Down Expand Up @@ -354,7 +368,7 @@ def check_field(data, order, stream):
d1, f1 = make_field("C")
d2, f2 = make_field("F")

stream = cp.cuda.Stream(non_blocking=True) if on_gpu else None
stream = None if stream_type is None else stream_type(non_blocking=True)
handle = co.schedule_exchange(stream, [pattern(f1), pattern(f2)])
assert not co.has_scheduled_exchange()

Expand Down
Loading