Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 68 additions & 20 deletions math-rm/prm_evaluate.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
from transformers import AutoTokenizer
from transformers import AutoModelForCausalLM
from datasets import load_dataset
from accelerate import Accelerator
import numpy as np
import torch
from tqdm import tqdm
import argparse
import json
import time
import os
import sys
import re
import sys
import time

import numpy as np
import torch
from accelerate import Accelerator
from datasets import load_dataset
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer


def parse_args():
parser = argparse.ArgumentParser()
Expand All @@ -33,21 +34,33 @@ def batch_data(data_list, batch_size=8):
batch_data.append(data_list[last_start:len(data_list)])
return batch_data

def select_sample(args,sample,model,tokenizer,candidate_tokens,local_rank):
def truncate_input_ids(input_ids, score_ids, max_len):
if input_ids.size(-1) > max_len:
input_ids = input_ids[:,:max_len]
score_ids = np.array(score_ids)
score_ids = score_ids[score_ids <= max_len].tolist()
return input_ids, score_ids

def select_sample(args,sample,model,tokenizer,candidate_tokens,local_rank, max_len=4096):
prompt = sample['prompt']
scores_list = []
#text_list = []
answers = sample['answers'][:args.num_n]
step_scores = []
all_status = []
for ans in answers:
single_step_score = []
conversation = []
forward_conv = []
if args.model_type == "Mistral":
ans_list = ans.split("ки\n")
else:
ans_list = ans.split("\n\n")
ans_list = [j.strip() for j in ans_list]
score_ids = []
# status:
# 0: normal,
# 1: truncated due to exceeding max_len,
# 2: no complete step after truncation (no score_ids)
status = 0
for k in range(len(ans_list)):
if k == 0:
text = prompt + " " + ans_list[0]
Expand All @@ -56,18 +69,47 @@ def select_sample(args,sample,model,tokenizer,candidate_tokens,local_rank):
conversation.append({"content":text,"role":"user"})
conversation.append({"content":"+","role":"assistant"})

input_ids = tokenizer.apply_chat_template(conversation,return_tensors="pt").to(local_rank)
# try to concat the token of each part of conversation while recording the position of the reward token
if k == 0:
input_ids = tokenizer.apply_chat_template(conversation,return_tensors="pt")
else:
input_prompt = tokenizer.apply_chat_template(conversation[-2:], tokenize=False)
# search the pattern: <xxxx>user<xxxx>
user_start_pattern = r'<[^<>]+>user<[^<>]+>'
step_start_idx = re.search(user_start_pattern, input_prompt).start()
input_prompt = input_prompt[step_start_idx:]
step_ids = tokenizer(input_prompt,return_tensors="pt",add_special_tokens=False)['input_ids']
input_ids = torch.cat([input_ids,step_ids],dim=-1)

# record the predicting reward token position (typically '\n\n')
score_ids.append(input_ids.size(-1)-3)

# check if the concat input_ids is correct
ref_input_ids = tokenizer.apply_chat_template(conversation,return_tensors="pt")
assert torch.all(ref_input_ids == input_ids)

if input_ids.size(-1) > max_len:
input_ids, score_ids = truncate_input_ids(input_ids, score_ids, max_len)
status = 1

if len(score_ids) == 0:
single_step_score = [0.0]
status = 2
else:
with torch.no_grad():
logits = model(input_ids).logits[:,-3,candidate_tokens] #simple version, the +/- is predicted by the '-3' position
logits = model(input_ids.to(local_rank)).logits[0,:,candidate_tokens]
logits = logits[score_ids,:]
scores = logits.softmax(dim=-1)[:,0] # 0 means the prob of + (1 mean -)
#print(scores)
single_step_score.append(scores[0].detach().to('cpu', dtype=torch.float32).item())
single_step_score = scores.detach().to('cpu', dtype=torch.float32).tolist()

step_scores.append(single_step_score)
all_status.append(status)
scores_list.append(sum(single_step_score)/len(single_step_score))

idx = scores_list.index(max(scores_list))
sample['step_scores'] = step_scores
sample['status'] = all_status
return sample['label'][idx] == 1,sample


Expand All @@ -88,19 +130,21 @@ def worker(args, model, tokenizer, data, local_rank):
if __name__ == "__main__":
args = parse_args()

accelerator = Accelerator()
accelerator = Accelerator(mixed_precision='bf16')
world_size = int(os.getenv("WORLD_SIZE", "1"))
#print(world_size)
ds = load_dataset(args.dataset,split="test").select(range(8))
local_rank = Accelerator().local_process_index
ds = load_dataset(args.dataset,split="test")
local_rank = accelerator.local_process_index
print("---------------")
print("begin to load reward model.")
print("---------------")
downloaded = False
while not downloaded:
try:
tokenizer = AutoTokenizer.from_pretrained(args.reward_name_or_path)
model = AutoModelForCausalLM.from_pretrained(args.reward_name_or_path, torch_dtype=torch.bfloat16).to(local_rank).eval()
model = AutoModelForCausalLM.from_pretrained(args.reward_name_or_path)
model = accelerator.prepare(model)
model.eval()
downloaded = True
except Exception as error:
print("An error occurred:", error)
Expand All @@ -109,7 +153,10 @@ def worker(args, model, tokenizer, data, local_rank):

tokenizer.padding_side = "right"
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = model.config.eos_token_id
if accelerator.distributed_type == "MULTI_GPU":
model.module.config.pad_token_id = model.module.config.eos_token_id
else:
model.config.pad_token_id = model.config.eos_token_id

data = []
data_size = len(ds["prompt"])
Expand All @@ -135,6 +182,7 @@ def worker(args, model, tokenizer, data, local_rank):

import torch.distributed as dist

accelerator.wait_for_everyone()
dist.all_gather_object(all_process_list, data_to_send)
gathered_data = []
gathered_save_data = []
Expand Down