We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 1b0e853 commit de8ec78Copy full SHA for de8ec78
timm/utils/metrics.py
@@ -49,7 +49,7 @@ def final_accuracy(self, vrs):
49
50
def _fa(vr):
51
n_verified = round(vr * N)
52
- return (n_verified + correct_sorted[n_verified:].sum()) / N
+ return (n_verified + correct_sorted[n_verified:].sum().item()) / N
53
54
return [_fa(vr) for vr in vrs]
55
@@ -67,8 +67,8 @@ def _afa(vr):
67
afa_weights = torch.arange(1, N + 1) / n_verified
68
return (
69
(n_verified - 1) / 2
70
- + (afa_weights[:n_verified] * correct_sorted[:n_verified]).sum()
71
- + correct_sorted[n_verified:].sum()
+ + (afa_weights[:n_verified] * correct_sorted[:n_verified]).sum().item()
+ + correct_sorted[n_verified:].sum().item()
72
) / N
73
74
return [_afa(vr) for vr in vrs]
0 commit comments