Skip to content

Commit 28b56e7

Browse files
committed
[WIP] Add "The stack" section to left nav
The new section contains an overview of the libraries in the stack, as well as a page for each library with a brief description and outbound links for more information.
1 parent 68efe18 commit 28b56e7

File tree

11 files changed

+92
-0
lines changed

11 files changed

+92
-0
lines changed

docs/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ myst-nb
55
myst-parser[linkify]
66
sphinx-book-theme
77
sphinx-copybutton
8+
sphinx-design
89

910
# Packages required for notebook execution
1011
matplotlib

docs/source/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
extensions = [
1818
'myst_nb',
1919
'sphinx_copybutton',
20+
'sphinx_design',
2021
]
2122

2223
templates_path = ['_templates']

docs/source/index.rst

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,20 @@ JAX AI Stack
1212
install
1313
getting_started
1414

15+
.. toctree::
16+
:hidden:
17+
:caption: The stack
18+
:maxdepth: 1
19+
20+
stack_overview
21+
stack_jax
22+
stack_flax
23+
stack_optax
24+
stack_orbax_checkpoint
25+
stack_orbax_export
26+
stack_grain
27+
stack_chex
28+
1529
.. toctree::
1630
:hidden:
1731
:caption: Tutorials

docs/source/stack_chex.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# Chex: test utilities

docs/source/stack_flax.md

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Flax NNX: neural nets
2+
3+
Flax NNX provides **neural net functionality** on top of JAX, such as a module
4+
abstraction and pre-defined layers, via a **Pythonic object-oriented API**. NNX
5+
allows you to write stateful model code that can still take advantage of JAX's
6+
function transforms and other features.
7+
8+
NNX has native integration with [Optax](stack_optax).
9+
10+
Main Flax NNX site:
11+
**[flax.readthedocs.io{material-regular}`open_in_new`](https://flax.readthedocs.io/)**
12+
13+
**If you'd like to learn more about NNX** beyond what's covered in the
14+
[](getting_started) guide, we recommend starting with **[Flax
15+
basics{material-regular}`open_in_new`](https://flax.readthedocs.io/en/latest/nnx_basics.html)**.
16+
17+
The Flax NNX docs cover many other useful topics including:
18+
19+
* [Function
20+
transforms{material-regular}`open_in_new`](https://flax.readthedocs.io/en/latest/guides/transforms.html)
21+
* [Parallelism{material-regular}`open_in_new`](https://flax.readthedocs.io/en/latest/guides/flax_gspmd.html)
22+
* [Performance
23+
considerations{material-regular}`open_in_new`](https://flax.readthedocs.io/en/latest/guides/performance.html)
24+
* And much more!

docs/source/stack_grain.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# Grain: data loading

docs/source/stack_jax.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# JAX: array computing
2+
3+
JAX is the foundation of the JAX AI Stack! It provides **high-performance array
4+
computing** functionality over accelerators via a simple **NumPy-like API and
5+
function transformations**.
6+
7+
Main JAX site: **[jax.dev{material-regular}`open_in_new`](https://jax.dev)**
8+
9+
**If you'd like to learn more about JAX** beyond what's covered in the
10+
[](getting_started) guide, we recommend starting with the **[JAX
11+
tutorials{material-regular}`open_in_new`](https://docs.jax.dev/en/latest/tutorials.html)**.
12+
13+
The JAX docs cover many other useful topics including:
14+
15+
* [Performance profiling{material-regular}`open_in_new`](https://docs.jax.dev/en/latest/profiling.html)
16+
* [Multi-host JAX programs{material-regular}`open_in_new`](https://docs.jax.dev/en/latest/multi_process.html)
17+
* [Custom GPU + TPU kernels with Pallas{material-regular}`open_in_new`](https://docs.jax.dev/en/latest/pallas/index.html)
18+
* And much more!

docs/source/stack_optax.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# Optax: optimizers
2+
3+
Optax provides **gradient processing and optimization** functionality on top of
4+
JAX, including optimizers and losses.
5+
6+
Main Optax site:
7+
**[optax.readthedocs.io{material-regular}`open_in_new`](https://optax.readthedocs.io/en/latest/index.html)**
8+
9+
**If you'd like to learn more about Optax** beyond what's covered in the
10+
[](getting_started) guide, we recommend starting with the **[Optax getting
11+
started{material-regular}`open_in_new`](https://optax.readthedocs.io/en/latest/getting_started.html)**
12+
guide.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# Orbax: checkpointing

docs/source/stack_orbax_export.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# Orbax: model export

0 commit comments

Comments
 (0)