diff --git a/src/cellflow/preprocessing/_wknn.py b/src/cellflow/preprocessing/_wknn.py index 222a9dcf..955c9fb5 100644 --- a/src/cellflow/preprocessing/_wknn.py +++ b/src/cellflow/preprocessing/_wknn.py @@ -185,8 +185,13 @@ def _build_nn( try: from cuml.neighbors import NearestNeighbors - jax.devices("gpu") - except ImportError: + has_gpu = any(d.platform == "gpu" for d in jax.devices()) + if has_gpu: + jax.devices("gpu") + else: + raise RuntimeError("No GPU backend available") + + except (ImportError, RuntimeError): logger.info( "cuML is not installed or GPU is not available. Falling back to neighborhood estimation using CPU with pynndescent." )