3030import torch .nn as nn
3131import torchvision .utils
3232import yaml
33- from torch .nn .parallel import DistributedDataParallel as NativeDDP
3433
3534from timm import utils
3635from timm .data import create_dataset , create_loader , create_naflex_loader , resolve_data_config , \
4039from timm .models import create_model , safe_model_name , resume_checkpoint , load_checkpoint , model_parameters
4140from timm .optim import create_optimizer_v2 , optimizer_kwargs
4241from timm .scheduler import create_scheduler_v2 , scheduler_kwargs
43- from timm .utils import ApexScaler , NativeScaler
42+ from timm .utils import NativeScaler
4443from timm .task import DistillationTeacher , ClassificationTask , LogitDistillationTask , FeatureDistillationTask
4544
46- try :
47- from apex import amp
48- from apex .parallel import DistributedDataParallel as ApexDDP
49- from apex .parallel import convert_syncbn_model
50- has_apex = True
51- except ImportError :
52- has_apex = False
53-
5445
5546try :
5647 import wandb
174165group .add_argument ('--device' , default = 'cuda' , type = str ,
175166 help = "Device (accelerator) to use." )
176167group .add_argument ('--amp' , action = 'store_true' , default = False ,
177- help = 'use NVIDIA Apex AMP or Native AMP for mixed precision training' )
168+ help = 'use AMP for mixed precision training' )
178169group .add_argument ('--amp-dtype' , default = 'float16' , type = str ,
179170 help = 'lower precision AMP dtype (default: float16)' )
180- group .add_argument ('--amp-impl' , default = 'native' , type = str ,
181- help = 'AMP impl to use, "native" or "apex" (default: native)' )
182171group .add_argument ('--model-dtype' , default = None , type = str ,
183172 help = 'Model dtype override (non-AMP) (default: float32)' )
184173group .add_argument ('--no-ddp-bb' , action = 'store_true' , default = False ,
346335group .add_argument ('--bn-eps' , type = float , default = None ,
347336 help = 'BatchNorm epsilon override (if not None)' )
348337group .add_argument ('--sync-bn' , action = 'store_true' ,
349- help = 'Enable NVIDIA Apex or Torch synchronized BatchNorm.' )
338+ help = 'Enable synchronized BatchNorm.' )
350339group .add_argument ('--dist-bn' , type = str , default = 'reduce' ,
351340 help = 'Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")' )
352341group .add_argument ('--split-bn' , action = 'store_true' ,
@@ -485,18 +474,11 @@ def main():
485474 if model_dtype == torch .float16 :
486475 _logger .warning ('float16 is not recommended for training, for half precision bfloat16 is recommended.' )
487476
488- # resolve AMP arguments based on PyTorch / Apex availability
489- use_amp = None
477+ # resolve AMP arguments based on PyTorch availability
490478 amp_dtype = torch .float16
491479 if args .amp :
492480 assert model_dtype is None or model_dtype == torch .float32 , 'float32 model dtype must be used with AMP'
493- if args .amp_impl == 'apex' :
494- assert has_apex , 'AMP impl specified as APEX but APEX is not installed.'
495- use_amp = 'apex'
496- assert args .amp_dtype == 'float16'
497- else :
498- use_amp = 'native'
499- assert args .amp_dtype in ('float16' , 'bfloat16' )
481+ assert args .amp_dtype in ('float16' , 'bfloat16' )
500482 if args .amp_dtype == 'bfloat16' :
501483 amp_dtype = torch .bfloat16
502484
@@ -580,12 +562,7 @@ def main():
580562 if args .distributed and args .sync_bn :
581563 args .dist_bn = '' # disable dist_bn when sync BN active
582564 assert not args .split_bn
583- if has_apex and use_amp == 'apex' :
584- # Apex SyncBN used with Apex AMP
585- # WARNING this won't currently work with models using BatchNormAct2d
586- model = convert_syncbn_model (model )
587- else :
588- model = convert_sync_batchnorm (model )
565+ model = convert_sync_batchnorm (model )
589566 if utils .is_primary (args ):
590567 _logger .info (
591568 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
@@ -598,7 +575,6 @@ def main():
598575
599576 if args .torchscript :
600577 assert not args .torchcompile
601- assert not use_amp == 'apex' , 'Cannot use APEX AMP with torchscripted model'
602578 assert not args .sync_bn , 'Cannot use SyncBatchNorm with torchscripted model'
603579 model = torch .jit .script (model )
604580
@@ -632,13 +608,7 @@ def main():
632608 # setup automatic mixed-precision (AMP) loss scaling and op casting
633609 amp_autocast = suppress # do nothing
634610 loss_scaler = None
635- if use_amp == 'apex' :
636- assert device .type == 'cuda'
637- model , optimizer = amp .initialize (model , optimizer , opt_level = 'O1' )
638- loss_scaler = ApexScaler ()
639- if utils .is_primary (args ):
640- _logger .info ('Using NVIDIA APEX AMP. Training in mixed precision.' )
641- elif use_amp == 'native' :
611+ if args .amp :
642612 amp_autocast = partial (torch .autocast , device_type = device .type , dtype = amp_dtype )
643613 if device .type in ('cuda' ,) and amp_dtype == torch .float16 :
644614 # loss scaler only used for float16 (half) dtype, bfloat16 does not need it
@@ -679,24 +649,6 @@ def main():
679649 mode = args .torchcompile_mode ,
680650 )
681651
682- # setup distributed training
683- # if args.distributed:
684- # if has_apex and use_amp == 'apex':
685- # # Apex DDP preferred unless native amp is activated
686- # if utils.is_primary(args):
687- # _logger.info("Using NVIDIA APEX DistributedDataParallel.")
688- # model = ApexDDP(model, delay_allreduce=True)
689- # else:
690- # if utils.is_primary(args):
691- # _logger.info("Using native Torch DistributedDataParallel.")
692- # model = NativeDDP(model, device_ids=[device], broadcast_buffers=not args.no_ddp_bb)
693- # # NOTE: EMA model does not need to be wrapped by DDP
694-
695- # if args.torchcompile:
696- # # torch compile should be done after DDP
697- # assert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.'
698- # model = torch.compile(model, backend=args.torchcompile, mode=args.torchcompile_mode)
699-
700652 # create the train and eval datasets
701653 if args .data and not args .data_dir :
702654 args .data_dir = args .data
@@ -1177,6 +1129,9 @@ def main():
11771129 except KeyboardInterrupt :
11781130 pass
11791131
1132+ if args .distributed :
1133+ torch .distributed .destroy_process_group ()
1134+
11801135 if best_metric is not None :
11811136 # log best metric as tracked by checkpoint saver
11821137 _logger .info ('*** Best metric: {0} (epoch {1})' .format (best_metric , best_epoch ))
0 commit comments