A minimal, high-performance, JAX-based implementation of 3D Gaussian Splatting. Restructured with a clean, modular architecture in the jax_gs package.
- Clean Architecture: Core logic modularized into
jax_gs(core, renderer, io, training). - Optimized Tile Rasterizer: JAX-native implementation with efficient bit-packed sorting for CPU, CUDA, and Apple Silicon (MPS).
- Fast GPU Execution: Optimized for NVIDIA L4 GPUs with full
float32throughput. - Resume Training: Continue training from any saved PLY checkpoint.
- Unit Tested: Comprehensive test suite for mathematical correctness and IO.
This project recommends using uv for fast Python package management.
-
Install uv:
curl -LsSf https://astral.sh/uv/install.sh | sh -
Setup Environment (
.cpu_env):uv venv .cpu_env --python 3.11 source .cpu_env/bin/activate uv pip install -r requirements_cpu.txt
- Download the Fern dataset (from NeRF LLFF data).
- Place it in
data/nerf_example_data/nerf_llff_data/fern. - The directory should contain
images_8/andsparse/0/.
Start a new training session on the Fern dataset:
python train_fern.py --num_iterations 10000Continue training from the latest .ply checkpoint:
python train_fern_resume.py --num_iterations 5000Parameters:
--num_iterations: Total iterations fortrain_fern.pyor additional iterations fortrain_fern_resume.py. Default is 10000.
Outputs:
- Progress Images:
results/fern_YYYYMMDD_HHMMSS/progress/. - PLY Checkpoints:
results/fern_YYYYMMDD_HHMMSS/ply/.
Visualize the results using the Viser-based viewers:
Visualize trained splats from a saved .ply checkpoint:
python viewer_ply.py results/fern_YYYYMMDD_HHMMSS/ply/fern_final.plyVisualize randomly generated 3D Gaussians to understand the representation:
python viewer_random.py --num 5000To verify mathematical correctness and IO stability, run the test suite using pytest.
# Recommended: Run on CPU for deterministic numerical checks
JAX_PLATFORMS=cpu PYTHONPATH=. pytest tests/If you encounter environment issues, you can explicitly point to your virtual environment's site-packages:
JAX_PLATFORMS=cpu PYTHONPATH=.:$(pwd)/.cpu_env/lib/python3.11/site-packages pytest tests/jax_gs/: Core package containing:core/:GaussiansandCameradata structures.renderer/: Tiled rasterization and projection kernels.io/: COLMAP and PLY loading/saving logic.training/: Loss functions and JIT-compiled trainer step.
tests/: Unit tests for each module.train_fern.py: Entry point for training.train_fern_resume.py: Entry point for resuming training.viewer_ply.py: Viser-based PLY visualization script.viewer_random.py: Viser-based random Gaussian visualization script.
