Skip to content

Commit fee045e

Browse files
committed
Remove apex AMP use from scripts
1 parent 84014b1 commit fee045e

File tree

5 files changed

+18
-95
lines changed

5 files changed

+18
-95
lines changed

benchmark.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,6 @@
2525
from timm.utils import setup_default_logging, set_jit_fuser, decay_batch_step, check_batch_size_retry, ParseKwargs,\
2626
reparameterize_model
2727

28-
has_apex = False
29-
try:
30-
from apex import amp
31-
has_apex = True
32-
except ImportError:
33-
pass
34-
3528
try:
3629
from deepspeed.profiling.flops_profiler import get_model_profile
3730
has_deepspeed_profiling = True

inference.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,6 @@
2323
from timm.models import create_model
2424
from timm.utils import AverageMeter, setup_default_logging, set_jit_fuser, ParseKwargs
2525

26-
try:
27-
from apex import amp
28-
has_apex = True
29-
except ImportError:
30-
has_apex = False
31-
3226
try:
3327
from functorch.compile import memory_efficient_fusion
3428
has_functorch = True
@@ -170,7 +164,7 @@ def main():
170164
assert args.model_dtype in ('float32', 'float16', 'bfloat16')
171165
model_dtype = getattr(torch, args.model_dtype)
172166

173-
# resolve AMP arguments based on PyTorch / Apex availability
167+
# resolve AMP arguments based on PyTorch availability
174168
amp_autocast = suppress
175169
if args.amp:
176170
assert model_dtype is None or model_dtype == torch.float32, 'float32 model dtype must be used with AMP'

