From 7aa68a4c5fa7205cfc39ceed91f51df79ff5dd46 Mon Sep 17 00:00:00 2001 From: Tal <21198860+mrT23@users.noreply.github.com> Date: Wed, 15 Oct 2025 14:35:06 +0300 Subject: [PATCH 1/6] Add knowledge distillation model and loss function support --- timm/utils/model_kd.py | 77 ++++++++++++++++++++++++++++++++++++++++++ train.py | 21 ++++++++++++ 2 files changed, 98 insertions(+) create mode 100644 timm/utils/model_kd.py diff --git a/timm/utils/model_kd.py b/timm/utils/model_kd.py new file mode 100644 index 0000000000..45d50833ba --- /dev/null +++ b/timm/utils/model_kd.py @@ -0,0 +1,77 @@ +import logging +import torch +import torch.nn as nn +import torchvision.transforms as T +from timm.models import create_model + +_logger = logging.getLogger(__name__) + +class build_kd_model(nn.Module): + def __init__(self, args): + super(build_kd_model, self).__init__() + + _logger.info(f"Creating KD model: from '{args.kd_model_name}'") + in_chans = 3 + if args.in_chans is not None: + in_chans = args.in_chans + model_kd = create_model( + model_name=args.kd_model_name, + num_classes=args.num_classes, + pretrained=True, + in_chans=in_chans) + + # compile model + model_kd.cpu().eval() + try: + model_kd = torch.compile(model_kd) + _logger.info(f"torch.compile applied successfully to KD model") + except Exception as e: + _logger.warning(f"torch.compile failed with error {e}, continuing KD model without torch compilation") + + self.model = model_kd.cuda() + self.mean_model_kd = model_kd.default_cfg['mean'] + self.std_model_kd = model_kd.default_cfg['std'] + + # handling different normalization of teacher and student + def normalize_input(self, input, student_model): + if hasattr(student_model, 'module'): + model_s = student_model.module + else: + model_s = student_model + + mean_student = model_s.default_cfg['mean'] + std_student = model_s.default_cfg['std'] + + input_kd = input + if mean_student != self.mean_model_kd or std_student != self.std_model_kd: + std = (self.std_model_kd[0] / std_student[0], self.std_model_kd[1] / std_student[1], + self.std_model_kd[2] / std_student[2]) + transform_std = T.Normalize(mean=(0, 0, 0), std=std) + + mean = (self.mean_model_kd[0] - mean_student[0], self.mean_model_kd[1] - mean_student[1], + self.mean_model_kd[2] - mean_student[2]) + transform_mean = T.Normalize(mean=mean, std=(1, 1, 1)) + + input_kd = transform_mean(transform_std(input)) + + return input_kd + + +def add_kd_loss(_loss, output, input, model, model_kd, args): + # student probability calculation + prob_s = torch.nn.functional.log_softmax(output, dim=-1) + + # teacher probability calculation + with torch.no_grad(): + input_kd = model_kd.normalize_input(input, model) + out_t = model_kd.model(input_kd.detach()) + prob_t = torch.nn.functional.softmax(out_t, dim=-1) + + # adding KL loss + if not args.use_kd_only_loss: + _loss += args.alpha_kd * torch.nn.functional.kl_div(prob_s, prob_t, reduction='batchmean') + else: # only kd + _loss = args.alpha_kd * torch.nn.functional.kl_div(prob_s, prob_t, reduction='batchmean') + + return _loss + diff --git a/train.py b/train.py index 0dacdc1dc2..38cbb2f0b4 100755 --- a/train.py +++ b/train.py @@ -41,6 +41,7 @@ from timm.optim import create_optimizer_v2, optimizer_kwargs from timm.scheduler import create_scheduler_v2, scheduler_kwargs from timm.utils import ApexScaler, NativeScaler +from timm.utils.model_kd import build_kd_model, add_kd_loss try: from apex import amp @@ -417,6 +418,14 @@ group.add_argument('--naflex-loss-scale', default='linear', type=str, help='Scale loss (gradient) by batch_size ("none", "sqrt", or "linear")') +# Knowledge Distillation parameters +parser.add_argument('--kd-model-name', default=None, type=str, + help='Name of teacher model for knowledge distillation') +parser.add_argument('--alpha-kd', default=5, type=float, + help='Weight for KD loss (default: 5)') +parser.add_argument('--use-kd-only-loss', action='store_true', default=False, + help='Use only KD loss, without cross-entropy loss') + def _parse_args(): # Do we have a config file to parse? @@ -482,6 +491,11 @@ def main(): utils.random_seed(args.seed, args.rank) + # Create the KD teacher model if specified + model_kd = None + if args.kd_model_name is not None: + model_kd = build_kd_model(args) + if args.fuser: utils.set_jit_fuser(args.fuser) if args.fast_norm: @@ -1008,6 +1022,7 @@ def main(): mixup_fn=mixup_fn, num_updates_total=num_epochs * updates_per_epoch, naflex_mode=naflex_mode, + model_kd=model_kd, ) if args.distributed and args.dist_bn in ('broadcast', 'reduce'): @@ -1123,6 +1138,7 @@ def train_one_epoch( mixup_fn=None, num_updates_total=None, naflex_mode=False, + model_kd=None, ): if args.mixup_off_epoch and epoch >= args.mixup_off_epoch: if args.prefetcher and loader.mixup_enabled: @@ -1169,6 +1185,11 @@ def _forward(): with amp_autocast(): output = model(input) _loss = loss_fn(output, target) + + # KD logic + if model_kd is not None: + _loss= add_kd_loss(_loss, output, input, model, model_kd, args) + if accum_steps > 1: _loss /= accum_steps return _loss From 665714264988a69624b75a063182cb2ec07a6577 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 17 Oct 2025 12:45:37 -0700 Subject: [PATCH 2/6] Cleanup distillation code --- timm/kd/__init__.py | 4 ++ timm/kd/distillation.py | 142 ++++++++++++++++++++++++++++++++++++++++ timm/utils/model_kd.py | 77 ---------------------- train.py | 28 ++++++-- 4 files changed, 167 insertions(+), 84 deletions(-) create mode 100644 timm/kd/__init__.py create mode 100644 timm/kd/distillation.py delete mode 100644 timm/utils/model_kd.py diff --git a/timm/kd/__init__.py b/timm/kd/__init__.py new file mode 100644 index 0000000000..8b3d7f2cea --- /dev/null +++ b/timm/kd/__init__.py @@ -0,0 +1,4 @@ +"""Knowledge Distillation module for timm""" +from .distillation import DistillationTeacher, apply_kd_loss + +__all__ = ['DistillationTeacher', 'apply_kd_loss'] diff --git a/timm/kd/distillation.py b/timm/kd/distillation.py new file mode 100644 index 0000000000..b9d051993d --- /dev/null +++ b/timm/kd/distillation.py @@ -0,0 +1,142 @@ +"""Knowledge Distillation helpers for training with a teacher model.""" +import logging +from typing import Tuple + +import torch +import torch.nn as nn +import torchvision.transforms as T + +from timm.models import create_model + + +_logger = logging.getLogger(__name__) + + +class DistillationTeacher(nn.Module): + """Wrapper for a teacher model used in knowledge distillation. + + Creates and manages a pre-trained teacher model for knowledge distillation, + handling model compilation and normalization differences between teacher and student. + + Args: + model_name: Name of the teacher model to create + num_classes: Number of output classes + in_chans: Number of input channels + pretrained: Whether to load pretrained weights + device: Device to place the model on (default: 'cuda') + dtype: Model dtype (default: None, uses float32) + """ + + def __init__( + self, + model_name: str, + num_classes: int, + in_chans: int = 3, + device: torch.device = torch.device('cuda'), + dtype: torch.dtype = None, + ): + super().__init__() + + _logger.info(f"Creating KD teacher model: '{model_name}'") + + model_kd = create_model( + model_name=model_name, + num_classes=num_classes, + pretrained=True, + in_chans=in_chans, + ) + + model_kd = model_kd.to(device=device, dtype=dtype) + model_kd.eval() + + try: + model_kd = torch.compile(model_kd) + _logger.info("torch.compile applied successfully to KD teacher model") + except Exception as e: + _logger.warning(f"torch.compile failed with error {e}, continuing without compilation") + + self.model = model_kd + self.mean_model_kd = model_kd.pretrained_cfg['mean'] + self.std_model_kd = model_kd.pretrained_cfg['std'] + + def normalize_input( + self, + input: torch.Tensor, + student_model: nn.Module, + ) -> torch.Tensor: + """Normalize input to match teacher's expected normalization. + + Handles different normalization between teacher and student models by + converting the student's normalized input to the teacher's expected format. + + Args: + input: Input tensor (already normalized for student) + student_model: Student model to extract normalization params from + + Returns: + Input tensor normalized for the teacher model + """ + if hasattr(student_model, 'module'): + model_s = student_model.module + else: + model_s = student_model + + mean_student = model_s.pretrained_cfg['mean'] + std_student = model_s.pretrained_cfg['std'] + + input_kd = input + if mean_student != self.mean_model_kd or std_student != self.std_model_kd: + # Compute normalized std and mean transformations + std = tuple(t_std / s_std for t_std, s_std in zip(self.std_model_kd, std_student)) + transform_std = T.Normalize(mean=(0, 0, 0), std=std) + + mean = tuple(t_mean - s_mean for t_mean, s_mean in zip(self.mean_model_kd, mean_student)) + transform_mean = T.Normalize(mean=mean, std=(1, 1, 1)) + + input_kd = transform_mean(transform_std(input)) + + return input_kd + + +def apply_kd_loss( + loss: torch.Tensor, + student_output: torch.Tensor, + input: torch.Tensor, + student_model: nn.Module, + teacher_model: DistillationTeacher, + alpha_kd: float, + use_kd_only: bool = False, +) -> torch.Tensor: + """Apply knowledge distillation loss. + + Computes KL divergence between student and teacher outputs and combines + with the base loss (or replaces it if use_kd_only is True). + + Args: + loss: Base loss (e.g., cross-entropy with labels) + student_output: Logits from student model + input: Input tensor (already normalized for student) + student_model: Student model being trained + teacher_model: Teacher model for distillation + alpha_kd: Weight for the KD loss component + use_kd_only: If True, only use KD loss (ignore base loss) + + Returns: + Combined loss with KD component + """ + # Student probability calculation + prob_s = torch.nn.functional.log_softmax(student_output, dim=-1) + + # Teacher probability calculation + with torch.inference_mode(): + input_kd = teacher_model.normalize_input(input, student_model) + out_t = teacher_model.model(input_kd.detach()) + prob_t = torch.nn.functional.softmax(out_t, dim=-1) + + # Compute KL divergence loss + kd_loss = alpha_kd * torch.nn.functional.kl_div(prob_s, prob_t, reduction='batchmean') + + if use_kd_only: + return kd_loss + else: + return loss + kd_loss diff --git a/timm/utils/model_kd.py b/timm/utils/model_kd.py deleted file mode 100644 index 45d50833ba..0000000000 --- a/timm/utils/model_kd.py +++ /dev/null @@ -1,77 +0,0 @@ -import logging -import torch -import torch.nn as nn -import torchvision.transforms as T -from timm.models import create_model - -_logger = logging.getLogger(__name__) - -class build_kd_model(nn.Module): - def __init__(self, args): - super(build_kd_model, self).__init__() - - _logger.info(f"Creating KD model: from '{args.kd_model_name}'") - in_chans = 3 - if args.in_chans is not None: - in_chans = args.in_chans - model_kd = create_model( - model_name=args.kd_model_name, - num_classes=args.num_classes, - pretrained=True, - in_chans=in_chans) - - # compile model - model_kd.cpu().eval() - try: - model_kd = torch.compile(model_kd) - _logger.info(f"torch.compile applied successfully to KD model") - except Exception as e: - _logger.warning(f"torch.compile failed with error {e}, continuing KD model without torch compilation") - - self.model = model_kd.cuda() - self.mean_model_kd = model_kd.default_cfg['mean'] - self.std_model_kd = model_kd.default_cfg['std'] - - # handling different normalization of teacher and student - def normalize_input(self, input, student_model): - if hasattr(student_model, 'module'): - model_s = student_model.module - else: - model_s = student_model - - mean_student = model_s.default_cfg['mean'] - std_student = model_s.default_cfg['std'] - - input_kd = input - if mean_student != self.mean_model_kd or std_student != self.std_model_kd: - std = (self.std_model_kd[0] / std_student[0], self.std_model_kd[1] / std_student[1], - self.std_model_kd[2] / std_student[2]) - transform_std = T.Normalize(mean=(0, 0, 0), std=std) - - mean = (self.mean_model_kd[0] - mean_student[0], self.mean_model_kd[1] - mean_student[1], - self.mean_model_kd[2] - mean_student[2]) - transform_mean = T.Normalize(mean=mean, std=(1, 1, 1)) - - input_kd = transform_mean(transform_std(input)) - - return input_kd - - -def add_kd_loss(_loss, output, input, model, model_kd, args): - # student probability calculation - prob_s = torch.nn.functional.log_softmax(output, dim=-1) - - # teacher probability calculation - with torch.no_grad(): - input_kd = model_kd.normalize_input(input, model) - out_t = model_kd.model(input_kd.detach()) - prob_t = torch.nn.functional.softmax(out_t, dim=-1) - - # adding KL loss - if not args.use_kd_only_loss: - _loss += args.alpha_kd * torch.nn.functional.kl_div(prob_s, prob_t, reduction='batchmean') - else: # only kd - _loss = args.alpha_kd * torch.nn.functional.kl_div(prob_s, prob_t, reduction='batchmean') - - return _loss - diff --git a/train.py b/train.py index 38cbb2f0b4..440aa01487 100755 --- a/train.py +++ b/train.py @@ -41,7 +41,7 @@ from timm.optim import create_optimizer_v2, optimizer_kwargs from timm.scheduler import create_scheduler_v2, scheduler_kwargs from timm.utils import ApexScaler, NativeScaler -from timm.utils.model_kd import build_kd_model, add_kd_loss +from timm.kd import DistillationTeacher, apply_kd_loss try: from apex import amp @@ -491,11 +491,6 @@ def main(): utils.random_seed(args.seed, args.rank) - # Create the KD teacher model if specified - model_kd = None - if args.kd_model_name is not None: - model_kd = build_kd_model(args) - if args.fuser: utils.set_jit_fuser(args.fuser) if args.fast_norm: @@ -545,6 +540,17 @@ def main(): if args.grad_checkpointing: model.set_grad_checkpointing(enable=True) + # Create the KD teacher model if specified + model_kd = None + if args.kd_model_name is not None: + model_kd = DistillationTeacher( + model_name=args.kd_model_name, + num_classes=args.num_classes, + in_chans=in_chans, + device=device, + dtype=model_dtype, + ) + if utils.is_primary(args): _logger.info( f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}') @@ -1188,7 +1194,15 @@ def _forward(): # KD logic if model_kd is not None: - _loss= add_kd_loss(_loss, output, input, model, model_kd, args) + _loss = apply_kd_loss( + loss=_loss, + student_output=output, + input=input, + student_model=model, + teacher_model=model_kd, + alpha_kd=args.alpha_kd, + use_kd_only=args.use_kd_only_loss, + ) if accum_steps > 1: _loss /= accum_steps From 743c3757586dba4aef6a6649ecf3b52194ee080d Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 17 Oct 2025 12:54:08 -0700 Subject: [PATCH 3/6] Keep as no_grad --- timm/kd/distillation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/kd/distillation.py b/timm/kd/distillation.py index b9d051993d..109d8703c2 100644 --- a/timm/kd/distillation.py +++ b/timm/kd/distillation.py @@ -128,7 +128,7 @@ def apply_kd_loss( prob_s = torch.nn.functional.log_softmax(student_output, dim=-1) # Teacher probability calculation - with torch.inference_mode(): + with torch.no_grad(): input_kd = teacher_model.normalize_input(input, student_model) out_t = teacher_model.model(input_kd.detach()) prob_t = torch.nn.functional.softmax(out_t, dim=-1) From 080b55b3627a1cdb00986273a504e64ad01a836d Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 17 Oct 2025 16:38:10 -0700 Subject: [PATCH 4/6] Add pretrained_path arg for kd --- timm/kd/distillation.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/timm/kd/distillation.py b/timm/kd/distillation.py index 109d8703c2..48d708fd5d 100644 --- a/timm/kd/distillation.py +++ b/timm/kd/distillation.py @@ -1,6 +1,6 @@ """Knowledge Distillation helpers for training with a teacher model.""" import logging -from typing import Tuple +from typing import Optional, Tuple import torch import torch.nn as nn @@ -22,7 +22,6 @@ class DistillationTeacher(nn.Module): model_name: Name of the teacher model to create num_classes: Number of output classes in_chans: Number of input channels - pretrained: Whether to load pretrained weights device: Device to place the model on (default: 'cuda') dtype: Model dtype (default: None, uses float32) """ @@ -32,6 +31,7 @@ def __init__( model_name: str, num_classes: int, in_chans: int = 3, + pretrained_path: Optional[str] = None, device: torch.device = torch.device('cuda'), dtype: torch.dtype = None, ): @@ -39,11 +39,19 @@ def __init__( _logger.info(f"Creating KD teacher model: '{model_name}'") + pretrained_kwargs = {'pretrained': True} + if pretrained_path: + # specify a local checkpoint path to load pretrained weights from + pretrained_kwargs['pretrained_cfg_overlay'] = dict( + file=pretrained_path, + num_classes=num_classes, # needed to avoid head adaptation? + ) + model_kd = create_model( model_name=model_name, num_classes=num_classes, - pretrained=True, in_chans=in_chans, + **pretrained_kwargs, ) model_kd = model_kd.to(device=device, dtype=dtype) From baa1eabd218e1e1b1c9ad143f040154954066f32 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 27 Nov 2025 08:47:57 -0800 Subject: [PATCH 5/6] Revamp distillation by using a lightweight task arch. Encapsulates extra projections, etc that may be needed. --- timm/kd/__init__.py | 4 - timm/kd/distillation.py | 150 ---------- timm/task/__init__.py | 17 ++ timm/task/classification.py | 90 ++++++ timm/task/distillation.py | 574 ++++++++++++++++++++++++++++++++++++ timm/task/task.py | 100 +++++++ train.py | 161 ++++++---- 7 files changed, 888 insertions(+), 208 deletions(-) delete mode 100644 timm/kd/__init__.py delete mode 100644 timm/kd/distillation.py create mode 100644 timm/task/__init__.py create mode 100644 timm/task/classification.py create mode 100644 timm/task/distillation.py create mode 100644 timm/task/task.py diff --git a/timm/kd/__init__.py b/timm/kd/__init__.py deleted file mode 100644 index 8b3d7f2cea..0000000000 --- a/timm/kd/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -"""Knowledge Distillation module for timm""" -from .distillation import DistillationTeacher, apply_kd_loss - -__all__ = ['DistillationTeacher', 'apply_kd_loss'] diff --git a/timm/kd/distillation.py b/timm/kd/distillation.py deleted file mode 100644 index 48d708fd5d..0000000000 --- a/timm/kd/distillation.py +++ /dev/null @@ -1,150 +0,0 @@ -"""Knowledge Distillation helpers for training with a teacher model.""" -import logging -from typing import Optional, Tuple - -import torch -import torch.nn as nn -import torchvision.transforms as T - -from timm.models import create_model - - -_logger = logging.getLogger(__name__) - - -class DistillationTeacher(nn.Module): - """Wrapper for a teacher model used in knowledge distillation. - - Creates and manages a pre-trained teacher model for knowledge distillation, - handling model compilation and normalization differences between teacher and student. - - Args: - model_name: Name of the teacher model to create - num_classes: Number of output classes - in_chans: Number of input channels - device: Device to place the model on (default: 'cuda') - dtype: Model dtype (default: None, uses float32) - """ - - def __init__( - self, - model_name: str, - num_classes: int, - in_chans: int = 3, - pretrained_path: Optional[str] = None, - device: torch.device = torch.device('cuda'), - dtype: torch.dtype = None, - ): - super().__init__() - - _logger.info(f"Creating KD teacher model: '{model_name}'") - - pretrained_kwargs = {'pretrained': True} - if pretrained_path: - # specify a local checkpoint path to load pretrained weights from - pretrained_kwargs['pretrained_cfg_overlay'] = dict( - file=pretrained_path, - num_classes=num_classes, # needed to avoid head adaptation? - ) - - model_kd = create_model( - model_name=model_name, - num_classes=num_classes, - in_chans=in_chans, - **pretrained_kwargs, - ) - - model_kd = model_kd.to(device=device, dtype=dtype) - model_kd.eval() - - try: - model_kd = torch.compile(model_kd) - _logger.info("torch.compile applied successfully to KD teacher model") - except Exception as e: - _logger.warning(f"torch.compile failed with error {e}, continuing without compilation") - - self.model = model_kd - self.mean_model_kd = model_kd.pretrained_cfg['mean'] - self.std_model_kd = model_kd.pretrained_cfg['std'] - - def normalize_input( - self, - input: torch.Tensor, - student_model: nn.Module, - ) -> torch.Tensor: - """Normalize input to match teacher's expected normalization. - - Handles different normalization between teacher and student models by - converting the student's normalized input to the teacher's expected format. - - Args: - input: Input tensor (already normalized for student) - student_model: Student model to extract normalization params from - - Returns: - Input tensor normalized for the teacher model - """ - if hasattr(student_model, 'module'): - model_s = student_model.module - else: - model_s = student_model - - mean_student = model_s.pretrained_cfg['mean'] - std_student = model_s.pretrained_cfg['std'] - - input_kd = input - if mean_student != self.mean_model_kd or std_student != self.std_model_kd: - # Compute normalized std and mean transformations - std = tuple(t_std / s_std for t_std, s_std in zip(self.std_model_kd, std_student)) - transform_std = T.Normalize(mean=(0, 0, 0), std=std) - - mean = tuple(t_mean - s_mean for t_mean, s_mean in zip(self.mean_model_kd, mean_student)) - transform_mean = T.Normalize(mean=mean, std=(1, 1, 1)) - - input_kd = transform_mean(transform_std(input)) - - return input_kd - - -def apply_kd_loss( - loss: torch.Tensor, - student_output: torch.Tensor, - input: torch.Tensor, - student_model: nn.Module, - teacher_model: DistillationTeacher, - alpha_kd: float, - use_kd_only: bool = False, -) -> torch.Tensor: - """Apply knowledge distillation loss. - - Computes KL divergence between student and teacher outputs and combines - with the base loss (or replaces it if use_kd_only is True). - - Args: - loss: Base loss (e.g., cross-entropy with labels) - student_output: Logits from student model - input: Input tensor (already normalized for student) - student_model: Student model being trained - teacher_model: Teacher model for distillation - alpha_kd: Weight for the KD loss component - use_kd_only: If True, only use KD loss (ignore base loss) - - Returns: - Combined loss with KD component - """ - # Student probability calculation - prob_s = torch.nn.functional.log_softmax(student_output, dim=-1) - - # Teacher probability calculation - with torch.no_grad(): - input_kd = teacher_model.normalize_input(input, student_model) - out_t = teacher_model.model(input_kd.detach()) - prob_t = torch.nn.functional.softmax(out_t, dim=-1) - - # Compute KL divergence loss - kd_loss = alpha_kd * torch.nn.functional.kl_div(prob_s, prob_t, reduction='batchmean') - - if use_kd_only: - return kd_loss - else: - return loss + kd_loss diff --git a/timm/task/__init__.py b/timm/task/__init__.py new file mode 100644 index 0000000000..625488fd25 --- /dev/null +++ b/timm/task/__init__.py @@ -0,0 +1,17 @@ +"""Training task abstractions for timm. + +This module provides task-based abstractions for training loops where each task +encapsulates both the forward pass and loss computation, returning a dictionary +with loss components and outputs for logging. +""" +from .task import TrainingTask +from .classification import ClassificationTask +from .distillation import DistillationTeacher, LogitDistillationTask, FeatureDistillationTask + +__all__ = [ + 'TrainingTask', + 'ClassificationTask', + 'DistillationTeacher', + 'LogitDistillationTask', + 'FeatureDistillationTask', +] diff --git a/timm/task/classification.py b/timm/task/classification.py new file mode 100644 index 0000000000..2f81871b3a --- /dev/null +++ b/timm/task/classification.py @@ -0,0 +1,90 @@ +"""Classification training task.""" +import logging +from typing import Callable, Dict, Optional, Union + +import torch +import torch.nn as nn + +from .task import TrainingTask + +_logger = logging.getLogger(__name__) + + +class ClassificationTask(TrainingTask): + """Standard supervised classification task. + + Simple task that performs a forward pass through the model and computes + the classification loss. + + Args: + model: The model to train + criterion: Loss function (e.g., CrossEntropyLoss) + device: Device for task tensors/buffers + dtype: Dtype for task tensors/buffers + verbose: Enable info logging + + Example: + >>> task = ClassificationTask(model, nn.CrossEntropyLoss(), device=torch.device('cuda')) + >>> result = task(input, target) + >>> result['loss'].backward() + """ + + def __init__( + self, + model: nn.Module, + criterion: Union[nn.Module, Callable], + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + verbose: bool = True, + ): + super().__init__(device=device, dtype=dtype, verbose=verbose) + self.model = model + self.criterion = criterion + + if self.verbose: + loss_name = getattr(criterion, '__name__', None) or type(criterion).__name__ + _logger.info(f"ClassificationTask: criterion={loss_name}") + + def prepare_distributed( + self, + device_ids: Optional[list] = None, + **ddp_kwargs + ) -> 'ClassificationTask': + """Prepare task for distributed training. + + Wraps the model in DistributedDataParallel (DDP). + + Args: + device_ids: List of device IDs for DDP (e.g., [local_rank]) + **ddp_kwargs: Additional arguments passed to DistributedDataParallel + + Returns: + self (for method chaining) + """ + from torch.nn.parallel import DistributedDataParallel as DDP + self.model = DDP(self.model, device_ids=device_ids, **ddp_kwargs) + return self + + def forward( + self, + input: torch.Tensor, + target: torch.Tensor, + ) -> Dict[str, torch.Tensor]: + """Forward pass through model and compute classification loss. + + Args: + input: Input tensor [B, C, H, W] + target: Target labels [B] + + Returns: + Dictionary containing: + - 'loss': Classification loss + - 'output': Model logits + """ + output = self.model(input) + loss = self.criterion(output, target) + + return { + 'loss': loss, + 'output': output, + } diff --git a/timm/task/distillation.py b/timm/task/distillation.py new file mode 100644 index 0000000000..d7f83c5330 --- /dev/null +++ b/timm/task/distillation.py @@ -0,0 +1,574 @@ +"""Knowledge distillation training tasks and components.""" +import logging +from typing import Dict, Optional, Literal, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.models import create_model +from timm.utils import unwrap_model + +from .task import TrainingTask + + +_logger = logging.getLogger(__name__) + + +class DistillationTeacher(nn.Module): + """Wrapper for a teacher model used in knowledge distillation. + + Creates and manages a pre-trained teacher model for knowledge distillation, + handling model compilation and normalization differences between teacher and student. + + Args: + model_name: Name of the teacher model to create + num_classes: Number of output classes + in_chans: Number of input channels + pretrained_path: Optional path to pretrained weights + device: Device to place the model on + dtype: Model dtype (uses float32 if None) + """ + + def __init__( + self, + model_name: str, + num_classes: int, + in_chans: int = 3, + pretrained_path: Optional[str] = None, + device: torch.device = torch.device('cuda'), + dtype: torch.dtype = None, + ): + super().__init__() + + _logger.info(f"Creating KD teacher model: '{model_name}'") + + pretrained_kwargs = {'pretrained': True} + if pretrained_path: + # specify a local checkpoint path to load pretrained weights from + pretrained_kwargs['pretrained_cfg_overlay'] = dict( + file=pretrained_path, + num_classes=num_classes, + ) + + model_kd = create_model( + model_name=model_name, + num_classes=num_classes, + in_chans=in_chans, + device=device, + dtype=dtype, + **pretrained_kwargs, + ) + + model_kd.eval() + self.model = model_kd + + # Register normalization values as non-persistent buffers + # Shape: [1, 3, 1, 1] for proper broadcasting over BCHW images + mean_kd = torch.tensor(model_kd.pretrained_cfg['mean'], device=device, dtype=dtype).view(1, -1, 1, 1) + std_kd = torch.tensor(model_kd.pretrained_cfg['std'], device=device, dtype=dtype).view(1, -1, 1, 1) + self.register_buffer('mean_kd', mean_kd, persistent=False) + self.register_buffer('std_kd', std_kd, persistent=False) + + def forward( + self, + input: torch.Tensor, + return_features: bool = False, + ) -> torch.Tensor: + """Forward pass through teacher model. + + Args: + input: Input tensor (should already be normalized for teacher) + return_features: Whether to return pooled pre-logits features instead of logits + + Returns: + Logits or pooled pre-logits features depending on return_features flag + """ + if return_features: + if not hasattr(self.model, 'forward_features') or not hasattr(self.model, 'forward_head'): + raise ValueError( + f"Model {self.model.__class__.__name__} does not support feature extraction. " + "Ensure the model has 'forward_features' and 'forward_head' methods." + ) + # Extract spatial features and pool to pre-logits + feature_map = self.model.forward_features(input) + return self.model.forward_head(feature_map, pre_logits=True) + else: + return self.model(input) + + def normalize_input( + self, + input: torch.Tensor, + student_mean: Optional[torch.Tensor] = None, + student_std: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Normalize input to match teacher's expected normalization. + + Handles different normalization between teacher and student models by + converting the student's normalized input to the teacher's expected format. + + Args: + input: Input tensor (already normalized for student) + student_mean: Student normalization mean buffer [1, 3, 1, 1] (None if same as teacher) + student_std: Student normalization std buffer [1, 3, 1, 1] (None if same as teacher) + + Returns: + Input tensor normalized for the teacher model + """ + # If no student normalization provided, assume it matches teacher (no conversion needed) + if student_mean is None or student_std is None: + return input + + # Check if renormalization is actually needed + if torch.equal(student_mean, self.mean_kd) and torch.equal(student_std, self.std_kd): + return input + + # De-normalize (Student) -> Re-normalize (Teacher) + # Combined for efficiency: (input * std_s + mean_s - mean_t) / std_t + return (input * student_std + student_mean - self.mean_kd) / self.std_kd + + +class LogitDistillationTask(TrainingTask): + """Logit-based knowledge distillation task. + + Performs distillation by matching student and teacher output logits using + KL divergence with temperature scaling. + + Loss weighting supports two modes: + 1. Independent weights: loss = task_loss_weight * task_loss + distill_loss_weight * distill_loss + 2. Complementary mode: loss = task_loss_weight * task_loss + (1 - task_loss_weight) * distill_loss + (used when only task_loss_weight is specified) + + Args: + student_model: Student model to train + teacher: Pre-configured teacher model wrapper + criterion: Task loss function (e.g., CrossEntropyLoss) + loss_type: Type of distillation loss (currently only 'kl' supported, reserved for future extensions) + distill_loss_weight: Weight for distillation loss + task_loss_weight: Weight for task loss + temperature: Softmax temperature for distillation (typical values: 1-4) + device: Device for task tensors/buffers + dtype: Dtype for task tensors/buffers + verbose: Enable info logging + + Example: + >>> # Independent weights + >>> task = LogitDistillationTask( + ... student_model=model, teacher=teacher, criterion=nn.CrossEntropyLoss(), + ... distill_loss_weight=1.0, task_loss_weight=1.0, temperature=4.0, + ... device=torch.device('cuda'), + ... ) + >>> # Complementary mode (task_weight=0.3 means distill gets 0.7) + >>> task = LogitDistillationTask( + ... student_model=model, teacher=teacher, criterion=nn.CrossEntropyLoss(), + ... task_loss_weight=0.3, temperature=4.0, + ... device=torch.device('cuda'), + ... ) + """ + + def __init__( + self, + student_model: nn.Module, + teacher: DistillationTeacher, + criterion: nn.Module, + loss_type: str = 'kl', + distill_loss_weight: Optional[float] = None, + task_loss_weight: Optional[float] = None, + temperature: float = 1.0, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + verbose: bool = True, + ): + super().__init__(device=device, dtype=dtype, verbose=verbose) + self.student = student_model + self.teacher = teacher + self.criterion = criterion + self.loss_type = loss_type + self.temperature = temperature + + if loss_type != 'kl': + raise ValueError(f"Unsupported loss_type '{loss_type}'. Currently only 'kl' is supported.") + + # Register student normalization values as non-persistent buffers + # Shape: [1, 3, 1, 1] for proper broadcasting over BCHW images + student_unwrapped = unwrap_model(student_model) + student_mean = torch.tensor( + student_unwrapped.pretrained_cfg['mean'], + device=self.device, + dtype=self.dtype, + ).view(1, -1, 1, 1) + student_std = torch.tensor( + student_unwrapped.pretrained_cfg['std'], + device=self.device, + dtype=self.dtype, + ).view(1, -1, 1, 1) + self.register_buffer('student_mean', student_mean, persistent=False) + self.register_buffer('student_std', student_std, persistent=False) + + # Determine weighting mode + if distill_loss_weight is not None: + # Mode 1: distill_weight specified - independent weights (task defaults to 1.0 if not set) + self.distill_loss_weight = distill_loss_weight + self.task_loss_weight = task_loss_weight if task_loss_weight is not None else 1.0 + if self.verbose: + _logger.info( + f"LogitDistillationTask: Independent weights - " + f"task_weight={self.task_loss_weight}, distill_weight={distill_loss_weight}" + ) + elif task_loss_weight is not None: + # Mode 2: Only task_weight specified - complementary mode + self.task_loss_weight = task_loss_weight + self.distill_loss_weight = 1.0 - task_loss_weight + if self.verbose: + _logger.info( + f"LogitDistillationTask: Complementary mode - " + f"task_weight={task_loss_weight}, distill_weight={self.distill_loss_weight}" + ) + else: + # Neither specified - use defaults (equal weighting) + self.distill_loss_weight = 1.0 + self.task_loss_weight = 1.0 + if self.verbose: + _logger.info( + f"LogitDistillationTask: Default equal weights - " + f"task_weight={self.task_loss_weight}, distill_weight={self.distill_loss_weight}" + ) + + if self.verbose: + _logger.info( + f"LogitDistillationTask: loss_type={loss_type}, temperature={temperature}" + ) + + def prepare_distributed( + self, + device_ids: Optional[list] = None, + **ddp_kwargs + ) -> 'LogitDistillationTask': + """Prepare task for distributed training. + + Wraps the student model in DistributedDataParallel (DDP) while leaving + the frozen teacher model unwrapped. + + Args: + device_ids: List of device IDs for DDP (e.g., [local_rank]) + **ddp_kwargs: Additional arguments passed to DistributedDataParallel + + Returns: + self (for method chaining) + """ + from torch.nn.parallel import DistributedDataParallel as DDP + + # Ensure teacher parameters are frozen + for param in self.teacher.parameters(): + param.requires_grad = False + + # Wrap only student in DDP + self.student = DDP(self.student, device_ids=device_ids, **ddp_kwargs) + return self + + def forward( + self, + input: torch.Tensor, + target: torch.Tensor, + ) -> Dict[str, torch.Tensor]: + """Forward pass with logit distillation. + + Args: + input: Input tensor [B, C, H, W] + target: Target labels [B] + + Returns: + Dictionary containing: + - 'loss': Combined training loss (task + distillation) + - 'output': Student logits + - 'task_loss': Classification loss component + - 'kd_loss': Distillation loss component + """ + # Student forward pass + student_logits = self.student(input) + + # Compute task loss + task_loss = self.criterion(student_logits, target) + + # Teacher forward pass (no gradient) + with torch.no_grad(): + input_kd = self.teacher.normalize_input(input, self.student_mean, self.student_std) + teacher_logits = self.teacher(input_kd.detach(), return_features=False) + + # Compute distillation loss (KL divergence with temperature scaling) + prob_s = F.log_softmax(student_logits / self.temperature, dim=-1) + prob_t = F.log_softmax(teacher_logits / self.temperature, dim=-1) + kd_loss = F.kl_div(prob_s, prob_t, reduction='batchmean', log_target=True) * (self.temperature ** 2) + + # Combine losses with weights + total_loss = self.task_loss_weight * task_loss + self.distill_loss_weight * kd_loss + + return { + 'loss': total_loss, + 'output': student_logits, + 'task_loss': task_loss, + 'kd_loss': kd_loss, + } + + +class FeatureDistillationTrainableModule(nn.Module): + """Trainable module for feature distillation. + + Wraps student model and projection layer into a single module where all + trainable forward operations happen inside forward(). This ensures proper + DDP wrapping when the module is used with DistributedDataParallel. + + Args: + student_model: Student model to train + projection: Optional projection layer (Linear layer or None) + + Returns: + Tuple of (logits, projected_features) + """ + + def __init__( + self, + student_model: nn.Module, + projection: Optional[nn.Module] = None, + ): + super().__init__() + self.student = student_model + self.projection = projection + + def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward pass through student and projection. + + Args: + input: Input tensor [B, C, H, W] + + Returns: + Tuple of (student_logits, projected_features) + """ + # Extract features and compute logits + feature_map = self.student.forward_features(input) + student_logits = self.student.forward_head(feature_map) + student_features = self.student.forward_head(feature_map, pre_logits=True) + + # Apply projection if needed + if self.projection is not None: + student_features = self.projection(student_features) + + return student_logits, student_features + + +class FeatureDistillationTask(TrainingTask): + """Feature-based knowledge distillation task. + + Performs distillation by matching student and teacher intermediate features + (pooled pre-logits) using MSE loss. Automatically creates a projection layer + if student and teacher feature dimensions differ. + + Loss weighting supports two modes: + 1. Independent weights: loss = task_loss_weight * task_loss + distill_loss_weight * distill_loss + 2. Complementary mode: loss = task_loss_weight * task_loss + (1 - task_loss_weight) * distill_loss + (used when only task_loss_weight is specified) + + Args: + student_model: Student model to train + teacher: Pre-configured teacher model wrapper + criterion: Task loss function (e.g., CrossEntropyLoss) + distill_loss_weight: Weight for distillation loss + task_loss_weight: Weight for task loss + student_feature_dim: Student pre-logits dimension (auto-detected if None) + teacher_feature_dim: Teacher pre-logits dimension (auto-detected if None) + device: Device for task tensors/buffers + dtype: Dtype for task tensors/buffers + verbose: Enable info logging + + Example: + >>> # Independent weights + >>> task = FeatureDistillationTask( + ... student_model=model, teacher=teacher, criterion=nn.CrossEntropyLoss(), + ... distill_loss_weight=5.0, task_loss_weight=1.0, + ... device=torch.device('cuda'), + ... ) + >>> # Complementary mode + >>> task = FeatureDistillationTask( + ... student_model=model, teacher=teacher, criterion=nn.CrossEntropyLoss(), + ... task_loss_weight=0.3, + ... device=torch.device('cuda'), + ... ) + """ + + def __init__( + self, + student_model: nn.Module, + teacher: DistillationTeacher, + criterion: nn.Module, + distill_loss_weight: Optional[float] = None, + task_loss_weight: Optional[float] = None, + student_feature_dim: Optional[int] = None, + teacher_feature_dim: Optional[int] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + verbose: bool = True, + ): + super().__init__(device=device, dtype=dtype, verbose=verbose) + self.teacher = teacher + self.criterion = criterion + + # Determine weighting mode + if distill_loss_weight is not None: + # Mode 1: distill_weight specified - independent weights (task defaults to 1.0 if not set) + self.distill_loss_weight = distill_loss_weight + self.task_loss_weight = task_loss_weight if task_loss_weight is not None else 1.0 + if self.verbose: + _logger.info( + f"FeatureDistillationTask: Independent weights - " + f"task_weight={self.task_loss_weight}, distill_weight={distill_loss_weight}" + ) + elif task_loss_weight is not None: + # Mode 2: Only task_weight specified - complementary mode + self.task_loss_weight = task_loss_weight + self.distill_loss_weight = 1.0 - task_loss_weight + if self.verbose: + _logger.info( + f"FeatureDistillationTask: Complementary mode - " + f"task_weight={task_loss_weight}, distill_weight={self.distill_loss_weight}" + ) + else: + # Neither specified - use defaults (equal weighting) + self.distill_loss_weight = 1.0 + self.task_loss_weight = 1.0 + if self.verbose: + _logger.info( + f"FeatureDistillationTask: Default equal weights - " + f"task_weight={self.task_loss_weight}, distill_weight={self.distill_loss_weight}" + ) + + # Auto-detect feature dimensions if not provided + if student_feature_dim is None: + student_feature_dim = self._detect_feature_dim(student_model) + if teacher_feature_dim is None: + teacher_feature_dim = self._detect_feature_dim(teacher.model) + + # Create projection layer if dimensions differ + projection = None + if student_feature_dim != teacher_feature_dim: + if self.verbose: + _logger.info( + f"Creating projection layer: {student_feature_dim} -> {teacher_feature_dim}" + ) + projection = nn.Linear(student_feature_dim, teacher_feature_dim, device=self.device, dtype=self.dtype) + else: + if self.verbose: + _logger.info("Feature dimensions match, no projection needed") + + # Create trainable module wrapping student and projection + self.trainable_module = FeatureDistillationTrainableModule(student_model, projection) + + # Register student normalization values as non-persistent buffers + # Shape: [1, 3, 1, 1] for proper broadcasting over BCHW images + student_unwrapped = unwrap_model(student_model) + student_mean = torch.tensor( + student_unwrapped.pretrained_cfg['mean'], + device=self.device, + dtype=self.dtype, + ).view(1, -1, 1, 1) + student_std = torch.tensor( + student_unwrapped.pretrained_cfg['std'], + device=self.device, + dtype=self.dtype, + ).view(1, -1, 1, 1) + self.register_buffer('student_mean', student_mean, persistent=False) + self.register_buffer('student_std', student_std, persistent=False) + + if self.verbose: + _logger.info( + f"FeatureDistillationTask: " + f"student_dim={student_feature_dim}, teacher_dim={teacher_feature_dim}" + ) + + @staticmethod + def _detect_feature_dim(model: nn.Module) -> int: + """Auto-detect feature dimension from model. + + Tries head_hidden_size first (pre-logits dimension), then num_features. + """ + # Unwrap DDP/EMA wrapper if present + model = unwrap_model(model) + + if hasattr(model, 'head_hidden_size'): + return model.head_hidden_size + elif hasattr(model, 'num_features'): + return model.num_features + else: + raise ValueError( + "Cannot auto-detect feature dimension. Model must have " + "'head_hidden_size' or 'num_features' attribute, or you must " + "specify student_feature_dim and teacher_feature_dim explicitly." + ) + + def prepare_distributed( + self, + device_ids: Optional[list] = None, + **ddp_kwargs + ) -> 'FeatureDistillationTask': + """Prepare task for distributed training. + + Wraps the trainable module (student + projection) in DistributedDataParallel (DDP) + while leaving the frozen teacher model unwrapped. + + Args: + device_ids: List of device IDs for DDP (e.g., [local_rank]) + **ddp_kwargs: Additional arguments passed to DistributedDataParallel + + Returns: + self (for method chaining) + """ + from torch.nn.parallel import DistributedDataParallel as DDP + + # Ensure teacher parameters are frozen + for param in self.teacher.parameters(): + param.requires_grad = False + + # Wrap trainable module (student + projection) in DDP + self.trainable_module = DDP(self.trainable_module, device_ids=device_ids, **ddp_kwargs) + return self + + def forward( + self, + input: torch.Tensor, + target: torch.Tensor, + ) -> Dict[str, torch.Tensor]: + """Forward pass with feature distillation. + + Args: + input: Input tensor [B, C, H, W] + target: Target labels [B] + + Returns: + Dictionary containing: + - 'loss': Combined training loss (task + distillation) + - 'output': Student logits + - 'task_loss': Classification loss component + - 'kd_loss': Feature distillation loss component + """ + # Student forward pass through trainable module (student + projection) + student_logits, student_features = self.trainable_module(input) + + # Compute task loss + task_loss = self.criterion(student_logits, target) + + # Teacher forward pass (no gradient) + with torch.no_grad(): + input_kd = self.teacher.normalize_input(input, self.student_mean, self.student_std) + teacher_features = self.teacher(input_kd.detach(), return_features=True) + + # Compute feature distillation loss (MSE) + kd_loss = F.mse_loss(student_features, teacher_features) + + # Combine losses with weights + total_loss = self.task_loss_weight * task_loss + self.distill_loss_weight * kd_loss + + return { + 'loss': total_loss, + 'output': student_logits, + 'task_loss': task_loss, + 'kd_loss': kd_loss, + } diff --git a/timm/task/task.py b/timm/task/task.py new file mode 100644 index 0000000000..719c58a600 --- /dev/null +++ b/timm/task/task.py @@ -0,0 +1,100 @@ +"""Base training task abstraction. + +This module provides the base TrainingTask class that encapsulates a complete +forward pass including loss computation. Tasks return a dictionary with loss +components and outputs for logging. +""" +from typing import Dict, Optional + +import torch +import torch.nn as nn + + +class TrainingTask(nn.Module): + """Base class for training tasks. + + A training task encapsulates a complete forward pass including loss computation. + Tasks return a dictionary containing the training loss and other components for logging. + + The returned dictionary must contain: + - 'loss': The training loss for backward pass (required) + - 'output': Model output/logits for metric computation (recommended) + - Other task-specific loss components for logging (optional) + + Args: + device: Device for task tensors/buffers (defaults to cpu) + dtype: Dtype for task tensors/buffers (defaults to torch default) + verbose: Enable info logging + + Example: + >>> task = SomeTask(model, criterion, device=torch.device('cuda')) + >>> + >>> # Prepare for distributed training (if needed) + >>> if distributed: + >>> task.prepare_distributed(device_ids=[local_rank]) + >>> + >>> # Training loop + >>> result = task(input, target) + >>> result['loss'].backward() + """ + + def __init__( + self, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + verbose: bool = True, + ): + super().__init__() + self.device = device if device is not None else torch.device('cpu') + self.dtype = dtype if dtype is not None else torch.get_default_dtype() + self.verbose = verbose + + def to(self, *args, **kwargs): + """Move task to device/dtype, keeping self.device and self.dtype in sync.""" + dummy = torch.empty(0).to(*args, **kwargs) + self.device = dummy.device + self.dtype = dummy.dtype + return super().to(*args, **kwargs) + + def prepare_distributed( + self, + device_ids: Optional[list] = None, + **ddp_kwargs + ) -> 'TrainingTask': + """Prepare task for distributed training. + + This method wraps trainable components in DistributedDataParallel (DDP) + while leaving non-trainable components (like frozen teacher models) unwrapped. + + Should be called after task initialization but before training loop. + + Args: + device_ids: List of device IDs for DDP (e.g., [local_rank]) + **ddp_kwargs: Additional arguments passed to DistributedDataParallel + + Returns: + self (for method chaining) + + Example: + >>> task = LogitDistillationTask(student, teacher, criterion) + >>> task.prepare_distributed(device_ids=[args.local_rank]) + >>> task = torch.compile(task) # Compile after DDP + """ + # Default implementation - subclasses override if they need DDP + return self + + def forward( + self, + input: torch.Tensor, + target: torch.Tensor, + ) -> Dict[str, torch.Tensor]: + """Perform forward pass and compute loss. + + Args: + input: Input tensor [B, C, H, W] + target: Target labels [B] + + Returns: + Dictionary with at least 'loss' key containing the training loss + """ + raise NotImplementedError diff --git a/train.py b/train.py index 440aa01487..4a73257674 100755 --- a/train.py +++ b/train.py @@ -41,7 +41,7 @@ from timm.optim import create_optimizer_v2, optimizer_kwargs from timm.scheduler import create_scheduler_v2, scheduler_kwargs from timm.utils import ApexScaler, NativeScaler -from timm.kd import DistillationTeacher, apply_kd_loss +from timm.task import DistillationTeacher, ClassificationTask, LogitDistillationTask, FeatureDistillationTask try: from apex import amp @@ -421,10 +421,21 @@ # Knowledge Distillation parameters parser.add_argument('--kd-model-name', default=None, type=str, help='Name of teacher model for knowledge distillation') -parser.add_argument('--alpha-kd', default=5, type=float, - help='Weight for KD loss (default: 5)') -parser.add_argument('--use-kd-only-loss', action='store_true', default=False, - help='Use only KD loss, without cross-entropy loss') +parser.add_argument('--kd-distill-type', default='logit', type=str, choices=['logit', 'feature'], + help='Type of distillation: "logit" for output distillation, "feature" for intermediate features (default: logit)') +parser.add_argument('--kd-loss-type', default='kl', type=str, + help='Loss function for logit distillation (default: kl). Currently only "kl" supported, reserved for future extensions.') +parser.add_argument('--distill-loss-weight', default=None, type=float, + help='Weight for distillation loss. If both weights specified: loss = task_weight * task + distill_weight * distill. ' + 'If only task_weight: loss = task_weight * task + (1-task_weight) * distill. Default: 1.0 if only this specified.') +parser.add_argument('--task-loss-weight', default=None, type=float, + help='Weight for task (classification) loss. See --distill-loss-weight for weighting modes. Default: 1.0 if unspecified.') +parser.add_argument('--kd-temperature', default=4.0, type=float, + help='Temperature for softmax in distillation (default: 4.0, typical range: 1-4)') +parser.add_argument('--kd-student-feature-dim', default=None, type=int, + help='Student model feature dimension (auto-detected from model.head_hidden_size or model.num_features if not specified)') +parser.add_argument('--kd-teacher-feature-dim', default=None, type=int, + help='Teacher model feature dimension (auto-detected from model.head_hidden_size or model.num_features if not specified)') def _parse_args(): @@ -540,16 +551,8 @@ def main(): if args.grad_checkpointing: model.set_grad_checkpointing(enable=True) - # Create the KD teacher model if specified - model_kd = None - if args.kd_model_name is not None: - model_kd = DistillationTeacher( - model_name=args.kd_model_name, - num_classes=args.num_classes, - in_chans=in_chans, - device=device, - dtype=model_dtype, - ) + # Create training task (classification or distillation) + task = None if utils.is_primary(args): _logger.info( @@ -677,22 +680,22 @@ def main(): ) # setup distributed training - if args.distributed: - if has_apex and use_amp == 'apex': - # Apex DDP preferred unless native amp is activated - if utils.is_primary(args): - _logger.info("Using NVIDIA APEX DistributedDataParallel.") - model = ApexDDP(model, delay_allreduce=True) - else: - if utils.is_primary(args): - _logger.info("Using native Torch DistributedDataParallel.") - model = NativeDDP(model, device_ids=[device], broadcast_buffers=not args.no_ddp_bb) - # NOTE: EMA model does not need to be wrapped by DDP - - if args.torchcompile: - # torch compile should be done after DDP - assert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.' - model = torch.compile(model, backend=args.torchcompile, mode=args.torchcompile_mode) + # if args.distributed: + # if has_apex and use_amp == 'apex': + # # Apex DDP preferred unless native amp is activated + # if utils.is_primary(args): + # _logger.info("Using NVIDIA APEX DistributedDataParallel.") + # model = ApexDDP(model, delay_allreduce=True) + # else: + # if utils.is_primary(args): + # _logger.info("Using native Torch DistributedDataParallel.") + # model = NativeDDP(model, device_ids=[device], broadcast_buffers=not args.no_ddp_bb) + # # NOTE: EMA model does not need to be wrapped by DDP + + # if args.torchcompile: + # # torch compile should be done after DDP + # assert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.' + # model = torch.compile(model, backend=args.torchcompile, mode=args.torchcompile_mode) # create the train and eval datasets if args.data and not args.data_dir: @@ -927,6 +930,69 @@ def main(): train_loss_fn = train_loss_fn.to(device=device) validate_loss_fn = nn.CrossEntropyLoss().to(device=device) + # Setup training task (classification or distillation) + if args.kd_model_name is not None: + # Create teacher model + teacher = DistillationTeacher( + model_name=args.kd_model_name, + num_classes=args.num_classes, + in_chans=in_chans, + device=device, + dtype=model_dtype, + ) + + # Create distillation task + if args.kd_distill_type == 'logit': + task = LogitDistillationTask( + student_model=model, + teacher=teacher, + criterion=train_loss_fn, + loss_type=args.kd_loss_type, + distill_loss_weight=args.distill_loss_weight, + task_loss_weight=args.task_loss_weight, + temperature=args.kd_temperature, + device=device, + dtype=model_dtype, + verbose=utils.is_primary(args), + ) + elif args.kd_distill_type == 'feature': + task = FeatureDistillationTask( + student_model=model, + teacher=teacher, + criterion=train_loss_fn, + distill_loss_weight=args.distill_loss_weight, + task_loss_weight=args.task_loss_weight, + student_feature_dim=args.kd_student_feature_dim, + teacher_feature_dim=args.kd_teacher_feature_dim, + device=device, + dtype=model_dtype, + verbose=utils.is_primary(args), + ) + else: + raise ValueError(f"Unknown distillation type: {args.kd_distill_type}") + else: + # Standard classification task + task = ClassificationTask( + model=model, + criterion=train_loss_fn, + device=device, + dtype=model_dtype, + verbose=utils.is_primary(args), + ) + + # Prepare task for distributed training + if args.distributed: + if utils.is_primary(args): + _logger.info("Preparing task for distributed training") + task.prepare_distributed(device_ids=[device]) + + # Compile task if requested (should be done after DDP) + if args.torchcompile: + assert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.' + if utils.is_primary(args): + _logger.info(f"Compiling task with backend={args.torchcompile}, mode={args.torchcompile_mode}") + task = torch.compile(task, backend=args.torchcompile, mode=args.torchcompile_mode) + # setup checkpoint saver and eval metric tracking eval_metric = args.eval_metric if loader_eval is not None else 'loss' decreasing_metric = eval_metric == 'loss' @@ -1015,8 +1081,8 @@ def main(): model, loader_train, optimizer, - train_loss_fn, args, + task=task, device=device, lr_scheduler=lr_scheduler, saver=saver, @@ -1028,7 +1094,6 @@ def main(): mixup_fn=mixup_fn, num_updates_total=num_epochs * updates_per_epoch, naflex_mode=naflex_mode, - model_kd=model_kd, ) if args.distributed and args.dist_bn in ('broadcast', 'reduce'): @@ -1131,8 +1196,8 @@ def train_one_epoch( model, loader, optimizer, - loss_fn, args, + task=None, device=torch.device('cuda'), lr_scheduler=None, saver=None, @@ -1144,7 +1209,6 @@ def train_one_epoch( mixup_fn=None, num_updates_total=None, naflex_mode=False, - model_kd=None, ): if args.mixup_off_epoch and epoch >= args.mixup_off_epoch: if args.prefetcher and loader.mixup_enabled: @@ -1189,24 +1253,13 @@ def train_one_epoch( def _forward(): with amp_autocast(): - output = model(input) - _loss = loss_fn(output, target) - - # KD logic - if model_kd is not None: - _loss = apply_kd_loss( - loss=_loss, - student_output=output, - input=input, - student_model=model, - teacher_model=model_kd, - alpha_kd=args.alpha_kd, - use_kd_only=args.use_kd_only_loss, - ) + # Task handles the complete forward pass and loss computation + result = task(input, target) + _loss = result['loss'] if accum_steps > 1: _loss /= accum_steps - return _loss + return _loss, result def _backward(_loss): if loss_scaler is not None: @@ -1255,13 +1308,13 @@ def _backward(_loss): if has_no_sync and not need_update: with model.no_sync(): - loss = _forward() + loss, result = _forward() scaled_loss = local_scale * loss if dist_scale is not None: scaled_loss *= dist_scale _backward(scaled_loss) else: - loss = _forward() + loss, result = _forward() scaled_loss = local_scale * loss if dist_scale is not None: scaled_loss *= dist_scale @@ -1273,10 +1326,10 @@ def _backward(_loss): if has_no_sync and not need_update: with model.no_sync(): - loss = _forward() + loss, result = _forward() _backward(loss) else: - loss = _forward() + loss, result = _forward() _backward(loss) losses_m.update(loss.item() * accum_steps, batch_size) From d386cc305f4060d739d6eb21d24dd783bae979aa Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 1 Dec 2025 13:03:37 -0800 Subject: [PATCH 6/6] Remove apex AMP use from scripts --- benchmark.py | 7 ----- inference.py | 8 +---- timm/task/distillation.py | 2 +- train.py | 65 ++++++--------------------------------- validate.py | 31 ++++--------------- 5 files changed, 18 insertions(+), 95 deletions(-) diff --git a/benchmark.py b/benchmark.py index db8fe4308d..beaf257a1d 100755 --- a/benchmark.py +++ b/benchmark.py @@ -25,13 +25,6 @@ from timm.utils import setup_default_logging, set_jit_fuser, decay_batch_step, check_batch_size_retry, ParseKwargs,\ reparameterize_model -has_apex = False -try: - from apex import amp - has_apex = True -except ImportError: - pass - try: from deepspeed.profiling.flops_profiler import get_model_profile has_deepspeed_profiling = True diff --git a/inference.py b/inference.py index 7ccaa334ce..21db6194fa 100755 --- a/inference.py +++ b/inference.py @@ -23,12 +23,6 @@ from timm.models import create_model from timm.utils import AverageMeter, setup_default_logging, set_jit_fuser, ParseKwargs -try: - from apex import amp - has_apex = True -except ImportError: - has_apex = False - try: from functorch.compile import memory_efficient_fusion has_functorch = True @@ -170,7 +164,7 @@ def main(): assert args.model_dtype in ('float32', 'float16', 'bfloat16') model_dtype = getattr(torch, args.model_dtype) - # resolve AMP arguments based on PyTorch / Apex availability + # resolve AMP arguments based on PyTorch availability amp_autocast = suppress if args.amp: assert model_dtype is None or model_dtype == torch.float32, 'float32 model dtype must be used with AMP' diff --git a/timm/task/distillation.py b/timm/task/distillation.py index d7f83c5330..ff92b44d93 100644 --- a/timm/task/distillation.py +++ b/timm/task/distillation.py @@ -1,6 +1,6 @@ """Knowledge distillation training tasks and components.""" import logging -from typing import Dict, Optional, Literal, Tuple +from typing import Dict, Optional, Tuple import torch import torch.nn as nn diff --git a/train.py b/train.py index 4a73257674..efa7725506 100755 --- a/train.py +++ b/train.py @@ -30,7 +30,6 @@ import torch.nn as nn import torchvision.utils import yaml -from torch.nn.parallel import DistributedDataParallel as NativeDDP from timm import utils from timm.data import create_dataset, create_loader, create_naflex_loader, resolve_data_config, \ @@ -40,17 +39,9 @@ from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint, model_parameters from timm.optim import create_optimizer_v2, optimizer_kwargs from timm.scheduler import create_scheduler_v2, scheduler_kwargs -from timm.utils import ApexScaler, NativeScaler +from timm.utils import NativeScaler from timm.task import DistillationTeacher, ClassificationTask, LogitDistillationTask, FeatureDistillationTask -try: - from apex import amp - from apex.parallel import DistributedDataParallel as ApexDDP - from apex.parallel import convert_syncbn_model - has_apex = True -except ImportError: - has_apex = False - try: import wandb @@ -174,11 +165,9 @@ group.add_argument('--device', default='cuda', type=str, help="Device (accelerator) to use.") group.add_argument('--amp', action='store_true', default=False, - help='use NVIDIA Apex AMP or Native AMP for mixed precision training') + help='use AMP for mixed precision training') group.add_argument('--amp-dtype', default='float16', type=str, help='lower precision AMP dtype (default: float16)') -group.add_argument('--amp-impl', default='native', type=str, - help='AMP impl to use, "native" or "apex" (default: native)') group.add_argument('--model-dtype', default=None, type=str, help='Model dtype override (non-AMP) (default: float32)') group.add_argument('--no-ddp-bb', action='store_true', default=False, @@ -346,7 +335,7 @@ group.add_argument('--bn-eps', type=float, default=None, help='BatchNorm epsilon override (if not None)') group.add_argument('--sync-bn', action='store_true', - help='Enable NVIDIA Apex or Torch synchronized BatchNorm.') + help='Enable synchronized BatchNorm.') group.add_argument('--dist-bn', type=str, default='reduce', help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")') group.add_argument('--split-bn', action='store_true', @@ -485,18 +474,11 @@ def main(): if model_dtype == torch.float16: _logger.warning('float16 is not recommended for training, for half precision bfloat16 is recommended.') - # resolve AMP arguments based on PyTorch / Apex availability - use_amp = None + # resolve AMP arguments based on PyTorch availability amp_dtype = torch.float16 if args.amp: assert model_dtype is None or model_dtype == torch.float32, 'float32 model dtype must be used with AMP' - if args.amp_impl == 'apex': - assert has_apex, 'AMP impl specified as APEX but APEX is not installed.' - use_amp = 'apex' - assert args.amp_dtype == 'float16' - else: - use_amp = 'native' - assert args.amp_dtype in ('float16', 'bfloat16') + assert args.amp_dtype in ('float16', 'bfloat16') if args.amp_dtype == 'bfloat16': amp_dtype = torch.bfloat16 @@ -580,12 +562,7 @@ def main(): if args.distributed and args.sync_bn: args.dist_bn = '' # disable dist_bn when sync BN active assert not args.split_bn - if has_apex and use_amp == 'apex': - # Apex SyncBN used with Apex AMP - # WARNING this won't currently work with models using BatchNormAct2d - model = convert_syncbn_model(model) - else: - model = convert_sync_batchnorm(model) + model = convert_sync_batchnorm(model) if utils.is_primary(args): _logger.info( 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' @@ -598,7 +575,6 @@ def main(): if args.torchscript: assert not args.torchcompile - assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model' assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model' model = torch.jit.script(model) @@ -632,13 +608,7 @@ def main(): # setup automatic mixed-precision (AMP) loss scaling and op casting amp_autocast = suppress # do nothing loss_scaler = None - if use_amp == 'apex': - assert device.type == 'cuda' - model, optimizer = amp.initialize(model, optimizer, opt_level='O1') - loss_scaler = ApexScaler() - if utils.is_primary(args): - _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.') - elif use_amp == 'native': + if args.amp: amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype) if device.type in ('cuda',) and amp_dtype == torch.float16: # loss scaler only used for float16 (half) dtype, bfloat16 does not need it @@ -679,24 +649,6 @@ def main(): mode=args.torchcompile_mode, ) - # setup distributed training - # if args.distributed: - # if has_apex and use_amp == 'apex': - # # Apex DDP preferred unless native amp is activated - # if utils.is_primary(args): - # _logger.info("Using NVIDIA APEX DistributedDataParallel.") - # model = ApexDDP(model, delay_allreduce=True) - # else: - # if utils.is_primary(args): - # _logger.info("Using native Torch DistributedDataParallel.") - # model = NativeDDP(model, device_ids=[device], broadcast_buffers=not args.no_ddp_bb) - # # NOTE: EMA model does not need to be wrapped by DDP - - # if args.torchcompile: - # # torch compile should be done after DDP - # assert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.' - # model = torch.compile(model, backend=args.torchcompile, mode=args.torchcompile_mode) - # create the train and eval datasets if args.data and not args.data_dir: args.data_dir = args.data @@ -1177,6 +1129,9 @@ def main(): except KeyboardInterrupt: pass + if args.distributed: + torch.distributed.destroy_process_group() + if best_metric is not None: # log best metric as tracked by checkpoint saver _logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch)) diff --git a/validate.py b/validate.py index 75657a764d..03c572929d 100755 --- a/validate.py +++ b/validate.py @@ -28,11 +28,6 @@ from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_fuser, \ decay_batch_step, check_batch_size_retry, ParseKwargs, reparameterize_model -try: - from apex import amp - has_apex = True -except ImportError: - has_apex = False try: from functorch.compile import memory_efficient_fusion @@ -124,11 +119,9 @@ parser.add_argument('--device', default='cuda', type=str, help="Device (accelerator) to use.") parser.add_argument('--amp', action='store_true', default=False, - help='use NVIDIA Apex AMP or Native AMP for mixed precision training') + help='use Native AMP for mixed precision inference') parser.add_argument('--amp-dtype', default='float16', type=str, help='lower precision AMP dtype (default: float16)') -parser.add_argument('--amp-impl', default='native', type=str, - help='AMP impl to use, "native" or "apex" (default: native)') parser.add_argument('--model-dtype', default=None, type=str, help='Model dtype override (non-AMP) (default: float32)') parser.add_argument('--tf-preprocessing', action='store_true', default=False, @@ -197,22 +190,14 @@ def validate(args): assert args.model_dtype in ('float32', 'float16', 'bfloat16') model_dtype = getattr(torch, args.model_dtype) - # resolve AMP arguments based on PyTorch / Apex availability - use_amp = None + # resolve AMP arguments based on PyTorch availability amp_autocast = suppress if args.amp: assert model_dtype is None or model_dtype == torch.float32, 'float32 model dtype must be used with AMP' - if args.amp_impl == 'apex': - assert has_apex, 'AMP impl specified as APEX but APEX is not installed.' - assert args.amp_dtype == 'float16' - use_amp = 'apex' - _logger.info('Validating in mixed precision with NVIDIA APEX AMP.') - else: - assert args.amp_dtype in ('float16', 'bfloat16') - use_amp = 'native' - amp_dtype = torch.bfloat16 if args.amp_dtype == 'bfloat16' else torch.float16 - amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype) - _logger.info('Validating in mixed precision with native PyTorch AMP.') + assert args.amp_dtype in ('float16', 'bfloat16') + amp_dtype = torch.bfloat16 if args.amp_dtype == 'bfloat16' else torch.float16 + amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype) + _logger.info('Validating in mixed precision with native PyTorch AMP.') else: _logger.info(f'Validating in {model_dtype or torch.float32}. AMP not enabled.') @@ -266,7 +251,6 @@ def validate(args): model = model.to(memory_format=torch.channels_last) if args.torchscript: - assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model' model = torch.jit.script(model) elif args.torchcompile: assert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.' @@ -276,9 +260,6 @@ def validate(args): assert has_functorch, "functorch is needed for --aot-autograd" model = memory_efficient_fusion(model) - if use_amp == 'apex': - model = amp.initialize(model, opt_level='O1') - if args.num_gpu > 1: model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu)))