-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
98 lines (79 loc) · 2.79 KB
/
utils.py
File metadata and controls
98 lines (79 loc) · 2.79 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from consts import NONE, PAD
import torch.nn.functional as F
# from : https://zhuanlan.zhihu.com/p/418305402
def compute_kl_loss(self, p, q ,pad_mask = None):
p_loss = F.kl_div(F.log_softmax(p, dim=-1), F.softmax(q, dim=-1), reduction='none')
q_loss = F.kl_div(F.log_softmax(q, dim=-1), F.softmax(p, dim=-1), reduction='none')
# pad_mask is for seq-level tasks
if pad_mask is not None:
p_loss.masked_fill_(pad_mask, 0.)
q_loss.masked_fill_(pad_mask, 0.)
# You can choose whether to use function "sum" and "mean" depending on your task
p_loss = p_loss.sum()
q_loss = q_loss.sum()
loss = (p_loss + q_loss) / 2
return loss
def build_vocab(labels, BIO_tagging=True):
all_labels = [PAD, NONE]
for label in labels:
if BIO_tagging:
all_labels.append('B-{}'.format(label))
all_labels.append('I-{}'.format(label))
else:
all_labels.append(label)
label2idx = {tag: idx for idx, tag in enumerate(all_labels)}
idx2label = {idx: tag for idx, tag in enumerate(all_labels)}
return all_labels, label2idx, idx2label
def calc_metric(y_true, y_pred):
"""
:param y_true: [(tuple), ...]
:param y_pred: [(tuple), ...]
:return:
"""
num_proposed = len(y_pred)
num_gold = len(y_true)
y_true_set = set(y_true)
num_correct = 0
for item in y_pred:
if item in y_true_set:
num_correct += 1
print('proposed: {}\tcorrect: {}\tgold: {}'.format(num_proposed, num_correct, num_gold))
if num_proposed != 0:
precision = num_correct / num_proposed
else:
precision = 1.0
if num_gold != 0:
recall = num_correct / num_gold
else:
recall = 1.0
if precision + recall != 0:
f1 = 2 * precision * recall / (precision + recall)
else:
f1 = 0
return precision, recall, f1
def find_triggers(labels):
"""
:param labels: ['B-Conflict:Attack', 'I-Conflict:Attack', 'O', 'B-Life:Marry']
:return: [(0, 2, 'Conflict:Attack'), (3, 4, 'Life:Marry')]
"""
result = []
labels = [label.split('-') for label in labels]
for i in range(len(labels)):
if labels[i][0] == 'B':
result.append([i, i + 1, labels[i][1]])
for item in result:
j = item[1]
while j < len(labels):
if labels[j][0] == 'I':
j = j + 1
item[1] = j
else:
break
return [tuple(item) for item in result]
# To watch performance comfortably on a telegram when training for a long time
def report_to_telegram(text, bot_token, chat_id):
try:
import requests
requests.get('https://api.telegram.org/bot{}/sendMessage?chat_id={}&text={}'.format(bot_token, chat_id, text))
except Exception as e:
print(e)