You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
# The Architectural Imperative: Performance Beyond Frameworks
2
2
3
3
As model architectures converge—for example, on multimodal Mixture-of-Experts (MoE) Transformers—the pursuit of peak performance is leading to the emergence of "Megakernels." A Megakernel is effectively the entire forward pass (or a large portion) of one specific model, hand-coded using a lower-level API like the CUDA SDK on NVIDIA GPUs. This approach achieves maximum hardware utilization by aggressively overlapping compute, memory, and communication. Recent work from the research community has demonstrated that this approach can yield significant throughput gains, over 22% in some cases, for inference on GPUs. This trend is not limited to inference; evidence suggests that some large-scale training efforts have involved low-level hardware control to achieve substantial efficiency gains.
Copy file name to clipboardExpand all lines: docs/source/ecosystem_overview/comparative.md
+1-1Lines changed: 1 addition & 1 deletion
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -1,4 +1,4 @@
1
-
##A Comparative Perspective: The JAX/TPU Stack as a Compelling Choice
1
+
# A Comparative Perspective: The JAX/TPU Stack as a Compelling Choice
2
2
3
3
The modern Machine Learning landscape offers many excellent, mature toolchains. The JAX AI Stack, however, presents a unique and compelling set of advantages for developers focused on large-scale, high-performance ML, stemming directly from its modular design and deep hardware co-design.
Copy file name to clipboardExpand all lines: docs/source/ecosystem_overview/conclusion.md
+1-1Lines changed: 1 addition & 1 deletion
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -1,4 +1,4 @@
1
-
##Conclusion: A Durable, Production-Ready Platform for the Future of AI
1
+
# Conclusion: A Durable, Production-Ready Platform for the Future of AI
2
2
3
3
The data provided in the table above draws to a rather simple conclusion \- these stacks have their own strengths and weaknesses in a small number of areas but overall are vastly similar from the software standpoint. Both stacks provide out of the box turnkey solutions for pre-training, post-training adaptation and deployment of foundational models.
Copy file name to clipboardExpand all lines: docs/source/ecosystem_overview/core.md
+19-19Lines changed: 19 additions & 19 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -1,8 +1,8 @@
1
-
##The Core JAX AI Stack
1
+
# The Core JAX AI Stack
2
2
3
3
The core JAX AI Stack consists of five key libraries that provide the foundation for model development: JAX, [Flax](https://flax.readthedocs.io/en/stable/), [Optax](https://optax.readthedocs.io/en/latest/), [Orbax](https://orbax.readthedocs.io/en/latest/) and [Grain](https://google-grain.readthedocs.io/en/latest/).
4
4
5
-
###JAX: A Foundation for Composable, High-Performance Program Transformation
5
+
## JAX: A Foundation for Composable, High-Performance Program Transformation
6
6
7
7
[JAX](https://docs.jax.dev/en/latest/) is a Python library for accelerator-oriented array computation and program transformation, designed for high-performance numerical computing and large-scale Machine Learning. With its functional programming model and friendly, NumPy-like API, JAX provides a solid foundation for higher-level libraries.
8
8
@@ -17,42 +17,42 @@ These core transformations can be mixed and matched to achieve high performance
17
17
18
18
The seamless integration with XLA's GSPMD (General-purpose SPMD) model allows JAX to automatically parallelize computations across large TPU pods with minimal code changes. In most cases, scaling simply requires high-level sharding annotations, a stark contrast to frameworks where scaling may require more manual management of device placement and communication collectives
19
19
20
-
###Flax: Flexible Neural Network Authoring and "Model Surgery"
20
+
## Flax: Flexible Neural Network Authoring and "Model Surgery"
21
21
22
22
[Flax](https://flax.readthedocs.io/en/latest/index.html) is a library designed to simplify the creation, debugging, and analysis of neural networks in JAX. While pure functional API provided by JAX can be used to fully specify and train a ML model, users coming from the PyTorch (or TensorFlow) ecosystem are more used to and comfortable with the object oriented approach of specifying models as a graph of `torch.nn.Modules`. The abstractions provided by [Flax](https://flax.readthedocs.io/en/stable/) allow users to think more in terms of layers rather than functions, making it more developer friendly to an audience who value ergonomics and experimentation ease. [Flax](https://flax.readthedocs.io/en/stable/) also enables config driven model construction systems, such as those present in [MaxText](https://maxtext.readthedocs.io/en/latest/) and AxLearn, which separate out model hyperparameters from layer definition code.
23
23
24
24
With a simple Pythonic API, it allows developers to express models using regular Python objects, while retaining the power and performance of JAX. Flax's NNX API is an evolution of the Flax Linen interface, incorporating lessons learned to offer a more user-friendly interface that remains consistent with the core JAX APIs. Since Flax modules are fully backed by the core JAX APIs, there is no performance penalty associated with defining the model in [Flax](https://flax.readthedocs.io/en/stable/).
25
25
26
-
####Motivation
26
+
### Motivation
27
27
28
28
JAX’s pure functional API, while powerful, can be complex for new users since it requires all the program state to be explicitly managed by the user. This paradigm can be unfamiliar to developers used to other frameworks. Modern model architectures are often complex with individual portions of the model trained separately and merged to form the final model[^3], in a process commonly referred to as model surgery. Even with decoder-only LLMs which tend to have a straightforward architecture, post training techniques such as LoRA and quantization require the model definition to be easily manipulated allowing parts of the architecture to be modified or even replaced.
29
29
30
30
The Flax NNX library, with its simple yet powerful Pythonic API enables this functionality in a way that is intuitive to the user, reducing the amount of cognitive overhead involved in authoring and training a model.
31
31
32
-
####Design
32
+
### Design
33
33
34
34
The [Flax](https://flax.readthedocs.io/en/stable/) NNX library introduces an object oriented model definition system that encapsulates the model and random number generator state internally, reducing the cognitive overhead of the user and provides a familiar experience for those accustomed to frameworks like PyTorch or TensorFlow. By making submodule definitions Pythonic and providing APIs to traverse the module hierarchy, it allows for the model definition to be easily editable programmatically for model introspection and surgery.
35
35
36
36
The [Flax](https://flax.readthedocs.io/en/stable/) NNX APIs are designed to be consistent with the core JAX APIs to allow users to exploit the full expressibility and performance of JAX, with lifted transformations for common operations like sharding, jit and others. Models defined using the NNX APIs can also be adapted to work with functional training loops, allowing the user the flexibility they need while retaining an intuitive object oriented API.
37
37
38
-
####Key Strengths
38
+
### Key Strengths
39
39
40
40
***Intuitive object oriented flexible APIs:** Layers are represented as pure Python objects with internal state management, simplifying model construction and training loops, while also advanced model surgery use cases through support for submodule replacement, partial initialization and model hierarchy traversal.
41
41
***Consistent with Core JAX APIs:** Lifted transformations consistent with core JAX and fully compatible with functional JAX provide the full performance of JAX without sacrificing developer friendliness.
42
42
43
43
44
44
(optax:composable)=
45
-
###Optax: Composable Gradient Processing and Optimization Strategies
45
+
## Optax: Composable Gradient Processing and Optimization Strategies
46
46
47
47
[Optax](https://optax.readthedocs.io/en/latest/index.html) is a gradient processing and optimization library for JAX. It is designed to empower model builders by providing building blocks that can be recombined in custom ways in order to train deep learning models amongst other applications. It builds on the capabilities of the core JAX library to provide a well tested high performance library of losses and optimizer functions and associated techniques that can be used to train ML models.
48
48
49
-
####Motivation
49
+
### Motivation
50
50
51
51
The calculation and minimization of losses is at the core of what enables the training of ML models. With its support for automatic differentiation the core JAX library provides the numeric capabilities to train models, but it does not provide standard implementations of popular optimizers (ex. `RMSProp`, `Adam`) or losses (`CrossEntropy`, `MSE` etc). While it is true that a user could implement these functions by themselves (and some advanced users will choose to do so), a bug in an optimizer implementation would introduce hard to diagnose model quality issues. Rather than having the user implement such critical pieces, [Optax](https://optax.readthedocs.io/en/latest/) provides implementations of these algorithms that are tested for correctness and performance.
52
52
53
53
The field of optimization theory lies squarely in the realm of research, however its central role in training also makes it an indispensable part of training production ML models. A library that serves this role needs to be both flexible enough to accommodate rapid research iterations and also robust and performant enough to be dependable for production model training. It should also provide well tested implementations of state of the art algorithms which match the standard equations. The [Optax](https://optax.readthedocs.io/en/latest/) library, through its modular composable architecture and emphasis on correct readable code is designed to achieve this.
54
54
55
-
####Design
55
+
### Design
56
56
57
57
[Optax](https://optax.readthedocs.io/en/latest/) is designed to both enhance research velocity and the transition from research to production by providing readable, well-tested, and efficient implementations of core algorithms. Optax has uses beyond the context of deep learning, however in this context it can be viewed as a collection of well known loss functions, optimization algorithms and gradient transformations implemented in a pure functional fashion in line with the JAX philosophy. The collection of well known [losses](https://optax.readthedocs.io/en/latest/api/losses.html) and [optimizers](https://optax.readthedocs.io/en/latest/api/optimizers.html) enable users to get started with ease and confidence.
58
58
@@ -80,31 +80,31 @@ for i, (inputs, targets) in enumerate(data_loader):
80
80
81
81
As it can be seen in the example above, setting up an optimizer with a custom learning rate, gradient clipping and gradient accumulation is a simple drop in replacement block of code, compared to PyTorch which forces the user to modify their training loop to directly manage the learning rate scheduler, gradient clipping and gradient accumulation.
82
82
83
-
####Key Strengths
83
+
### Key Strengths
84
84
85
85
***Robust Library:** Provides a comprehensive library of losses, optimizers, and algorithms with a focus on correctness and readability.
86
86
***Modular Chainable Transformations:** As shown above, this flexible API allows users to craft powerful, complex optimization strategies declaratively, without modifying the training loop.
87
87
***Functional and Scalable:** The pure functional implementations integrate seamlessly with JAX's parallelization mechanisms (e.g., pmap), enabling the same code to scale from a single host to large clusters.
88
88
89
89
90
90
(orbax:tensorstore)=
91
-
###Orbax / TensorStore \- Large scale distributed checkpointing
91
+
## Orbax / TensorStore \- Large scale distributed checkpointing
92
92
93
93
[**Orbax**](https://orbax.readthedocs.io/en/latest/) is an any-scale checkpointing library for JAX users backed primarily by [**TensorStore**](https://google.github.io/tensorstore/), a library for efficiently reading and writing multi-dimensional arrays. The two libraries operate at different levels of the stack \- Orbax at the level of ML models and states \- TensorStore at the level of individual arrays.
94
94
95
-
####Motivation
95
+
### Motivation
96
96
97
97
[Orbax](https://orbax.readthedocs.io/en/latest/), which centers on JAX users and ML checkpointing, aims to reduce the fragmentation of checkpointing implementations across disparate research codebases, increase adoption of important performance features outside the most cutting-edge codebases, and provide a clean, flexible API for novice and advanced users alike. With advanced features like fully asynchronous distributed checkpointing, multi-tier checkpointing and emergency checkpointing, [Orbax](https://orbax.readthedocs.io/en/latest/) enables resilience in the largest of training jobs while also providing a flexible representation for publishing checkpoints.
98
98
99
-
####ML Checkpointing vs Generalized Checkpoint/Restore
99
+
### ML Checkpointing vs Generalized Checkpoint/Restore
100
100
101
101
It is worth considering the difference between ML checkpoint systems ([Orbax](https://orbax.readthedocs.io/en/latest/), NeMO-Megatron, Torch Distributed Checkpoint) with generalized checkpoint systems like CRIU.
102
102
103
103
Systems like CRIU & CRIUgpu behave analogously to VM live migration; they halt the entire system and take a snapshot of every last bit of information so it can be faithfully reconstructed. This captures the entirety of the process’ host memory, device memory and operating system state. This is far more information that is actually needed to reconstruct a ML workload, since for a ML workload, a very large fraction of this information (activations, data examples, file handles) is trivially reconstructed. Capturing this much data also incurs a large amount of time when the job is halted.
104
104
105
105
ML checkpoint systems are designed to minimize the amount of time the accelerator is halted by selectively persisting information that cannot be reconstructed. Specifically, this entails persisting model weights, optimizer state, dataloader state and random number generator state, which is a far smaller amount of data.
106
106
107
-
####Design
107
+
### Design
108
108
109
109
The [Orbax API](https://orbax.readthedocs.io/en/latest/index.html) centers around handling [PyTrees](https://docs.jax.dev/en/latest/pytrees.html) (nested containers) of arrays as the standard representation of JAX models. Saving and loading can be synchronous or asynchronous, with saving consisting of blocking and non-blocking phases. A higher-level `Checkpointer` class is provided, which facilitates checkpointing in a training loop, with save intervals, garbage collection, dataset checkpointing, and metadata management. Finally, Orbax provides customization layers for dealing with user-defined checkpointable objects and PyTree leaves.
110
110
@@ -125,23 +125,23 @@ Specific industry-leading performance features have their own design challenges,
125
125
*[**Restore \+ broadcast**](https://cloud.google.com/blog/products/compute/unlock-faster-workload-start-time-using-orbax-on-jax): Hero-scale training runs replicate the model weights among multiple data-parallel replicas. Orbax provides a load balancing feature that distributes the burden evenly among available replicas when saving. It also leverages fast chip interconnects to avoid redundant reads of the model on different groups of hosts, instead loading on a single primary replica and broadcasting the weights to all other replicas.
126
126
***Emergency checkpointing**: Hero-scale training suffers from frequent interruptions and hardware failures. Checkpointing to persistent RAM disk improves goodput for hero-scale jobs by allowing for increased checkpoint frequency, faster restore times, and improved resiliency, since TPU states may be corrupted on some replicas, but not all.
127
127
128
-
####Key Strengths
128
+
### Key Strengths
129
129
130
130
***Widespread adoption:** As checkpoints are a medium for communication of ML artifacts between different codebases and stages of ML development, widespread adoption is an inherent advantage. Currently, Orbax has [\~4 million](https://pypistats.org/packages/orbax-checkpoint) monthly package downloads.
131
131
***Easy to use:** Orbax abstracts away complex technical aspects of checkpointing like async saving, single- vs. multi-controller, checkpoint atomicity, distributed filesystem details, TPU vs. GPU, etc. It condenses use cases into simple, but generalizable APIs (direct-to-path, sequence-of-steps).
132
132
***Flexible:** While Orbax focuses on exposing a simple API surface for the majority of users, additional layers for handling custom checkpointable objects and PyTree nodes allow for flexibility in specialized use cases.
133
133
***Performant and scalable:** Orbax provides a variety of features designed to make checkpointing as fast and as unobtrusive as possible, freeing developers to focus on efficiency in the remainder of the training loop. Scalability to the cutting edge of ML research is a top concern of the library; training runs at a scale of O(10k) nodes currently rely on Orbax.
134
134
135
135
136
-
###Grain: Deterministic and Scalable Input Data Pipelines
136
+
## Grain: Deterministic and Scalable Input Data Pipelines
137
137
138
138
[Grain](https://google-grain.readthedocs.io/en/latest/) is a Python library for reading and processing data for training and evaluating JAX models. It is flexible, fast and deterministic and supports advanced features like checkpointing which are essential to successfully training large workloads. It supports popular data formats and storage backends and also provides a flexible API to extend support to user specific formats and backends that are not natively supported. While [Grain](https://google-grain.readthedocs.io/en/latest/) is primarily designed to work with JAX, it is framework independent, does not require JAX to run and can be used with other frameworks as well.
139
139
140
-
####Motivation
140
+
### Motivation
141
141
142
142
Data pipelines form a critical part of the training infrastructure \- they need to be flexible so that common transformations can be expressed efficiently, and performant enough that they are able to keep the accelerators busy at all times. They also need to be able to accommodate multiple storage formats and backends. Due to their higher step times, training large models at scale pose unique additional requirements on the data pipeline beyond those that are required by regular training workloads, primarily focused around determinism and reproducibility[^5]. The [Grain](https://google-grain.readthedocs.io/en/latest/) library is designed with a flexible enough architecture to address all these needs.
143
143
144
-
####Design
144
+
### Design
145
145
146
146
At the highest level, there are two ways to structure an input pipeline, as a separate cluster of data workers or by co-locating the data workers on the hosts that drive the accelerators. [Grain](https://google-grain.readthedocs.io/en/latest/) chooses the latter for a variety of reasons.
147
147
@@ -151,7 +151,7 @@ On the API front, with a pure python implementation that supports multiple proce
151
151
152
152
Out of the box, [Grain](https://google-grain.readthedocs.io/en/latest/) supports efficient random access data formats like `ArrayRecord` and `Bagz` alongside other popular data formats such as Parquet and `TFDS`. [Grain](https://google-grain.readthedocs.io/en/latest/) includes support for reading from local file systems as well as reading from GCS by default. Along with supporting popular storage formats and backends, a clean abstraction to the storage layer allows users to easily add support for or wrap their existing data sources to be compatible with the [Grain](https://google-grain.readthedocs.io/en/latest/) library.
153
153
154
-
####Key Strengths
154
+
### Key Strengths
155
155
156
156
***Deterministic data feeding:** Colocating the data worker with the accelerator and coupling it with a stable global shuffle and [checkpointable iterators](https://google-grain.readthedocs.io/en/latest/tutorials/data_loader_tutorial.html#checkpointing) allows the model state and data pipeline state to be checkpointed together in a consistent snapshot using [Orbax](https://orbax.readthedocs.io/en/latest/), enhancing the determinism of the training process.
157
157
***Flexible APIs to enable powerful data transformations:** A flexible pure Python [transformations](https://google-grain.readthedocs.io/en/latest/data_loader/transformations.html) API allows for extensive data transformations within the input processing pipeline.
0 commit comments