-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathplugin.py
More file actions
92 lines (76 loc) · 3.59 KB
/
plugin.py
File metadata and controls
92 lines (76 loc) · 3.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import asyncio
import re
from typing import List
from swift.plugin import ORM, orms
class CosineReward(ORM):
def __init__(self,
tokenizer=None,
cosine_min_len_value_wrong: float = -0.5,
cosine_max_len_value_wrong: float = 0.0,
cosine_min_len_value_correct: float = 1.0,
cosine_max_len_value_correct: float = 0.5,
cosine_max_len: int = 512,
soft_cache_length: int = 256):
self.tokenizer = tokenizer
self.min_len_value_wrong = cosine_min_len_value_wrong
self.max_len_value_wrong = cosine_max_len_value_wrong
self.min_len_value_correct = cosine_min_len_value_correct
self.max_len_value_correct = cosine_max_len_value_correct
self.max_len = cosine_max_len
self.soft_cache_length = soft_cache_length
self.accuracy_orm = CLSAccuracyORM_choice()
@staticmethod
def cosfn(t, T, min_value, max_value):
import math
return max_value - (max_value - min_value) * (1 - math.cos(t * math.pi / T)) / 2
def __call__(self, completions, solution, **kwargs) -> List[float]:
acc_rewards = self.accuracy_orm(completions, solution, **kwargs)
rewards = []
for content, acc_reward in zip(completions, acc_rewards):
# is_correct = acc_reward >= 1
if acc_reward == 1.0 or acc_reward == 0.0:
if acc_reward == 1.0:
# Swap min/max for correct answers
min_value = self.max_len_value_correct
max_value = self.min_len_value_correct
else:
min_value = self.max_len_value_wrong
max_value = self.min_len_value_wrong
gen_len = len(self.tokenizer.encode(content))
gen_len = max(0, gen_len - self.soft_cache_length) #长度小于256 不做奖励
max_len = max(0, self.max_len - self.soft_cache_length)
reward = self.cosfn(gen_len, max_len, min_value, max_value)
# reward = self.cosfn(gen_len, self.max_len, min_value, max_value)
elif acc_reward == -1.0:
reward = -1.0
else:
raise ValueError("Invalid acc reward")
rewards.append(reward)
return rewards
class CLSAccuracyORM_choice(ORM):
def __call__(self, completions, solution, **kwargs) -> List[float]:
"""
Reward function that checks if the completion is correct.
Args:
completions (list[str]): Generated outputs
solution (list[str]): Ground Truths.
Returns:
list[float]: Reward scores
"""
# print(completions)
rewards = []
for content, sol in zip(completions, solution):
reward = 0.0
content_match = re.search(r'<answer>(.*?)</answer>', content, re.DOTALL | re.IGNORECASE)
student_answer = content_match.group(1).strip() if content_match else content.strip()
ground_truth = sol.strip().replace(' ','').replace('_','').replace('.','').replace('\n','').lower()
student_answer = student_answer.strip().replace(' ','').replace('_','').replace('.','').replace('\n','').lower()
# Compare the extracted answers
if ground_truth == student_answer:
reward = 1.0
if student_answer not in ['a', 'b', 'c' ,'d']:
reward = -1.0
rewards.append(reward)
return rewards
orms['external_cls_acc_choice'] = CLSAccuracyORM_choice
orms['external_cosine'] = CosineReward