diff --git a/docs/src/installation.rst b/docs/src/installation.rst index 66ea42ef..8b9e29ff 100644 --- a/docs/src/installation.rst +++ b/docs/src/installation.rst @@ -11,16 +11,14 @@ The Python package can be installed with pip by simply running pip install sphericart -This basic package makes use of NumPy. A PyTorch-based implementation can be installed with +This basic package makes use of NumPy. Implementations supporting PyTorch and JAX can be installed with .. code-block:: bash pip install sphericart[torch] + pip install sphericart[jax] -This pre-built version available on PyPI sacrifices some performance to ensure it -can run on all systems, and it does not include GPU support. -If you need an extra 5-10% of performance, you want to evaluate the spherical harmonics on GPUs, -and/or you want to use it in JAX, you should build the code from source: +If you need an extra 5-10% of performance, you should build the code from source: .. code-block:: bash diff --git a/scripts/create-jax-versions-range.py b/scripts/create-jax-versions-range.py index 1eed6e43..d615504c 100644 --- a/scripts/create-jax-versions-range.py +++ b/scripts/create-jax-versions-range.py @@ -26,7 +26,7 @@ if match is None: raise ValueError(f"unexpected Requires-Dist format: {version}") - major, minor, patch = match.groups() + major, minor, patch, *_ = match.groups() major = int(major) minor = int(minor) patch = int(patch) if patch is not None else 0 diff --git a/sphericart-jax/setup.py b/sphericart-jax/setup.py index ace77bee..64f75848 100644 --- a/sphericart-jax/setup.py +++ b/sphericart-jax/setup.py @@ -4,8 +4,8 @@ from setuptools import Extension, setup from setuptools.command.bdist_egg import bdist_egg +from setuptools.command.bdist_wheel import bdist_wheel from setuptools.command.build_ext import build_ext -from wheel.bdist_wheel import bdist_wheel ROOT = os.path.realpath(os.path.dirname(__file__)) @@ -31,7 +31,7 @@ class cmake_ext(build_ext): def run(self): import jax - jax_major, jax_minor, jax_patch = jax.__version__.split(".") + jax_major, jax_minor, jax_patch, *_ = jax.__version__.split(".") source_dir = ROOT build_dir = os.path.join(ROOT, "build", "cmake-build") @@ -99,8 +99,8 @@ def run(self): import jax # if we have jax, we are building a wheel - requires specific jax version - jax_v_major, jax_v_minor, jax_v_patch = jax.__version__.split(".") - jax_version = f"== {jax_v_major}.{jax_v_minor}.{jax_v_patch}" + jax_v_major, jax_v_minor, jax_v_patch, *_ = jax.__version__.split(".") + jax_version = f"== {jax_v_major}.{jax_v_minor}.{jax_v_patch}.*" except ImportError: # otherwise we are building a sdist jax_version = ">=0.6.0"