Bonsai is a minimal, lightweight JAX implementation of popular models.
We're committed to making popular models accessible in JAX through simple, hackable, and concise code. Our aim is to lower the barrier to entry for JAX and promote academic innovation.
Tip
For large-scale or industry use on Google Cloud, see MaxText and MaxDiffusion.
The following models are part of Bonsai. We have included the current model status here to easily convey which models are ready for full use. We categorize them as follows:
- ✅ Ready with broad support
- ⚙️ Adding additional features
- 🟡 In progress
- ⏳ Coming soon (has open PR)
These are listed based on status and then alphabetically.
| Model | Type | Status | Details |
|---|---|---|---|
| Densenet | Image classification | ✅ | |
| EfficientNet | Image classification | ✅ | |
| Qwen 3 | LLM | ✅ | |
| ResNet50 | Image classification | ✅ | |
| VGG | Image classification | ✅ | |
| ViT | Image classification | ⚙️ | Update to include sharding |
| LLaDa | Diffusion LLM | 🟡 | Need more numerical testing |
| Sam2 | Image segmentation | 🟡 | Need more numerical testing |
| UNet | Image | 🟡 | Need a reference implementation and numerical testing |
| VAE | Generative model | 🟡 | Need a reference implementation and numerical testing |
| Whisper | Speech recognition | 🟡 | Need more numerical testing and not all call methods implemented |
| ConvNeXt | Image classification | ⏳ |
Got models you'd like to see in JAX? Add a request or contribute. Please refer to the open issues and PRs before creating a new one to see if a feature is already being addressed.
To get started with JAX Bonsai, follow these steps to set up your development environment and run the models.
Clone the JAX Bonsai repository to your local machine.
git clone https://github.com/jax-ml/bonsai.git
cd bonsaiInstall the latest repository.
pip install -e .Jump right into our Qwen3 model, implemented in 400 lines of code in JAX.
python bonsai/models/qwen3/tests/run_model.pyWe welcome contributions! If you're interested in adding new models, improving existing implementations, or enhancing documentation, please see our Contributing Guidelines.
Join our discord to socialize with other JAX enthusiasts.
- JAX: Learn more about JAX, a super fast NumPy-based ML framework with automatic differentiation.
- The JAX ecosystem: Unlock unparalleled speed and scale for your next-generation models. Explore an incredible suite of tools and libraries that effortlessly extend JAX's capabilities, transforming how you build, train, and deploy.
- MaxText and MaxDiffusion: Industury solution for highly scalable, high-performant JAX model library via Google Cloud Platform.
- JAX LLM Examples: Example high-performant implementation of LLMs in pure JAX.