Skip to content

Commit d3294d9

Browse files
committed
fix headings, H2->H1
1 parent f769b7a commit d3294d9

File tree

6 files changed

+66
-66
lines changed

6 files changed

+66
-66
lines changed

docs/source/ecosystem_overview/architectural.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
## The Architectural Imperative: Performance Beyond Frameworks
1+
# The Architectural Imperative: Performance Beyond Frameworks
22

33
As model architectures converge—for example, on multimodal Mixture-of-Experts (MoE) Transformers—the pursuit of peak performance is leading to the emergence of "Megakernels." A Megakernel is effectively the entire forward pass (or a large portion) of one specific model, hand-coded using a lower-level API like the CUDA SDK on NVIDIA GPUs. This approach achieves maximum hardware utilization by aggressively overlapping compute, memory, and communication. Recent work from the research community has demonstrated that this approach can yield significant throughput gains, over 22% in some cases, for inference on GPUs. This trend is not limited to inference; evidence suggests that some large-scale training efforts have involved low-level hardware control to achieve substantial efficiency gains.
44

docs/source/ecosystem_overview/comparative.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
## A Comparative Perspective: The JAX/TPU Stack as a Compelling Choice
1+
# A Comparative Perspective: The JAX/TPU Stack as a Compelling Choice
22

33
The modern Machine Learning landscape offers many excellent, mature toolchains. The JAX AI Stack, however, presents a unique and compelling set of advantages for developers focused on large-scale, high-performance ML, stemming directly from its modular design and deep hardware co-design.
44

docs/source/ecosystem_overview/conclusion.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
## Conclusion: A Durable, Production-Ready Platform for the Future of AI
1+
# Conclusion: A Durable, Production-Ready Platform for the Future of AI
22

33
The data provided in the table above draws to a rather simple conclusion \- these stacks have their own strengths and weaknesses in a small number of areas but overall are vastly similar from the software standpoint. Both stacks provide out of the box turnkey solutions for pre-training, post-training adaptation and deployment of foundational models.
44

docs/source/ecosystem_overview/core.md

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
## The Core JAX AI Stack
1+
# The Core JAX AI Stack
22

