Skip to content

jax-ml/bonsai

Repository files navigation

Bonsai

License

Bonsai is a minimal, lightweight JAX implementation of popular models.

We're committed to making popular models accessible in JAX through simple, hackable, and concise code. Our aim is to lower the barrier to entry for JAX and promote academic innovation.

Tip

For large-scale or industry use on Google Cloud, see MaxText and MaxDiffusion.

Models

The following models are part of Bonsai. We have included the current model status here to easily convey which models are ready for full use. We categorize them as follows:

  1. ✅ Ready with broad support
  2. ⚙️ Adding additional features
  3. 🟡 In progress
  4. ⏳ Coming soon (has open PR)

These are listed based on status and then alphabetically.

Model Type Status Details
Densenet Image classification
EfficientNet Image classification
Qwen 3 LLM
ResNet50 Image classification
VGG Image classification
ViT Image classification ⚙️ Update to include sharding
LLaDa Diffusion LLM 🟡 Need more numerical testing
Sam2 Image segmentation 🟡 Need more numerical testing
UNet Image 🟡 Need a reference implementation and numerical testing
VAE Generative model 🟡 Need a reference implementation and numerical testing
Whisper Speech recognition 🟡 Need more numerical testing and not all call methods implemented
ConvNeXt Image classification

Got models you'd like to see in JAX? Add a request or contribute. Please refer to the open issues and PRs before creating a new one to see if a feature is already being addressed.

🏁 Getting Started

To get started with JAX Bonsai, follow these steps to set up your development environment and run the models.

Installation

Clone the JAX Bonsai repository to your local machine.

git clone https://github.com/jax-ml/bonsai.git
cd bonsai

Install the latest repository.

pip install -e .

Running models

Jump right into our Qwen3 model, implemented in 400 lines of code in JAX.

python bonsai/models/qwen3/tests/run_model.py

Contributing

We welcome contributions! If you're interested in adding new models, improving existing implementations, or enhancing documentation, please see our Contributing Guidelines.

Join our discord to socialize with other JAX enthusiasts.

Useful Links

  • JAX: Learn more about JAX, a super fast NumPy-based ML framework with automatic differentiation.
  • The JAX ecosystem: Unlock unparalleled speed and scale for your next-generation models. Explore an incredible suite of tools and libraries that effortlessly extend JAX's capabilities, transforming how you build, train, and deploy.
  • MaxText and MaxDiffusion: Industury solution for highly scalable, high-performant JAX model library via Google Cloud Platform.
  • JAX LLM Examples: Example high-performant implementation of LLMs in pure JAX.

About

Minimal, lightweight JAX implementations of popular models.

Topics

Resources

License

Contributing

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published