A batteries-included template for PyTorch machine learning projects using Lightning, wandb, and modern Python tooling.
- PyTorch Lightning: Structured training framework with minimal boilerplate
- Pydantic Configuration: Type-safe configuration management
- Weights & Biases: Integrated experiment tracking
- Modern Tooling: Built with
uvfor fast dependency management - Code Quality: Pre-configured with
ruff,mypy,pytest, andpre-commithooks - Git-based Versioning: Automatic experiment naming using git commit hashes
.
├── src/template/
│ ├── config.py # Pydantic configuration classes
│ ├── lightning_module.py # Base Lightning module
│ ├── datasets/ # Dataset implementations
│ ├── modeling/ # Model architectures
│ └── scripts/
│ └── train.py # Training script
├── tests/ # Test files
├── pyproject.toml # Project metadata and dependencies
├── .pre-commit-config.yaml # Pre-commit hooks configuration
└── .env.example # Example environment variables
curl -LsSf https://astral.sh/uv/install.sh | shWhen creating a new project from this template:
- Clone or fork this repository
- Rename the
src/templatedirectory to your project name:mv src/template src/your_project_name
- Update
pyproject.toml:- Change
name = "template"to your project name - Update
module-name = ["template"]to your project name - Update the
trainscript path in[project.scripts]
- Change
- Update import statements in Python files to use your new project name
uv syncCopy the example environment file and add your API keys:
cp .env.example .env
# Edit .env and add your wandb API key and other credentialsuv run pre-commit installRun the training script:
uv run train <data_root> --project my-project --num_devices 1Available arguments:
data_root: Path to your dataset (required)--project: Wandb project name (default: "jigsaw-2025")--num_devices: Number of GPUs to use (default: 1)--num_workers: Number of data loading workers (default: 12)--log_root: Directory for logs and checkpoints (default: "logs")--checkpoint_path: Resume from checkpoint--weights_path: Load model weights--debug: Enable debug mode--fast_dev_run: Run a quick test with minimal data
Edit src/template/config.py to customize hyperparameters:
from pydantic import BaseModel
class Config(BaseModel):
# Reproducibility
seed: int = 42
# Data
test_split: float = 0.1
batch_size: int = 16
# Training
max_epochs: int = 200
early_stopping_patience: int = 30
learning_rate: float = 1e-4
min_learning_rate: float = 1e-6
weight_decay: float = 1e-2-
Create your Lightning module by inheriting from
BaseLightningModule:from template.lightning_module import BaseLightningModule class MyModel(BaseLightningModule): def training_step(self, batch, batch_idx): # Your training logic here pass def validation_step(self, batch, batch_idx): # Your validation logic here pass
-
Add your dataset in
src/template/datasets/:from torch.utils.data import Dataset class MyDataset(Dataset): def __init__(self, data_root, config): # Initialize your dataset pass
-
Update the training script to use your model and dataset
uv run pytestuv run mypy src/uv run ruff check src/
uv run ruff format src/Pre-commit hooks will automatically run on every commit to ensure code quality. To run manually:
uv run pre-commit run --all-filesCore dependencies:
- PyTorch: Deep learning framework (with GPU support)
- Lightning: High-level PyTorch wrapper
- Pydantic: Data validation and configuration
- Wandb: Experiment tracking
- python-dotenv: Environment variable management
Development tools:
- ruff: Fast Python linter and formatter
- mypy: Static type checker
- pytest: Testing framework
- pre-commit: Git hooks for code quality
This project uses uv_build as the build backend, which is significantly faster than traditional build systems like setuptools or hatchling.
To build the project:
uv buildSee LICENSE file for details.
- Create a new branch for your feature
- Make your changes
- Ensure all tests pass and pre-commit hooks succeed
- Submit a pull request
- Experiment names automatically include the git commit hash for reproducibility
- Use
.envfor sensitive information (API keys, credentials) - The config system uses Pydantic for type safety and validation
- Lightning automatically handles distributed training, gradient accumulation, and mixed precision