Skip to content

Commit ad7a8da

Browse files
authored
Merge branch 'jax-ml:main' into upgrade-vit-2
2 parents bcb5fa1 + 0dd29a7 commit ad7a8da

36 files changed

+1210
-622
lines changed

.github/workflows/lint.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ jobs:
2323
steps:
2424
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
2525
- name: Set up Python 3.12
26-
uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0
26+
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
2727
with:
2828
python-version: 3.12
2929
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1

.github/workflows/nightly.yaml

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ jobs:
3131
steps:
3232
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
3333
- name: Set up Python ${{ matrix.python-version }}
34-
uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0
34+
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
3535
with:
3636
python-version: ${{ matrix.python-version }}
3737
- name: Install dependencies
@@ -56,13 +56,13 @@ jobs:
5656
fail-fast: false
5757
steps:
5858
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
59-
- uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0
59+
- uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
6060
with:
6161
python-version: 3.12
6262
- name: Install dependencies with jax nightly
6363
run: |
6464
python -m pip install --upgrade pip
65-
python -m pip install .[dev,tfds,grain]
65+
python -m pip install .[dev,tfds]
6666
python -m pip install --upgrade --pre jax jaxlib -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
6767
- name: Run tests
6868
run: |
@@ -81,13 +81,13 @@ jobs:
8181
fail-fast: false
8282
steps:
8383
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
84-
- uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0
84+
- uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
8585
with:
8686
python-version: 3.12
8787
- name: Install dependencies with flax nightly
8888
run: |
8989
python -m pip install --upgrade pip
90-
python -m pip install .[dev,tfds,grain]
90+
python -m pip install .[dev,tfds]
9191
python -m pip install --upgrade git+https://github.com/google/flax.git
9292
- name: Run tests
9393
run: |
@@ -106,13 +106,13 @@ jobs:
106106
fail-fast: false
107107
steps:
108108
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
109-
- uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0
109+
- uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
110110
with:
111111
python-version: 3.12
112112
- name: Install dependencies with optax nightly
113113
run: |
114114
python -m pip install --upgrade pip
115-
python -m pip install .[dev,tfds,grain]
115+
python -m pip install .[dev,tfds]
116116
python -m pip install --upgrade git+https://github.com/google-deepmind/optax.git
117117
- name: Run tests
118118
run: |
@@ -131,13 +131,13 @@ jobs:
131131
fail-fast: false
132132
steps:
133133
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
134-
- uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0
134+
- uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
135135
with:
136136
python-version: 3.12
137137
- name: Install dependencies with orbax-checkpoint and orbax-export nightly
138138
run: |
139139
python -m pip install --upgrade pip
140-
python -m pip install .[dev,tfds,grain]
140+
python -m pip install .[dev,tfds]
141141
python -m pip install --upgrade 'git+https://github.com/google/orbax/#subdirectory=checkpoint' 'git+https://github.com/google/orbax/#subdirectory=export'
142142
- name: Run tests
143143
run: |
@@ -156,13 +156,13 @@ jobs:
156156
fail-fast: false
157157
steps:
158158
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
159-
- uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0
159+
- uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
160160
with:
161161
python-version: 3.12
162162
- name: Install dependencies with chex nightly
163163
run: |
164164
python -m pip install --upgrade pip
165-
python -m pip install .[dev,tfds,grain]
165+
python -m pip install .[dev,tfds]
166166
python -m pip install --upgrade 'git+https://github.com/google-deepmind/chex/'
167167
- name: Run tests
168168
run: |

.github/workflows/release.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ jobs:
4343
permissions:
4444
id-token: write
4545
steps:
46-
- uses: actions/download-artifact@95815c38cf2ff2164869cbab79da8d1f422bc89e # v4.2.1
46+
- uses: actions/download-artifact@d3f86a106a0bac45b974a628896c90dbdf5c8093 # v4.3.0
4747
with:
4848
name: distribution
4949
path: dist

.github/workflows/test.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,13 @@ jobs:
3131
steps:
3232
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
3333
- name: Set up Python ${{ matrix.python-version }}
34-
uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0
34+
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
3535
with:
3636
python-version: ${{ matrix.python-version }}
3737
- name: Install dependencies
3838
run: |
3939
python -m pip install --upgrade pip
40-
pip install .[dev,tfds,grain]
40+
pip install .[dev,tfds]
4141
- name: Run tests
4242
run: |
4343
pytest -n auto jax_ai_stack

README.md

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,7 @@ single point-of-entry for this suite of libraries, so you can install and begin
2323
using many of the same open source packages that Google developers are using
2424
in their everyday work.
2525

26-
To get started with the JAX AI stack, you can check out [Getting started with JAX](
27-
https://github.com/jax-ml/jax-ai-stack/blob/main/docs/source/getting_started_with_jax_for_AI.ipynb).
26+
To get started with the JAX AI stack, you can check out [Getting started with JAX](https://docs.jaxstack.ai/en/latest/getting_started.html).
2827
This is still a work-in-progress, please check back for more documentation and tutorials
2928
in the coming weeks!
3029

@@ -44,20 +43,27 @@ together via the integration tests in this repository. Packages include:
4443
- [optax](https://github.com/google-deepmind/optax): gradient processing and optimization in JAX.
4544
- [orbax](https://github.com/google/orbax): checkpointing and persistence utilities for JAX.
4645
- [chex](https://github.com/google-deepmind/chex): utilities for writing reliable JAX code.
46+
- [grain](https://github.com/google/grain): data loading.
4747

4848
### Optional packages
4949

5050
Additionally, there are optional packages you can install with `pip` extras.
51+
5152
The following command:
5253
```
53-
pip install jax-ai-stack[grain]
54+
pip install jax-ai-stack[tfds]
5455
```
55-
will install a compatible version of the [grain](https://github.com/google/grain) data
56-
loader (currently mac and linux-only).
56+
will install a compatible version of
57+
[tensorflow](https://github.com/tensorflow/tensorflow)
58+
and [tensorflow-datasets](https://github.com/tensorflow/datasets).
5759

58-
Similarly, the following command:
60+
### Hardware support
61+
62+
To install `jax-ai-stack` with hardware-specific JAX support, add the JAX installation
63+
command in the same `pip install` invocation. For example:
5964
```
60-
pip install jax-ai-stack[tfds]
65+
pip install jax-ai-stack "jax[cuda]" # JAX + AI stack with GPU/CUDA support
66+
pip install jax-ai-stack "jax[tpu]" # JAX + AI stack with TPU support
6167
```
62-
will install a compatible version of [tensorflow](https://github.com/tensorflow/tensorflow)
63-
and [tensorflow-datasets](https://github.com/tensorflow/datasets).
68+
For more information on available options for hardware-specific JAX installation, refer
69+
to [JAX installation](https://docs.jax.dev/en/latest/installation.html).

0 commit comments

Comments
 (0)