Skip to content

Commit de8ec78

Browse files
author
Bartosz Smoczynski
committed
Dump metrics as number not tensor
1 parent 1b0e853 commit de8ec78

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

timm/utils/metrics.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def final_accuracy(self, vrs):
4949

5050
def _fa(vr):
5151
n_verified = round(vr * N)
52-
return (n_verified + correct_sorted[n_verified:].sum()) / N
52+
return (n_verified + correct_sorted[n_verified:].sum().item()) / N
5353

5454
return [_fa(vr) for vr in vrs]
5555

@@ -67,8 +67,8 @@ def _afa(vr):
6767
afa_weights = torch.arange(1, N + 1) / n_verified
6868
return (
6969
(n_verified - 1) / 2
70-
+ (afa_weights[:n_verified] * correct_sorted[:n_verified]).sum()
71-
+ correct_sorted[n_verified:].sum()
70+
+ (afa_weights[:n_verified] * correct_sorted[:n_verified]).sum().item()
71+
+ correct_sorted[n_verified:].sum().item()
7272
) / N
7373

7474
return [_afa(vr) for vr in vrs]

0 commit comments

Comments
 (0)