timm/task/distillation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Knowledge distillation training tasks and components."""
22
import logging
3-
from typing import Dict, Optional, Literal, Tuple
3+
from typing import Dict, Optional, Tuple
44

55
import torch
66
import torch.nn as nn

train.py

Lines changed: 10 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
import torch.nn as nn
3131
import torchvision.utils
3232
import yaml
33-
from torch.nn.parallel import DistributedDataParallel as NativeDDP
3433

3534
from timm import utils
3635
from timm.data import create_dataset, create_loader, create_naflex_loader, resolve_data_config, \
@@ -40,17 +39,9 @@
4039
from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint, model_parameters
4140
from timm.optim import create_optimizer_v2, optimizer_kwargs
4241
from timm.scheduler import create_scheduler_v2, scheduler_kwargs
43-
from timm.utils import ApexScaler, NativeScaler
42+
from timm.utils import NativeScaler
4443
from timm.task import DistillationTeacher, ClassificationTask, LogitDistillationTask, FeatureDistillationTask
4544

46-
try:
47-
from apex import amp
48-
from apex.parallel import DistributedDataParallel as ApexDDP
49-
from apex.parallel import convert_syncbn_model
50-
has_apex = True
51-
except ImportError:
52-
has_apex = False
53-
5445

5546
try:
5647
import wandb
@@ -174,11 +165,9 @@
174165
group.add_argument('--device', default='cuda', type=str,
175166
help="Device (accelerator) to use.")
176167
group.add_argument('--amp', action='store_true', default=False,
177-
help='use NVIDIA Apex AMP or Native AMP for mixed precision training')
168+
help='use AMP for mixed precision training')
178169
group.add_argument('--amp-dtype', default='float16', type=str,
179170
help='lower precision AMP dtype (default: float16)')
180-
group.add_argument('--amp-impl', default='native', type=str,
181-
help='AMP impl to use, "native" or "apex" (default: native)')
182171
group.add_argument('--model-dtype', default=None, type=str,
183172
help='Model dtype override (non-AMP) (default: float32)')
184173
group.add_argument('--no-ddp-bb', action='store_true', default=False,
@@ -346,7 +335,7 @@
346335
group.add_argument('--bn-eps', type=float, default=None,
347336
help='BatchNorm epsilon override (if not None)')
348337
group.add_argument('--sync-bn', action='store_true',
349-
help='Enable NVIDIA Apex or Torch synchronized BatchNorm.')
338+
help='Enable synchronized BatchNorm.')
350339
group.add_argument('--dist-bn', type=str, default='reduce',
351340
help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")')
352341
group.add_argument('--split-bn', action='store_true',
@@ -485,18 +474,11 @@ def main():
485474
if model_dtype == torch.float16:
486475
_logger.warning('float16 is not recommended for training, for half precision bfloat16 is recommended.')
487476

488-
# resolve AMP arguments based on PyTorch / Apex availability
489-
use_amp = None
477+
# resolve AMP arguments based on PyTorch availability
490478
amp_dtype = torch.float16
491479
if args.amp:
492480
assert model_dtype is None or model_dtype == torch.float32, 'float32 model dtype must be used with AMP'
493-
if args.amp_impl == 'apex':
494-
assert has_apex, 'AMP impl specified as APEX but APEX is not installed.'
495-
use_amp = 'apex'
496-
assert args.amp_dtype == 'float16'
497-
else:
498-
use_amp = 'native'
499-
assert args.amp_dtype in ('float16', 'bfloat16')
481+
assert args.amp_dtype in ('float16', 'bfloat16')
500482
if args.amp_dtype == 'bfloat16':
501483
amp_dtype = torch.bfloat16
502484

@@ -580,12 +562,7 @@ def main():
580562
if args.distributed and args.sync_bn:
581563
args.dist_bn = '' # disable dist_bn when sync BN active
582564
assert not args.split_bn
583-
if has_apex and use_amp == 'apex':
584-
# Apex SyncBN used with Apex AMP
585-
# WARNING this won't currently work with models using BatchNormAct2d
586-
model = convert_syncbn_model(model)
587-
else:
588-
model = convert_sync_batchnorm(model)
565+
model = convert_sync_batchnorm(model)
589566
if utils.is_primary(args):
590567
_logger.info(
591568
'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
@@ -598,7 +575,6 @@ def main():
598575

599576
if args.torchscript:
600577
assert not args.torchcompile
601-
assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
602578
assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model'
603579
model = torch.jit.script(model)
604580

@@ -632,13 +608,7 @@ def main():
632608
# setup automatic mixed-precision (AMP) loss scaling and op casting
633609
amp_autocast = suppress # do nothing
634610
loss_scaler = None
635-
if use_amp == 'apex':
636-
assert device.type == 'cuda'
637-
model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
638-
loss_scaler = ApexScaler()
639-
if utils.is_primary(args):
640-
_logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')
641-
elif use_amp == 'native':
611+
if args.amp:
642612
amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype)
643613
if device.type in ('cuda',) and amp_dtype == torch.float16:
644614
# loss scaler only used for float16 (half) dtype, bfloat16 does not need it
@@ -679,24 +649,6 @@ def main():
679649
mode=args.torchcompile_mode,
680650
)
681651

682-
# setup distributed training
683-
# if args.distributed:
684-
# if has_apex and use_amp == 'apex':
685-
# # Apex DDP preferred unless native amp is activated
686-
# if utils.is_primary(args):
687-
# _logger.info("Using NVIDIA APEX DistributedDataParallel.")
688-
# model = ApexDDP(model, delay_allreduce=True)
689-
# else:
690-
# if utils.is_primary(args):
691-
# _logger.info("Using native Torch DistributedDataParallel.")
692-
# model = NativeDDP(model, device_ids=[device], broadcast_buffers=not args.no_ddp_bb)
693-
# # NOTE: EMA model does not need to be wrapped by DDP
694-
695-
# if args.torchcompile:
696-
# # torch compile should be done after DDP
697-
# assert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.'
698-
# model = torch.compile(model, backend=args.torchcompile, mode=args.torchcompile_mode)
699-
700652
# create the train and eval datasets
701653
if args.data and not args.data_dir:
702654
args.data_dir = args.data
@@ -1177,6 +1129,9 @@ def main():
11771129
except KeyboardInterrupt:
11781130
pass
11791131

1132+
if args.distributed:
1133+
torch.distributed.destroy_process_group()
1134+
11801135
if best_metric is not None:
11811136
# log best metric as tracked by checkpoint saver
11821137
_logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))

validate.py

Lines changed: 6 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,6 @@
2828
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_fuser, \
2929
decay_batch_step, check_batch_size_retry, ParseKwargs, reparameterize_model
3030

31-
try:
32-
from apex import amp
33-
has_apex = True
34-
except ImportError:
35-
has_apex = False
3631

3732
try:
3833
from functorch.compile import memory_efficient_fusion
@@ -124,11 +119,9 @@
124119
parser.add_argument('--device', default='cuda', type=str,
125120
help="Device (accelerator) to use.")
126121
parser.add_argument('--amp', action='store_true', default=False,
127-
help='use NVIDIA Apex AMP or Native AMP for mixed precision training')
122+
help='use Native AMP for mixed precision inference')
128123
parser.add_argument('--amp-dtype', default='float16', type=str,
129124
help='lower precision AMP dtype (default: float16)')
130-
parser.add_argument('--amp-impl', default='native', type=str,
131-
help='AMP impl to use, "native" or "apex" (default: native)')
132125
parser.add_argument('--model-dtype', default=None, type=str,
133126
help='Model dtype override (non-AMP) (default: float32)')
134127
parser.add_argument('--tf-preprocessing', action='store_true', default=False,
@@ -197,22 +190,14 @@ def validate(args):
197190
assert args.model_dtype in ('float32', 'float16', 'bfloat16')
198191
model_dtype = getattr(torch, args.model_dtype)
199192

200-
# resolve AMP arguments based on PyTorch / Apex availability
201-
use_amp = None
193+
# resolve AMP arguments based on PyTorch availability
202194
amp_autocast = suppress
203195
if args.amp:
204196
assert model_dtype is None or model_dtype == torch.float32, 'float32 model dtype must be used with AMP'
205-
if args.amp_impl == 'apex':
206-
assert has_apex, 'AMP impl specified as APEX but APEX is not installed.'
207-
assert args.amp_dtype == 'float16'
208-
use_amp = 'apex'
209-
_logger.info('Validating in mixed precision with NVIDIA APEX AMP.')
210-
else:
211-
assert args.amp_dtype in ('float16', 'bfloat16')
212-
use_amp = 'native'
213-
amp_dtype = torch.bfloat16 if args.amp_dtype == 'bfloat16' else torch.float16
214-
amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype)
215-
_logger.info('Validating in mixed precision with native PyTorch AMP.')
197+
assert args.amp_dtype in ('float16', 'bfloat16')
198+
amp_dtype = torch.bfloat16 if args.amp_dtype == 'bfloat16' else torch.float16
199+
amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype)
200+
_logger.info('Validating in mixed precision with native PyTorch AMP.')
216201
else:
217202
_logger.info(f'Validating in {model_dtype or torch.float32}. AMP not enabled.')
218203

@@ -266,7 +251,6 @@ def validate(args):
266251
model = model.to(memory_format=torch.channels_last)
267252

268253
if args.torchscript:
269-
assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
270254
model = torch.jit.script(model)
271255
elif args.torchcompile:
272256
assert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.'
@@ -276,9 +260,6 @@ def validate(args):
276260
assert has_functorch, "functorch is needed for --aot-autograd"
277261
model = memory_efficient_fusion(model)
278262

279-
if use_amp == 'apex':
280-
model = amp.initialize(model, opt_level='O1')
281-
282263
if args.num_gpu > 1:
283264
model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu)))
284265

0 commit comments

Comments
 (0)