Skip to content

Commit 1b0e853

Browse files
authored
Merge pull request #13 from qedsoftware/eval-verifiable-metrics
Log FA and AFA metrics
2 parents 66fd5d1 + 914ad50 commit 1b0e853

File tree

3 files changed

+85
-2
lines changed

3 files changed

+85
-2
lines changed

timm/train.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -937,6 +937,10 @@ def train(config: dict[str, t.Any]):
937937
if eval_metrics:
938938
mlflow.log_metric("val loss", eval_metrics["loss"], step=epoch)
939939
mlflow.log_metric("val accuracy", eval_metrics["top1"], step=epoch)
940+
for vr in utils.EVAL_VERIFICATION_RATES:
941+
mlflow.log_metric(f"FA at {int(100 * vr):03d}", eval_metrics[f"fa@{vr}"])
942+
mlflow.log_metric(f"AFA at {int(100 * vr):03d}", eval_metrics[f"afa@{vr}"])
943+
940944

941945
if output_dir is not None:
942946
lrs = [param_group['lr'] for param_group in optimizer.param_groups]
@@ -1152,6 +1156,7 @@ def validate(
11521156
losses_m = utils.AverageMeter()
11531157
top1_m = utils.AverageMeter()
11541158
top5_m = utils.AverageMeter()
1159+
correct_with_confidences_m = utils.CorrectnessOfPredictionsWithConfidencesMeter()
11551160

11561161
model.eval()
11571162

@@ -1193,6 +1198,7 @@ def validate(
11931198
losses_m.update(reduced_loss.item(), input.size(0))
11941199
top1_m.update(acc1.item(), output.size(0))
11951200
top5_m.update(acc5.item(), output.size(0))
1201+
correct_with_confidences_m.update(output, target)
11961202

11971203
batch_time_m.update(time.time() - end)
11981204
end = time.time()
@@ -1206,7 +1212,32 @@ def validate(
12061212
f'Acc@5: {top5_m.val:>7.3f} ({top5_m.avg:>7.3f})'
12071213
)
12081214

1209-
metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)])
1215+
metrics = OrderedDict(
1216+
[
1217+
("loss", losses_m.avg),
1218+
("top1", top1_m.avg),
1219+
("top5", top5_m.avg),
1220+
*[
1221+
(f"fa@{vr}", fa)
1222+
for vr, fa in zip(
1223+
utils.EVAL_VERIFICATION_RATES,
1224+
correct_with_confidences_m.final_accuracy(
1225+
utils.EVAL_VERIFICATION_RATES
1226+
),
1227+
)
1228+
],
1229+
*[
1230+
(f"afa@{vr}", afa)
1231+
for vr, afa in zip(
1232+
utils.EVAL_VERIFICATION_RATES,
1233+
correct_with_confidences_m.average_final_accuracy(
1234+
utils.EVAL_VERIFICATION_RATES
1235+
),
1236+
)
1237+
],
1238+
]
1239+
)
1240+
12101241

12111242
return metrics
12121243

timm/utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
world_info_from_env, is_distributed_env, is_primary
99
from .jit import set_jit_legacy, set_jit_fuser
1010
from .log import setup_default_logging, FormatterNoInfo
11-
from .metrics import AverageMeter, accuracy
11+
from .metrics import AverageMeter, accuracy, CorrectnessOfPredictionsWithConfidencesMeter, EVAL_VERIFICATION_RATES
1212
from .misc import natural_key, add_bool_arg, ParseKwargs
1313
from .model import unwrap_model, get_state_dict, freeze, unfreeze, reparameterize_model
1414
from .model_ema import ModelEma, ModelEmaV2, ModelEmaV3

timm/utils/metrics.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
33
Hacked together by / Copyright 2020 Ross Wightman
44
"""
5+
import torch
56

7+
EVAL_VERIFICATION_RATES = [0.01, 0.02, 0.05, 0.1, 0.2]
68

79
class AverageMeter:
810
"""Computes and stores the average and current value"""
@@ -22,6 +24,56 @@ def update(self, val, n=1):
2224
self.avg = self.sum / self.count
2325

2426

27+
class CorrectnessOfPredictionsWithConfidencesMeter:
28+
def __init__(self):
29+
self.reset()
30+
31+
def reset(self):
32+
self.predictions_correct = []
33+
self.confidences = []
34+
35+
def update(self, output, target):
36+
confidences, preds = output.topk(k=1)
37+
preds = preds.t()
38+
correct = preds.eq(target.reshape(1, -1).expand_as(preds)).flatten()
39+
40+
self.predictions_correct.append(correct.detach().cpu())
41+
self.confidences.append(confidences.detach().cpu())
42+
43+
def final_accuracy(self, vrs):
44+
correct = torch.cat(self.predictions_correct)
45+
confidences = torch.cat(self.confidences)
46+
47+
correct_sorted = correct[confidences.flatten().argsort()]
48+
N = len(correct_sorted)
49+
50+
def _fa(vr):
51+
n_verified = round(vr * N)
52+
return (n_verified + correct_sorted[n_verified:].sum()) / N
53+
54+
return [_fa(vr) for vr in vrs]
55+
56+
def average_final_accuracy(self, vrs):
57+
correct = torch.cat(self.predictions_correct)
58+
confidences = torch.cat(self.confidences)
59+
60+
correct_sorted = correct[confidences.flatten().argsort()]
61+
N = len(correct_sorted)
62+
63+
def _afa(vr):
64+
# see https://drive.google.com/file/d/1Uag8VtD3RwsoS8hs59X6T5u_iwuqspkS/view
65+
# for derivation of this formula
66+
n_verified = round(vr * N)
67+
afa_weights = torch.arange(1, N + 1) / n_verified
68+
return (
69+
(n_verified - 1) / 2
70+
+ (afa_weights[:n_verified] * correct_sorted[:n_verified]).sum()
71+
+ correct_sorted[n_verified:].sum()
72+
) / N
73+
74+
return [_afa(vr) for vr in vrs]
75+
76+
2577
def accuracy(output, target, topk=(1,)):
2678
"""Computes the accuracy over the k top predictions for the specified values of k"""
2779
maxk = min(max(topk), output.size()[1])

0 commit comments

Comments
 (0)