Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions .github/actions/setup-cuequivariance-jax/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,28 +14,28 @@ runs:
if: inputs.use-gpu != 'true'
shell: bash
run: |
python -m uv pip install -U pytest jax
uv pip install -U pytest jax

- name: Install JAX (GPU)
if: inputs.use-gpu == 'true'
shell: bash
run: |
python -m uv pip install -U pytest "jax[cuda12]"
python -m uv pip install nvidia-cusolver-cu12==11.7.3.90
python -m uv pip install nvidia-cublas-cu12
uv pip install -U pytest "jax[cuda12]"
uv pip install nvidia-cusolver-cu12==11.7.3.90
uv pip install nvidia-cublas-cu12

- name: Install common dependencies
shell: bash
run: |
python -m uv pip install triton
python -m uv pip install "flax>=0.12.0"
uv pip install triton
uv pip install "flax>=0.12.0"

- name: Clean and install packages
shell: bash
run: |
python -m uv pip uninstall cuequivariance cuequivariance_jax cuequivariance_torch 2>/dev/null || true
python -m uv pip install -e ./cuequivariance
python -m uv pip install -e ./cuequivariance_jax
uv pip uninstall cuequivariance cuequivariance_jax cuequivariance_torch 2>/dev/null || true
uv pip install -e ./cuequivariance
uv pip install -e ./cuequivariance_jax

- name: Verify installation
shell: bash
Expand Down
8 changes: 4 additions & 4 deletions .github/actions/setup-cuequivariance-torch/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@ runs:
- name: Install PyTorch dependencies
shell: bash
run: |
python -m uv pip install -U pytest torch e3nn
uv pip install -U pytest torch e3nn

- name: Clean and install packages
shell: bash
run: |
python -m uv pip uninstall cuequivariance cuequivariance_jax cuequivariance_torch 2>/dev/null || true
python -m uv pip install -e ./cuequivariance
python -m uv pip install -e ./cuequivariance_torch
uv pip uninstall cuequivariance cuequivariance_jax cuequivariance_torch 2>/dev/null || true
uv pip install -e ./cuequivariance
uv pip install -e ./cuequivariance_torch

- name: Verify installation
shell: bash
Expand Down
2 changes: 1 addition & 1 deletion .github/actions/setup-cuequivariance/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@ runs:
- name: Install cuequivariance
shell: bash
run: |
python -m uv pip install -e ./cuequivariance[dev]
uv pip install -e ./cuequivariance[dev]
12 changes: 6 additions & 6 deletions .github/actions/setup-docs/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,16 @@ runs:
- name: Install JAX and PyTorch dependencies
shell: bash
run: |
python -m uv pip install -U pytest jax torch e3nn
python -m uv pip install "flax>=0.12.0"
uv pip install -U pytest jax torch e3nn
uv pip install "flax>=0.12.0"

- name: Install packages
shell: bash
run: |
python -m uv pip install -e ./cuequivariance
python -m uv pip install -e ./cuequivariance_jax
python -m uv pip install -e ./cuequivariance_torch
python -m uv pip install -r docs/requirements.txt
uv pip install -e ./cuequivariance
uv pip install -e ./cuequivariance_jax
uv pip install -e ./cuequivariance_torch
uv pip install -r docs/requirements.txt

- name: Verify installation
shell: bash
Expand Down
19 changes: 10 additions & 9 deletions .github/actions/setup-python-uv/action.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
name: 'Setup Python with uv'
description: 'Set up Python and install uv package manager'
name: 'Setup uv and Python'
description: 'Set up uv package manager and Python'

inputs:
python-version:
Expand All @@ -10,13 +10,14 @@ inputs:
runs:
using: 'composite'
steps:
- name: Set up Python ${{ inputs.python-version }}
uses: actions/setup-python@v5
- name: Set up uv
uses: astral-sh/setup-uv@v7
with:
python-version: ${{ inputs.python-version }}
- name: Install uv
enable-cache: true

