From 6c14902577cad8718d53476621b89aee3a05c54f Mon Sep 17 00:00:00 2001 From: Josh Veitch-Michaelis Date: Thu, 19 Mar 2026 08:56:55 -0400 Subject: [PATCH] add mixed precision training/matmul32 precision --- src/deepforest/conf/config.yaml | 8 ++++++++ src/deepforest/conf/schema.py | 2 ++ src/deepforest/main.py | 5 +++++ 3 files changed, 15 insertions(+) diff --git a/src/deepforest/conf/config.yaml b/src/deepforest/conf/config.yaml index e659aac5d..ca66f0828 100644 --- a/src/deepforest/conf/config.yaml +++ b/src/deepforest/conf/config.yaml @@ -17,6 +17,13 @@ model: name: 'weecology/deepforest-tree' revision: 'main' +# Trainer precision. Override to 16-mixed for faster training, 32-true for full precision. +precision: + +# On CUDA, setting the matmul precision can provide speed up without affecting model performance +# Start with 'high' for initial tests. +matmul_precision: highest + # Specify a label_dict to override model settings. # By default, this will be populated from the model # checkpoint that is selected in model.name/revision. @@ -77,6 +84,7 @@ train: # preload images to GPU memory for fast training. This depends on GPU size and number of images. preload_images: False + validation: csv_file: root_dir: diff --git a/src/deepforest/conf/schema.py b/src/deepforest/conf/schema.py index 9879538ee..91b195539 100644 --- a/src/deepforest/conf/schema.py +++ b/src/deepforest/conf/schema.py @@ -133,6 +133,8 @@ class Config: devices: int | str = "auto" accelerator: str = "auto" batch_size: int = 1 + precision: str | None = None + matmul_precision: str = "highest" architecture: str = "retinanet" num_classes: int | None = None diff --git a/src/deepforest/main.py b/src/deepforest/main.py index e22c9b4f9..a3269411d 100644 --- a/src/deepforest/main.py +++ b/src/deepforest/main.py @@ -231,6 +231,9 @@ def create_trainer(self, logger=None, callbacks=None, **kwargs): else: enable_checkpointing = False + if torch.cuda.is_available(): + torch.set_float32_matmul_precision(self.config.matmul_precision) + trainer_args = { "logger": logger, "max_epochs": self.config.train.epochs, @@ -243,6 +246,8 @@ def create_trainer(self, logger=None, callbacks=None, **kwargs): "num_sanity_val_steps": num_sanity_val_steps, "default_root_dir": self.config.log_root, } + if self.config.precision is not None: + trainer_args["precision"] = self.config.precision # Update with kwargs to allow them to override config trainer_args.update(kwargs)