diff --git a/math-rm/prm_evaluate.py b/math-rm/prm_evaluate.py index 5b84c6b..c159020 100644 --- a/math-rm/prm_evaluate.py +++ b/math-rm/prm_evaluate.py @@ -1,16 +1,17 @@ -from transformers import AutoTokenizer -from transformers import AutoModelForCausalLM -from datasets import load_dataset -from accelerate import Accelerator -import numpy as np -import torch -from tqdm import tqdm import argparse import json -import time import os -import sys import re +import sys +import time + +import numpy as np +import torch +from accelerate import Accelerator +from datasets import load_dataset +from tqdm import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer + def parse_args(): parser = argparse.ArgumentParser() @@ -33,21 +34,33 @@ def batch_data(data_list, batch_size=8): batch_data.append(data_list[last_start:len(data_list)]) return batch_data -def select_sample(args,sample,model,tokenizer,candidate_tokens,local_rank): +def truncate_input_ids(input_ids, score_ids, max_len): + if input_ids.size(-1) > max_len: + input_ids = input_ids[:,:max_len] + score_ids = np.array(score_ids) + score_ids = score_ids[score_ids <= max_len].tolist() + return input_ids, score_ids + +def select_sample(args,sample,model,tokenizer,candidate_tokens,local_rank, max_len=4096): prompt = sample['prompt'] scores_list = [] - #text_list = [] answers = sample['answers'][:args.num_n] step_scores = [] + all_status = [] for ans in answers: single_step_score = [] conversation = [] - forward_conv = [] if args.model_type == "Mistral": ans_list = ans.split("ΠΊΠΈ\n") else: ans_list = ans.split("\n\n") ans_list = [j.strip() for j in ans_list] + score_ids = [] + # status: + # 0: normal, + # 1: truncated due to exceeding max_len, + # 2: no complete step after truncation (no score_ids) + status = 0 for k in range(len(ans_list)): if k == 0: text = prompt + " " + ans_list[0] @@ -56,18 +69,47 @@ def select_sample(args,sample,model,tokenizer,candidate_tokens,local_rank): conversation.append({"content":text,"role":"user"}) conversation.append({"content":"+","role":"assistant"}) - input_ids = tokenizer.apply_chat_template(conversation,return_tensors="pt").to(local_rank) + # try to concat the token of each part of conversation while recording the position of the reward token + if k == 0: + input_ids = tokenizer.apply_chat_template(conversation,return_tensors="pt") + else: + input_prompt = tokenizer.apply_chat_template(conversation[-2:], tokenize=False) + # search the pattern: user + user_start_pattern = r'<[^<>]+>user<[^<>]+>' + step_start_idx = re.search(user_start_pattern, input_prompt).start() + input_prompt = input_prompt[step_start_idx:] + step_ids = tokenizer(input_prompt,return_tensors="pt",add_special_tokens=False)['input_ids'] + input_ids = torch.cat([input_ids,step_ids],dim=-1) + + # record the predicting reward token position (typically '\n\n') + score_ids.append(input_ids.size(-1)-3) + + # check if the concat input_ids is correct + ref_input_ids = tokenizer.apply_chat_template(conversation,return_tensors="pt") + assert torch.all(ref_input_ids == input_ids) + + if input_ids.size(-1) > max_len: + input_ids, score_ids = truncate_input_ids(input_ids, score_ids, max_len) + status = 1 + + if len(score_ids) == 0: + single_step_score = [0.0] + status = 2 + else: with torch.no_grad(): - logits = model(input_ids).logits[:,-3,candidate_tokens] #simple version, the +/- is predicted by the '-3' position + logits = model(input_ids.to(local_rank)).logits[0,:,candidate_tokens] + logits = logits[score_ids,:] scores = logits.softmax(dim=-1)[:,0] # 0 means the prob of + (1 mean -) #print(scores) - single_step_score.append(scores[0].detach().to('cpu', dtype=torch.float32).item()) + single_step_score = scores.detach().to('cpu', dtype=torch.float32).tolist() step_scores.append(single_step_score) + all_status.append(status) scores_list.append(sum(single_step_score)/len(single_step_score)) idx = scores_list.index(max(scores_list)) sample['step_scores'] = step_scores + sample['status'] = all_status return sample['label'][idx] == 1,sample @@ -88,11 +130,11 @@ def worker(args, model, tokenizer, data, local_rank): if __name__ == "__main__": args = parse_args() - accelerator = Accelerator() + accelerator = Accelerator(mixed_precision='bf16') world_size = int(os.getenv("WORLD_SIZE", "1")) #print(world_size) - ds = load_dataset(args.dataset,split="test").select(range(8)) - local_rank = Accelerator().local_process_index + ds = load_dataset(args.dataset,split="test") + local_rank = accelerator.local_process_index print("---------------") print("begin to load reward model.") print("---------------") @@ -100,7 +142,9 @@ def worker(args, model, tokenizer, data, local_rank): while not downloaded: try: tokenizer = AutoTokenizer.from_pretrained(args.reward_name_or_path) - model = AutoModelForCausalLM.from_pretrained(args.reward_name_or_path, torch_dtype=torch.bfloat16).to(local_rank).eval() + model = AutoModelForCausalLM.from_pretrained(args.reward_name_or_path) + model = accelerator.prepare(model) + model.eval() downloaded = True except Exception as error: print("An error occurred:", error) @@ -109,7 +153,10 @@ def worker(args, model, tokenizer, data, local_rank): tokenizer.padding_side = "right" tokenizer.pad_token = tokenizer.eos_token - model.config.pad_token_id = model.config.eos_token_id + if accelerator.distributed_type == "MULTI_GPU": + model.module.config.pad_token_id = model.module.config.eos_token_id + else: + model.config.pad_token_id = model.config.eos_token_id data = [] data_size = len(ds["prompt"]) @@ -135,6 +182,7 @@ def worker(args, model, tokenizer, data, local_rank): import torch.distributed as dist + accelerator.wait_for_everyone() dist.all_gather_object(all_process_list, data_to_send) gathered_data = [] gathered_save_data = []