-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathndcg.py
More file actions
29 lines (26 loc) · 1.09 KB
/
ndcg.py
File metadata and controls
29 lines (26 loc) · 1.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch
import numpy as np
def ndcg_score(y_true, y_score, k=10, gains="exponential"):
def dcg_score(y_score, k=k, gains="exponential"):
y_score_k = y_score[:k]
if gains == "exponential":
gains = torch.pow(2.0, y_score_k) - 1.0
gains = gains.type(torch.FloatTensor)
elif gains == "linear":
gains = y_score
else:
raise ValueError("Invalid gains option.")
discounts = torch.log2(torch.arange(k).type(torch.FloatTensor) + 2)
return torch.sum(gains / discounts)
best = dcg_score(y_true, k, gains)
actual = dcg_score(y_score, k, gains)
result = actual / best
return result.item()
if __name__ == '__main__':
sys_sorted_labels = [1, 1, 0, 1, 0, 1, 0, 0]
ideal_sorted_labels=[1, 1, 1, 1, 0, 0, 0, 0]
sys_sorted_labels = torch.from_numpy(np.asarray(sys_sorted_labels))
ideal_sorted_labels = torch.from_numpy(np.asarray(ideal_sorted_labels))
for k in [1, 3, 4, 8, 10]:
if len(sys_sorted_labels) >= k:
print(ndcg_score(ideal_sorted_labels, sys_sorted_labels, k=k))