33
The core JAX AI Stack consists of five key libraries that provide the foundation for model development: JAX, [Flax](https://flax.readthedocs.io/en/stable/), [Optax](https://optax.readthedocs.io/en/latest/), [Orbax](https://orbax.readthedocs.io/en/latest/) and [Grain](https://google-grain.readthedocs.io/en/latest/).
44

5-
### JAX: A Foundation for Composable, High-Performance Program Transformation
5+
## JAX: A Foundation for Composable, High-Performance Program Transformation
66

77
[JAX](https://docs.jax.dev/en/latest/) is a Python library for accelerator-oriented array computation and program transformation, designed for high-performance numerical computing and large-scale Machine Learning. With its functional programming model and friendly, NumPy-like API, JAX provides a solid foundation for higher-level libraries.
88

@@ -17,42 +17,42 @@ These core transformations can be mixed and matched to achieve high performance
1717

1818
The seamless integration with XLA's GSPMD (General-purpose SPMD) model allows JAX to automatically parallelize computations across large TPU pods with minimal code changes. In most cases, scaling simply requires high-level sharding annotations, a stark contrast to frameworks where scaling may require more manual management of device placement and communication collectives
1919

20-
### Flax: Flexible Neural Network Authoring and "Model Surgery"
20+
## Flax: Flexible Neural Network Authoring and "Model Surgery"
2121

2222
[Flax](https://flax.readthedocs.io/en/latest/index.html) is a library designed to simplify the creation, debugging, and analysis of neural networks in JAX. While pure functional API provided by JAX can be used to fully specify and train a ML model, users coming from the PyTorch (or TensorFlow) ecosystem are more used to and comfortable with the object oriented approach of specifying models as a graph of `torch.nn.Modules`. The abstractions provided by [Flax](https://flax.readthedocs.io/en/stable/) allow users to think more in terms of layers rather than functions, making it more developer friendly to an audience who value ergonomics and experimentation ease. [Flax](https://flax.readthedocs.io/en/stable/) also enables config driven model construction systems, such as those present in [MaxText](https://maxtext.readthedocs.io/en/latest/) and AxLearn, which separate out model hyperparameters from layer definition code.
2323

2424
With a simple Pythonic API, it allows developers to express models using regular Python objects, while retaining the power and performance of JAX. Flax's NNX API is an evolution of the Flax Linen interface, incorporating lessons learned to offer a more user-friendly interface that remains consistent with the core JAX APIs. Since Flax modules are fully backed by the core JAX APIs, there is no performance penalty associated with defining the model in [Flax](https://flax.readthedocs.io/en/stable/).
2525

26-
#### Motivation
26+
### Motivation
2727

2828
JAX’s pure functional API, while powerful, can be complex for new users since it requires all the program state to be explicitly managed by the user. This paradigm can be unfamiliar to developers used to other frameworks. Modern model architectures are often complex with individual portions of the model trained separately and merged to form the final model[^3], in a process commonly referred to as model surgery. Even with decoder-only LLMs which tend to have a straightforward architecture, post training techniques such as LoRA and quantization require the model definition to be easily manipulated allowing parts of the architecture to be modified or even replaced.
2929

3030
The Flax NNX library, with its simple yet powerful Pythonic API enables this functionality in a way that is intuitive to the user, reducing the amount of cognitive overhead involved in authoring and training a model.
3131

32-
#### Design
32+
### Design
3333

3434
The [Flax](https://flax.readthedocs.io/en/stable/) NNX library introduces an object oriented model definition system that encapsulates the model and random number generator state internally, reducing the cognitive overhead of the user and provides a familiar experience for those accustomed to frameworks like PyTorch or TensorFlow. By making submodule definitions Pythonic and providing APIs to traverse the module hierarchy, it allows for the model definition to be easily editable programmatically for model introspection and surgery.
3535

3636
The [Flax](https://flax.readthedocs.io/en/stable/) NNX APIs are designed to be consistent with the core JAX APIs to allow users to exploit the full expressibility and performance of JAX, with lifted transformations for common operations like sharding, jit and others. Models defined using the NNX APIs can also be adapted to work with functional training loops, allowing the user the flexibility they need while retaining an intuitive object oriented API.
3737

38-
#### Key Strengths
38+
### Key Strengths
3939

4040
* **Intuitive object oriented flexible APIs:** Layers are represented as pure Python objects with internal state management, simplifying model construction and training loops, while also advanced model surgery use cases through support for submodule replacement, partial initialization and model hierarchy traversal.
4141
* **Consistent with Core JAX APIs:** Lifted transformations consistent with core JAX and fully compatible with functional JAX provide the full performance of JAX without sacrificing developer friendliness.
4242

4343

4444
(optax:composable)=
45-
### Optax: Composable Gradient Processing and Optimization Strategies
45+
## Optax: Composable Gradient Processing and Optimization Strategies
4646

4747
[Optax](https://optax.readthedocs.io/en/latest/index.html) is a gradient processing and optimization library for JAX. It is designed to empower model builders by providing building blocks that can be recombined in custom ways in order to train deep learning models amongst other applications. It builds on the capabilities of the core JAX library to provide a well tested high performance library of losses and optimizer functions and associated techniques that can be used to train ML models.
4848

49-
#### Motivation
49+
### Motivation
5050

5151
The calculation and minimization of losses is at the core of what enables the training of ML models. With its support for automatic differentiation the core JAX library provides the numeric capabilities to train models, but it does not provide standard implementations of popular optimizers (ex. `RMSProp`, `Adam`) or losses (`CrossEntropy`, `MSE` etc). While it is true that a user could implement these functions by themselves (and some advanced users will choose to do so), a bug in an optimizer implementation would introduce hard to diagnose model quality issues. Rather than having the user implement such critical pieces, [Optax](https://optax.readthedocs.io/en/latest/) provides implementations of these algorithms that are tested for correctness and performance.
5252

5353
The field of optimization theory lies squarely in the realm of research, however its central role in training also makes it an indispensable part of training production ML models. A library that serves this role needs to be both flexible enough to accommodate rapid research iterations and also robust and performant enough to be dependable for production model training. It should also provide well tested implementations of state of the art algorithms which match the standard equations. The [Optax](https://optax.readthedocs.io/en/latest/) library, through its modular composable architecture and emphasis on correct readable code is designed to achieve this.
5454

55-
#### Design
55+
### Design
5656

5757
[Optax](https://optax.readthedocs.io/en/latest/) is designed to both enhance research velocity and the transition from research to production by providing readable, well-tested, and efficient implementations of core algorithms. Optax has uses beyond the context of deep learning, however in this context it can be viewed as a collection of well known loss functions, optimization algorithms and gradient transformations implemented in a pure functional fashion in line with the JAX philosophy. The collection of well known [losses](https://optax.readthedocs.io/en/latest/api/losses.html) and [optimizers](https://optax.readthedocs.io/en/latest/api/optimizers.html) enable users to get started with ease and confidence.
5858

@@ -80,31 +80,31 @@ for i, (inputs, targets) in enumerate(data_loader):
8080

8181
As it can be seen in the example above, setting up an optimizer with a custom learning rate, gradient clipping and gradient accumulation is a simple drop in replacement block of code, compared to PyTorch which forces the user to modify their training loop to directly manage the learning rate scheduler, gradient clipping and gradient accumulation.
8282

83-
#### Key Strengths
83+
### Key Strengths
8484

8585
* **Robust Library:** Provides a comprehensive library of losses, optimizers, and algorithms with a focus on correctness and readability.
8686
* **Modular Chainable Transformations:** As shown above, this flexible API allows users to craft powerful, complex optimization strategies declaratively, without modifying the training loop.
8787
* **Functional and Scalable:** The pure functional implementations integrate seamlessly with JAX's parallelization mechanisms (e.g., pmap), enabling the same code to scale from a single host to large clusters.
8888

8989

9090
(orbax:tensorstore)=
91-
### Orbax / TensorStore \- Large scale distributed checkpointing
91+
## Orbax / TensorStore \- Large scale distributed checkpointing
9292

9393
[**Orbax**](https://orbax.readthedocs.io/en/latest/) is an any-scale checkpointing library for JAX users backed primarily by [**TensorStore**](https://google.github.io/tensorstore/), a library for efficiently reading and writing multi-dimensional arrays. The two libraries operate at different levels of the stack \- Orbax at the level of ML models and states \- TensorStore at the level of individual arrays.
9494

95-
#### Motivation
95+
### Motivation
9696

9797
[Orbax](https://orbax.readthedocs.io/en/latest/), which centers on JAX users and ML checkpointing, aims to reduce the fragmentation of checkpointing implementations across disparate research codebases, increase adoption of important performance features outside the most cutting-edge codebases, and provide a clean, flexible API for novice and advanced users alike. With advanced features like fully asynchronous distributed checkpointing, multi-tier checkpointing and emergency checkpointing, [Orbax](https://orbax.readthedocs.io/en/latest/) enables resilience in the largest of training jobs while also providing a flexible representation for publishing checkpoints.
9898

99-
#### ML Checkpointing vs Generalized Checkpoint/Restore
99+
### ML Checkpointing vs Generalized Checkpoint/Restore
100100

101101
It is worth considering the difference between ML checkpoint systems ([Orbax](https://orbax.readthedocs.io/en/latest/), NeMO-Megatron, Torch Distributed Checkpoint) with generalized checkpoint systems like CRIU.
102102

103103
Systems like CRIU & CRIUgpu behave analogously to VM live migration; they halt the entire system and take a snapshot of every last bit of information so it can be faithfully reconstructed. This captures the entirety of the process’ host memory, device memory and operating system state. This is far more information that is actually needed to reconstruct a ML workload, since for a ML workload, a very large fraction of this information (activations, data examples, file handles) is trivially reconstructed. Capturing this much data also incurs a large amount of time when the job is halted.
104104

105105
ML checkpoint systems are designed to minimize the amount of time the accelerator is halted by selectively persisting information that cannot be reconstructed. Specifically, this entails persisting model weights, optimizer state, dataloader state and random number generator state, which is a far smaller amount of data.
106106

107-
#### Design
107+
### Design
108108

109109
The [Orbax API](https://orbax.readthedocs.io/en/latest/index.html) centers around handling [PyTrees](https://docs.jax.dev/en/latest/pytrees.html) (nested containers) of arrays as the standard representation of JAX models. Saving and loading can be synchronous or asynchronous, with saving consisting of blocking and non-blocking phases. A higher-level `Checkpointer` class is provided, which facilitates checkpointing in a training loop, with save intervals, garbage collection, dataset checkpointing, and metadata management. Finally, Orbax provides customization layers for dealing with user-defined checkpointable objects and PyTree leaves.
110110

@@ -125,23 +125,23 @@ Specific industry-leading performance features have their own design challenges,
125125
* [**Restore \+ broadcast**](https://cloud.google.com/blog/products/compute/unlock-faster-workload-start-time-using-orbax-on-jax): Hero-scale training runs replicate the model weights among multiple data-parallel replicas. Orbax provides a load balancing feature that distributes the burden evenly among available replicas when saving. It also leverages fast chip interconnects to avoid redundant reads of the model on different groups of hosts, instead loading on a single primary replica and broadcasting the weights to all other replicas.
126126
* **Emergency checkpointing**: Hero-scale training suffers from frequent interruptions and hardware failures. Checkpointing to persistent RAM disk improves goodput for hero-scale jobs by allowing for increased checkpoint frequency, faster restore times, and improved resiliency, since TPU states may be corrupted on some replicas, but not all.
127127

128-
#### Key Strengths
128+
### Key Strengths
129129

130130
* **Widespread adoption:** As checkpoints are a medium for communication of ML artifacts between different codebases and stages of ML development, widespread adoption is an inherent advantage. Currently, Orbax has [\~4 million](https://pypistats.org/packages/orbax-checkpoint) monthly package downloads.
131131
* **Easy to use:** Orbax abstracts away complex technical aspects of checkpointing like async saving, single- vs. multi-controller, checkpoint atomicity, distributed filesystem details, TPU vs. GPU, etc. It condenses use cases into simple, but generalizable APIs (direct-to-path, sequence-of-steps).
132132
* **Flexible:** While Orbax focuses on exposing a simple API surface for the majority of users, additional layers for handling custom checkpointable objects and PyTree nodes allow for flexibility in specialized use cases.
133133
* **Performant and scalable:** Orbax provides a variety of features designed to make checkpointing as fast and as unobtrusive as possible, freeing developers to focus on efficiency in the remainder of the training loop. Scalability to the cutting edge of ML research is a top concern of the library; training runs at a scale of O(10k) nodes currently rely on Orbax.
134134

135135

136-
### Grain: Deterministic and Scalable Input Data Pipelines
136+
## Grain: Deterministic and Scalable Input Data Pipelines
137137

138138
[Grain](https://google-grain.readthedocs.io/en/latest/) is a Python library for reading and processing data for training and evaluating JAX models. It is flexible, fast and deterministic and supports advanced features like checkpointing which are essential to successfully training large workloads. It supports popular data formats and storage backends and also provides a flexible API to extend support to user specific formats and backends that are not natively supported. While [Grain](https://google-grain.readthedocs.io/en/latest/) is primarily designed to work with JAX, it is framework independent, does not require JAX to run and can be used with other frameworks as well.
139139

140-
#### Motivation
140+
### Motivation
141141

142142
Data pipelines form a critical part of the training infrastructure \- they need to be flexible so that common transformations can be expressed efficiently, and performant enough that they are able to keep the accelerators busy at all times. They also need to be able to accommodate multiple storage formats and backends. Due to their higher step times, training large models at scale pose unique additional requirements on the data pipeline beyond those that are required by regular training workloads, primarily focused around determinism and reproducibility[^5]. The [Grain](https://google-grain.readthedocs.io/en/latest/) library is designed with a flexible enough architecture to address all these needs.
143143

144-
#### Design
144+
### Design
145145

146146
At the highest level, there are two ways to structure an input pipeline, as a separate cluster of data workers or by co-locating the data workers on the hosts that drive the accelerators. [Grain](https://google-grain.readthedocs.io/en/latest/) chooses the latter for a variety of reasons.
147147

@@ -151,7 +151,7 @@ On the API front, with a pure python implementation that supports multiple proce
151151

152152
Out of the box, [Grain](https://google-grain.readthedocs.io/en/latest/) supports efficient random access data formats like `ArrayRecord` and `Bagz` alongside other popular data formats such as Parquet and `TFDS`. [Grain](https://google-grain.readthedocs.io/en/latest/) includes support for reading from local file systems as well as reading from GCS by default. Along with supporting popular storage formats and backends, a clean abstraction to the storage layer allows users to easily add support for or wrap their existing data sources to be compatible with the [Grain](https://google-grain.readthedocs.io/en/latest/) library.
153153

154-
#### Key Strengths
154+
### Key Strengths
155155

156156
* **Deterministic data feeding:** Colocating the data worker with the accelerator and coupling it with a stable global shuffle and [checkpointable iterators](https://google-grain.readthedocs.io/en/latest/tutorials/data_loader_tutorial.html#checkpointing) allows the model state and data pipeline state to be checkpointed together in a consistent snapshot using [Orbax](https://orbax.readthedocs.io/en/latest/), enhancing the determinism of the training process.
157157
* **Flexible APIs to enable powerful data transformations:** A flexible pure Python [transformations](https://google-grain.readthedocs.io/en/latest/data_loader/transformations.html) API allows for extensive data transformations within the input processing pipeline.

0 commit comments

Comments
 (0)