diff --git a/validate.py b/validate.py index e5ff9494bf..2c8550efcc 100755 --- a/validate.py +++ b/validate.py @@ -42,6 +42,12 @@ has_compile = hasattr(torch, 'compile') +try: + from sklearn.metrics import precision_score, recall_score, f1_score + has_sklearn = True +except ImportError: + has_sklearn = False + _logger = logging.getLogger('validate') @@ -158,6 +164,11 @@ parser.add_argument('--retry', default=False, action='store_true', help='Enable batch size decay & retry for single model validation') +parser.add_argument('--metrics-avg', type=str, default=None, + choices=['micro', 'macro', 'weighted'], + help='Enable precision, recall, F1-score calculation and specify the averaging method. ' + 'Requires scikit-learn. (default: None)') + # NaFlex loader arguments parser.add_argument('--naflex-loader', action='store_true', default=False, help='Use NaFlex loader (Requires NaFlex compatible model)') @@ -176,6 +187,11 @@ def validate(args): device = torch.device(args.device) + if args.metrics_avg and not has_sklearn: + _logger.warning( + f"scikit-learn not installed, disabling metrics calculation. Please install with 'pip install scikit-learn'.") + args.metrics_avg = None + model_dtype = None if args.model_dtype: assert args.model_dtype in ('float32', 'float16', 'bfloat16') @@ -346,6 +362,10 @@ def validate(args): top1 = AverageMeter() top5 = AverageMeter() + if args.metrics_avg: + all_preds = [] + all_targets = [] + model.eval() with torch.inference_mode(): # warmup, reduce variability of first batch time, especially for comparing torchscript vs non @@ -382,6 +402,11 @@ def validate(args): top1.update(acc1.item(), batch_size) top5.update(acc5.item(), batch_size) + if args.metrics_avg: + predictions = torch.argmax(output, dim=1) + all_preds.append(predictions.cpu()) + all_targets.append(target.cpu()) + # measure elapsed time batch_time.update(time.time() - end) end = time.time() @@ -408,18 +433,41 @@ def validate(args): top1a, top5a = real_labels.get_accuracy(k=1), real_labels.get_accuracy(k=5) else: top1a, top5a = top1.avg, top5.avg + + metric_results = {} + if args.metrics_avg: + all_preds = torch.cat(all_preds).numpy() + all_targets = torch.cat(all_targets).numpy() + precision = precision_score(all_targets, all_preds, average=args.metrics_avg, zero_division=0) + recall = recall_score(all_targets, all_preds, average=args.metrics_avg, zero_division=0) + f1 = f1_score(all_targets, all_preds, average=args.metrics_avg, zero_division=0) + metric_results = { + f'{args.metrics_avg}_precision': round(precision, 4), + f'{args.metrics_avg}_recall': round(recall, 4), + f'{args.metrics_avg}_f1_score': round(f1, 4), + } + results = OrderedDict( model=args.model, top1=round(top1a, 4), top1_err=round(100 - top1a, 4), top5=round(top5a, 4), top5_err=round(100 - top5a, 4), + **metric_results, param_count=round(param_count / 1e6, 2), img_size=data_config['input_size'][-1], crop_pct=crop_pct, interpolation=data_config['interpolation'], ) - _logger.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format( - results['top1'], results['top1_err'], results['top5'], results['top5_err'])) + log_string = ' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format( + results['top1'], results['top1_err'], results['top5'], results['top5_err']) + if metric_results: + log_string += ' | Precision({avg}) {prec:.3f} | Recall({avg}) {rec:.3f} | F1-score({avg}) {f1:.3f}'.format( + avg=args.metrics_avg, + prec=metric_results[f'{args.metrics_avg}_precision'], + rec=metric_results[f'{args.metrics_avg}_recall'], + f1=metric_results[f'{args.metrics_avg}_f1_score'], + ) + _logger.info(log_string) return results @@ -534,4 +582,4 @@ def write_results(results_file, results, format='csv'): if __name__ == '__main__': - main() + main() \ No newline at end of file