- name: Create venv with Python ${{ inputs.python-version }}
shell: bash
run: |
python -m pip install --upgrade pip
python -m pip install --upgrade uv
uv venv --python ${{ inputs.python-version }} .venv
echo "${{ github.workspace }}/.venv/bin" >> $GITHUB_PATH
echo "VIRTUAL_ENV=${{ github.workspace }}/.venv" >> $GITHUB_ENV
2 changes: 1 addition & 1 deletion .github/workflows/nightly-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ jobs:

- name: Downgrade numpy
run: |
python -m uv pip install -U "numpy==1.26.*"
uv pip install -U "numpy==1.26.*"

- name: Test with pytest (numpy 1.26, including slow tests)
run: |
Expand Down
7 changes: 2 additions & 5 deletions .github/workflows/style.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,12 @@ jobs:

steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
- uses: ./.github/actions/setup-python-uv
with:
python-version: "3.12"
- name: Setup Pre-commit
run: |
python -m pip install --upgrade pip
python -m pip install --upgrade uv
python -m uv pip install pre-commit
uv pip install pre-commit
pre-commit install
- name: Run Pre-commit
run: |
Expand Down
42 changes: 22 additions & 20 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,28 +14,30 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.10", "3.12"]
python-version: ["3.10", "3.14"]

steps:
- uses: actions/checkout@v4

- uses: ./.github/actions/setup-python-uv
with:
python-version: ${{ matrix.python-version }}

- uses: ./.github/actions/setup-cuequivariance
with:
install-graphviz: 'true'

- name: Test with pytest
run: |
pytest --doctest-modules -x -m "not slow" cuequivariance

- name: Downgrade numpy
if: matrix.python-version == '3.10'
run: |
python -m uv pip install -U "numpy==1.26.*"
uv pip install -U "numpy==1.26.*"

- name: Test with pytest (numpy 1.26)
if: matrix.python-version == '3.10'
run: |
pytest --doctest-modules -x -m "not slow" cuequivariance

Expand All @@ -45,17 +47,17 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.11", "3.13"]
python-version: ["3.11", "3.14"]

steps:
- uses: actions/checkout@v4

- uses: ./.github/actions/setup-python-uv
with:
python-version: ${{ matrix.python-version }}

- uses: ./.github/actions/setup-cuequivariance-jax

- name: Test with pytest
run: |
XLA_PYTHON_CLIENT_PREALLOCATE=false pytest --doctest-modules -x -m "not slow" cuequivariance_jax
Expand All @@ -66,17 +68,17 @@ jobs:

steps:
- uses: actions/checkout@v4

- uses: ./.github/actions/setup-python-uv
with:
python-version: "3.12"

- name: Install without flax
run: |
python -m uv pip install -U jax
python -m uv pip install -e ./cuequivariance
python -m uv pip install -e ./cuequivariance_jax
uv pip install -U jax
uv pip install -e ./cuequivariance
uv pip install -e ./cuequivariance_jax

- name: Verify import without flax
run: |
python -c "import cuequivariance_jax; print('cuex', cuequivariance_jax.__version__)"
Expand All @@ -93,13 +95,13 @@ jobs:

steps:
- uses: actions/checkout@v4

- uses: ./.github/actions/setup-python-uv
with:
python-version: ${{ matrix.python-version }}

- uses: ./.github/actions/setup-cuequivariance-torch

- name: Test with pytest
run: |
pytest --doctest-modules -x -m "not slow" cuequivariance_torch
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ __pycache__
docs/api/generated/
docs/public/
docs/jupyter_execute/
docs/_build/
Original file line number Diff line number Diff line change
Expand Up @@ -817,7 +817,15 @@ def add_segment(
"""Add a segment to the descriptor."""
if isinstance(segment, dict):
segment = tuple(segment[m] for m in self.subscripts.operands[operand])
return self.operands[operand].add_segment(segment)
result = self.operands[operand].add_segment(segment)
# Rebuild tuple to invalidate CPython's cached tuple hashes,
# which become stale after the operand is mutated in-place.
object.__setattr__(
self,
"operands_and_subscripts",
tuple((ope, ss) for ope, ss in self.operands_and_subscripts),
)
return result

def add_segments(
self, operand: int, segments: list[Union[tuple[int, ...], dict[str, int]]]
Expand Down
Loading