Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 51 additions & 3 deletions validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@

has_compile = hasattr(torch, 'compile')

try:
from sklearn.metrics import precision_score, recall_score, f1_score
has_sklearn = True
except ImportError:
has_sklearn = False

_logger = logging.getLogger('validate')


Expand Down Expand Up @@ -158,6 +164,11 @@
parser.add_argument('--retry', default=False, action='store_true',
help='Enable batch size decay & retry for single model validation')

parser.add_argument('--metrics-avg', type=str, default=None,
choices=['micro', 'macro', 'weighted'],
help='Enable precision, recall, F1-score calculation and specify the averaging method. '
'Requires scikit-learn. (default: None)')

# NaFlex loader arguments
parser.add_argument('--naflex-loader', action='store_true', default=False,
help='Use NaFlex loader (Requires NaFlex compatible model)')
Expand All @@ -176,6 +187,11 @@ def validate(args):

device = torch.device(args.device)

if args.metrics_avg and not has_sklearn:
_logger.warning(
f"scikit-learn not installed, disabling metrics calculation. Please install with 'pip install scikit-learn'.")
args.metrics_avg = None

model_dtype = None
if args.model_dtype:
assert args.model_dtype in ('float32', 'float16', 'bfloat16')
Expand Down Expand Up @@ -346,6 +362,10 @@ def validate(args):
top1 = AverageMeter()
top5 = AverageMeter()

if args.metrics_avg:
all_preds = []
all_targets = []

model.eval()
with torch.inference_mode():
# warmup, reduce variability of first batch time, especially for comparing torchscript vs non
Expand Down Expand Up @@ -382,6 +402,11 @@ def validate(args):
top1.update(acc1.item(), batch_size)
top5.update(acc5.item(), batch_size)

if args.metrics_avg:
predictions = torch.argmax(output, dim=1)
all_preds.append(predictions.cpu())
all_targets.append(target.cpu())

# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
Expand All @@ -408,18 +433,41 @@ def validate(args):
top1a, top5a = real_labels.get_accuracy(k=1), real_labels.get_accuracy(k=5)
else:
top1a, top5a = top1.avg, top5.avg

metric_results = {}
if args.metrics_avg:
all_preds = torch.cat(all_preds).numpy()
all_targets = torch.cat(all_targets).numpy()
precision = precision_score(all_targets, all_preds, average=args.metrics_avg, zero_division=0)
recall = recall_score(all_targets, all_preds, average=args.metrics_avg, zero_division=0)
f1 = f1_score(all_targets, all_preds, average=args.metrics_avg, zero_division=0)
metric_results = {
f'{args.metrics_avg}_precision': round(precision, 4),
f'{args.metrics_avg}_recall': round(recall, 4),
f'{args.metrics_avg}_f1_score': round(f1, 4),
}

results = OrderedDict(
model=args.model,
top1=round(top1a, 4), top1_err=round(100 - top1a, 4),
top5=round(top5a, 4), top5_err=round(100 - top5a, 4),
**metric_results,
param_count=round(param_count / 1e6, 2),
img_size=data_config['input_size'][-1],
crop_pct=crop_pct,
interpolation=data_config['interpolation'],
)

_logger.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format(
results['top1'], results['top1_err'], results['top5'], results['top5_err']))
log_string = ' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format(
results['top1'], results['top1_err'], results['top5'], results['top5_err'])
if metric_results:
log_string += ' | Precision({avg}) {prec:.3f} | Recall({avg}) {rec:.3f} | F1-score({avg}) {f1:.3f}'.format(
avg=args.metrics_avg,
prec=metric_results[f'{args.metrics_avg}_precision'],
rec=metric_results[f'{args.metrics_avg}_recall'],
f1=metric_results[f'{args.metrics_avg}_f1_score'],
)
_logger.info(log_string)

return results

Expand Down Expand Up @@ -534,4 +582,4 @@ def write_results(results_file, results, format='csv'):


if __name__ == '__main__':
main()
main()