22
33Hacked together by / Copyright 2020 Ross Wightman
44"""
5+ import torch
56
7+ EVAL_VERIFICATION_RATES = [0.01 , 0.02 , 0.05 , 0.1 , 0.2 ]
68
79class AverageMeter :
810 """Computes and stores the average and current value"""
@@ -22,6 +24,56 @@ def update(self, val, n=1):
2224 self .avg = self .sum / self .count
2325
2426
27+ class CorrectnessOfPredictionsWithConfidencesMeter :
28+ def __init__ (self ):
29+ self .reset ()
30+
31+ def reset (self ):
32+ self .predictions_correct = []
33+ self .confidences = []
34+
35+ def update (self , output , target ):
36+ confidences , preds = output .topk (k = 1 )
37+ preds = preds .t ()
38+ correct = preds .eq (target .reshape (1 , - 1 ).expand_as (preds )).flatten ()
39+
40+ self .predictions_correct .append (correct .detach ().cpu ())
41+ self .confidences .append (confidences .detach ().cpu ())
42+
43+ def final_accuracy (self , vrs ):
44+ correct = torch .cat (self .predictions_correct )
45+ confidences = torch .cat (self .confidences )
46+
47+ correct_sorted = correct [confidences .flatten ().argsort ()]
48+ N = len (correct_sorted )
49+
50+ def _fa (vr ):
51+ n_verified = round (vr * N )
52+ return (n_verified + correct_sorted [n_verified :].sum ()) / N
53+
54+ return [_fa (vr ) for vr in vrs ]
55+
56+ def average_final_accuracy (self , vrs ):
57+ correct = torch .cat (self .predictions_correct )
58+ confidences = torch .cat (self .confidences )
59+
60+ correct_sorted = correct [confidences .flatten ().argsort ()]
61+ N = len (correct_sorted )
62+
63+ def _afa (vr ):
64+ # see https://drive.google.com/file/d/1Uag8VtD3RwsoS8hs59X6T5u_iwuqspkS/view
65+ # for derivation of this formula
66+ n_verified = round (vr * N )
67+ afa_weights = torch .arange (1 , N + 1 ) / n_verified
68+ return (
69+ (n_verified - 1 ) / 2
70+ + (afa_weights [:n_verified ] * correct_sorted [:n_verified ]).sum ()
71+ + correct_sorted [n_verified :].sum ()
72+ ) / N
73+
74+ return [_afa (vr ) for vr in vrs ]
75+
76+
2577def accuracy (output , target , topk = (1 ,)):
2678 """Computes the accuracy over the k top predictions for the specified values of k"""
2779 maxk = min (max (topk ), output .size ()[1 ])
0 commit comments