diff --git a/.github/actions/setup-cuequivariance-jax/action.yml b/.github/actions/setup-cuequivariance-jax/action.yml index 48c58a0..23adc2e 100644 --- a/.github/actions/setup-cuequivariance-jax/action.yml +++ b/.github/actions/setup-cuequivariance-jax/action.yml @@ -14,28 +14,28 @@ runs: if: inputs.use-gpu != 'true' shell: bash run: | - python -m uv pip install -U pytest jax + uv pip install -U pytest jax - name: Install JAX (GPU) if: inputs.use-gpu == 'true' shell: bash run: | - python -m uv pip install -U pytest "jax[cuda12]" - python -m uv pip install nvidia-cusolver-cu12==11.7.3.90 - python -m uv pip install nvidia-cublas-cu12 + uv pip install -U pytest "jax[cuda12]" + uv pip install nvidia-cusolver-cu12==11.7.3.90 + uv pip install nvidia-cublas-cu12 - name: Install common dependencies shell: bash run: | - python -m uv pip install triton - python -m uv pip install "flax>=0.12.0" + uv pip install triton + uv pip install "flax>=0.12.0" - name: Clean and install packages shell: bash run: | - python -m uv pip uninstall cuequivariance cuequivariance_jax cuequivariance_torch 2>/dev/null || true - python -m uv pip install -e ./cuequivariance - python -m uv pip install -e ./cuequivariance_jax + uv pip uninstall cuequivariance cuequivariance_jax cuequivariance_torch 2>/dev/null || true + uv pip install -e ./cuequivariance + uv pip install -e ./cuequivariance_jax - name: Verify installation shell: bash diff --git a/.github/actions/setup-cuequivariance-torch/action.yml b/.github/actions/setup-cuequivariance-torch/action.yml index f960881..87f2bdc 100644 --- a/.github/actions/setup-cuequivariance-torch/action.yml +++ b/.github/actions/setup-cuequivariance-torch/action.yml @@ -7,14 +7,14 @@ runs: - name: Install PyTorch dependencies shell: bash run: | - python -m uv pip install -U pytest torch e3nn + uv pip install -U pytest torch e3nn - name: Clean and install packages shell: bash run: | - python -m uv pip uninstall cuequivariance cuequivariance_jax cuequivariance_torch 2>/dev/null || true - python -m uv pip install -e ./cuequivariance - python -m uv pip install -e ./cuequivariance_torch + uv pip uninstall cuequivariance cuequivariance_jax cuequivariance_torch 2>/dev/null || true + uv pip install -e ./cuequivariance + uv pip install -e ./cuequivariance_torch - name: Verify installation shell: bash diff --git a/.github/actions/setup-cuequivariance/action.yml b/.github/actions/setup-cuequivariance/action.yml index 7f8a54d..3db9e0e 100644 --- a/.github/actions/setup-cuequivariance/action.yml +++ b/.github/actions/setup-cuequivariance/action.yml @@ -20,4 +20,4 @@ runs: - name: Install cuequivariance shell: bash run: | - python -m uv pip install -e ./cuequivariance[dev] + uv pip install -e ./cuequivariance[dev] diff --git a/.github/actions/setup-docs/action.yml b/.github/actions/setup-docs/action.yml index d297c7d..354eee6 100644 --- a/.github/actions/setup-docs/action.yml +++ b/.github/actions/setup-docs/action.yml @@ -13,16 +13,16 @@ runs: - name: Install JAX and PyTorch dependencies shell: bash run: | - python -m uv pip install -U pytest jax torch e3nn - python -m uv pip install "flax>=0.12.0" + uv pip install -U pytest jax torch e3nn + uv pip install "flax>=0.12.0" - name: Install packages shell: bash run: | - python -m uv pip install -e ./cuequivariance - python -m uv pip install -e ./cuequivariance_jax - python -m uv pip install -e ./cuequivariance_torch - python -m uv pip install -r docs/requirements.txt + uv pip install -e ./cuequivariance + uv pip install -e ./cuequivariance_jax + uv pip install -e ./cuequivariance_torch + uv pip install -r docs/requirements.txt - name: Verify installation shell: bash diff --git a/.github/actions/setup-python-uv/action.yml b/.github/actions/setup-python-uv/action.yml index a208d05..3db8efc 100644 --- a/.github/actions/setup-python-uv/action.yml +++ b/.github/actions/setup-python-uv/action.yml @@ -1,5 +1,5 @@ -name: 'Setup Python with uv' -description: 'Set up Python and install uv package manager' +name: 'Setup uv and Python' +description: 'Set up uv package manager and Python' inputs: python-version: @@ -10,13 +10,14 @@ inputs: runs: using: 'composite' steps: - - name: Set up Python ${{ inputs.python-version }} - uses: actions/setup-python@v5 + - name: Set up uv + uses: astral-sh/setup-uv@v7 with: - python-version: ${{ inputs.python-version }} - - - name: Install uv + enable-cache: true + + - name: Create venv with Python ${{ inputs.python-version }} shell: bash run: | - python -m pip install --upgrade pip - python -m pip install --upgrade uv + uv venv --python ${{ inputs.python-version }} .venv + echo "${{ github.workspace }}/.venv/bin" >> $GITHUB_PATH + echo "VIRTUAL_ENV=${{ github.workspace }}/.venv" >> $GITHUB_ENV diff --git a/.github/workflows/nightly-tests.yml b/.github/workflows/nightly-tests.yml index 423f472..23bf345 100644 --- a/.github/workflows/nightly-tests.yml +++ b/.github/workflows/nightly-tests.yml @@ -56,7 +56,7 @@ jobs: - name: Downgrade numpy run: | - python -m uv pip install -U "numpy==1.26.*" + uv pip install -U "numpy==1.26.*" - name: Test with pytest (numpy 1.26, including slow tests) run: | diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index e2226bd..e61f8af 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -16,15 +16,12 @@ jobs: steps: - uses: actions/checkout@v4 - - name: Set up Python - uses: actions/setup-python@v5 + - uses: ./.github/actions/setup-python-uv with: python-version: "3.12" - name: Setup Pre-commit run: | - python -m pip install --upgrade pip - python -m pip install --upgrade uv - python -m uv pip install pre-commit + uv pip install pre-commit pre-commit install - name: Run Pre-commit run: | diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 8aa7fee..920e7fb 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -14,28 +14,30 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.10", "3.12"] + python-version: ["3.10", "3.14"] steps: - uses: actions/checkout@v4 - + - uses: ./.github/actions/setup-python-uv with: python-version: ${{ matrix.python-version }} - + - uses: ./.github/actions/setup-cuequivariance with: install-graphviz: 'true' - + - name: Test with pytest run: | pytest --doctest-modules -x -m "not slow" cuequivariance - + - name: Downgrade numpy + if: matrix.python-version == '3.10' run: | - python -m uv pip install -U "numpy==1.26.*" - + uv pip install -U "numpy==1.26.*" + - name: Test with pytest (numpy 1.26) + if: matrix.python-version == '3.10' run: | pytest --doctest-modules -x -m "not slow" cuequivariance @@ -45,17 +47,17 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.11", "3.13"] + python-version: ["3.11", "3.14"] steps: - uses: actions/checkout@v4 - + - uses: ./.github/actions/setup-python-uv with: python-version: ${{ matrix.python-version }} - + - uses: ./.github/actions/setup-cuequivariance-jax - + - name: Test with pytest run: | XLA_PYTHON_CLIENT_PREALLOCATE=false pytest --doctest-modules -x -m "not slow" cuequivariance_jax @@ -66,17 +68,17 @@ jobs: steps: - uses: actions/checkout@v4 - + - uses: ./.github/actions/setup-python-uv with: python-version: "3.12" - + - name: Install without flax run: | - python -m uv pip install -U jax - python -m uv pip install -e ./cuequivariance - python -m uv pip install -e ./cuequivariance_jax - + uv pip install -U jax + uv pip install -e ./cuequivariance + uv pip install -e ./cuequivariance_jax + - name: Verify import without flax run: | python -c "import cuequivariance_jax; print('cuex', cuequivariance_jax.__version__)" @@ -93,13 +95,13 @@ jobs: steps: - uses: actions/checkout@v4 - + - uses: ./.github/actions/setup-python-uv with: python-version: ${{ matrix.python-version }} - + - uses: ./.github/actions/setup-cuequivariance-torch - + - name: Test with pytest run: | pytest --doctest-modules -x -m "not slow" cuequivariance_torch diff --git a/.gitignore b/.gitignore index efa6755..5bdb1b0 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ __pycache__ docs/api/generated/ docs/public/ docs/jupyter_execute/ +docs/_build/ diff --git a/cuequivariance/cuequivariance/segmented_polynomials/segmented_tensor_product.py b/cuequivariance/cuequivariance/segmented_polynomials/segmented_tensor_product.py index 3a5a134..359598e 100644 --- a/cuequivariance/cuequivariance/segmented_polynomials/segmented_tensor_product.py +++ b/cuequivariance/cuequivariance/segmented_polynomials/segmented_tensor_product.py @@ -817,7 +817,15 @@ def add_segment( """Add a segment to the descriptor.""" if isinstance(segment, dict): segment = tuple(segment[m] for m in self.subscripts.operands[operand]) - return self.operands[operand].add_segment(segment) + result = self.operands[operand].add_segment(segment) + # Rebuild tuple to invalidate CPython's cached tuple hashes, + # which become stale after the operand is mutated in-place. + object.__setattr__( + self, + "operands_and_subscripts", + tuple((ope, ss) for ope, ss in self.operands_and_subscripts), + ) + return result def add_segments( self, operand: int, segments: list[Union[tuple[int, ...], dict[str, int]]]