-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathinference_single_sample.py
More file actions
154 lines (123 loc) · 4.67 KB
/
inference_single_sample.py
File metadata and controls
154 lines (123 loc) · 4.67 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import torch
from transformers import AutoTokenizer
from src.model.qwen_decoder.modeling import IModelForCausalLM
from src.model.qwen_decoder.configuration import IConfig
device = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_PATH = "/data/rech/mofengra/opendecoder/ckpts/Qwen2.5-3B-Instruct_nq_hotpotqa_open_top10_irrel_shuffle/checkpoint-21000"
# ------------------
# Load model/tokenizer
# ------------------
config = IConfig.from_pretrained(MODEL_PATH)
model = IModelForCausalLM.from_pretrained(MODEL_PATH, config=config).to(device).eval()
tokenizer = AutoTokenizer.from_pretrained(
MODEL_PATH,
trust_remote_code=True,
padding_side="left",
)
# Add Passage tokens (must match training)
special_passage_tokens = [f"Passage_{i+1}:" for i in range(20)]
tokenizer.add_special_tokens({"additional_special_tokens": special_passage_tokens})
model.resize_token_embeddings(len(tokenizer))
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.unk_token
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
# ------------------
# Example input
# ------------------
question = "Who wrote the novel The Old Man and the Sea?"
documents = [
"The Old Man and the Sea is a short novel written by Ernest Hemingway in 1951.",
"Ernest Hemingway was an American novelist and short-story writer.",
"The book won the Pulitzer Prize for Fiction in 1953.",
"It tells the story of an aging Cuban fisherman.",
"Hemingway also wrote For Whom the Bell Tolls.",
"The novella was published in Life magazine.",
"It contributed to Hemingway winning the Nobel Prize.",
"The protagonist is named Santiago.",
"The story is set in the Gulf Stream.",
"The work is considered one of Hemingway's classics."
]
# document-level relevance (length = 10)
doc_scores = [0.95, 0.9, 0.4, 0.2, 0.1, 0.1, 0.05, 0.05, 0.05, 0.05]
# normalize exactly like your dataset (normal mode)
mx = max(doc_scores)
norm_scores = [s / mx for s in doc_scores]
# ------------------
# Build RAG prompt
# ------------------
context_parts = []
for i, doc in enumerate(documents):
context_parts.append(f"Passage_{i+1}: {doc}")
context = "\n".join(context_parts)
messages = [
{"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."},
{
"role": "user",
"content": (
"You should answer the question by referring to the knowledge provided below and integrating "
"the usefulness of your own knowledge. Just directly answer it in several words as a short answer "
"without any explanation.\n"
f"{context}\n\nQuestion:{question}\n"
),
},
]
prompt = tokenizer.apply_chat_template(messages, tokenize=False)
tokenized = tokenizer(
prompt,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=4096,
)
input_ids = tokenized["input_ids"].to(device)
attention_mask = tokenized["attention_mask"].to(device)
seq_len = input_ids.shape[1]
# ------------------
# Build token-level relevance_scores
# ------------------
relevance_scores = torch.ones(seq_len, dtype=torch.float)
# find Passage_i token positions
passage_starts = []
for i in range(len(documents)):
tok = f"Passage_{i+1}:"
tok_id = tokenizer.convert_tokens_to_ids(tok)
matches = (input_ids[0] == tok_id).nonzero(as_tuple=True)[0]
passage_starts.append(matches[0].item())
# find assistant start (same logic as dataset)
im_start = tokenizer.convert_tokens_to_ids("<|im_start|>")
assistant = tokenizer.convert_tokens_to_ids("assistant")
label_start = seq_len
positions = (input_ids[0] == im_start).nonzero(as_tuple=True)[0].tolist()
for p in reversed(positions):
if input_ids[0][p + 1] == assistant:
label_start = p
break
# compute passage spans
spans = []
for i in range(len(passage_starts)):
s = passage_starts[i]
e = passage_starts[i + 1] if i < len(passage_starts) - 1 else label_start - 1
spans.append((s, e))
# assign relevance per token
for i, (s, e) in enumerate(spans):
relevance_scores[s:e] = norm_scores[i]
relevance_scores = relevance_scores.unsqueeze(0).to(device)
# ------------------
# Generate
# ------------------
with torch.no_grad():
outputs = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
relevant_scores=relevance_scores,
max_new_tokens=64,
do_sample=False,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
)
answer = tokenizer.decode(
outputs[0][input_ids.shape[-1]:],
skip_special_tokens=True,
).strip().replace("assistant", "").replace("<|im_start|>\n", "").replace("system\n", "")
print("Answer:", answer)