From 8021879107e5a363b87a843906559cce5691f274 Mon Sep 17 00:00:00 2001 From: maxreciprocate <56548574+maxreciprocate@users.noreply.github.com> Date: Thu, 7 Dec 2023 18:33:16 +0200 Subject: [PATCH 1/2] feat(train_reward_model): force chatml & add stats for consistency, formatting has to happen either through tokenizer's `apply_chat_format` or throught ahead of time formatting in the dataset --- train_reward_model.py | 403 +++++++++++++++++++++++++++++------------- 1 file changed, 284 insertions(+), 119 deletions(-) diff --git a/train_reward_model.py b/train_reward_model.py index d2834fa..5af56c3 100644 --- a/train_reward_model.py +++ b/train_reward_model.py @@ -1,29 +1,89 @@ -import torch -import wandb import argparse import os -import transformers +import sys +from typing import List, Optional, Dict, Any, Union + +import matplotlib import numpy as np -from tqdm import tqdm -from time import time +import torch +import torch.nn as nn import torch.nn.functional as F -from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR -from huggingface_hub import list_repo_refs -from transformers import AutoTokenizer, AutoModelForSequenceClassification +import wandb from accelerate import Accelerator -from datasets import load_dataset +from datasets import load_dataset, load_from_disk +from huggingface_hub import list_repo_refs +from matplotlib import pyplot +from rich.console import Console +from rich.table import Table +from torch.optim.lr_scheduler import CosineAnnealingLR +from tqdm import tqdm +from transformers import (AutoConfig, AutoModelForCausalLM, + AutoModelForSequenceClassification, AutoTokenizer, + LlamaConfig, LlamaModel, PreTrainedModel) + + +# for any model which doesn't have a ForSequenceClassification wrapper in transformers +class ClassificationModel(PreTrainedModel): + def __init__(self, model_path): + super().__init__(AutoConfig.from_pretrained(model_path, trust_remote_code=True)) + self.llm = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True).model + self.score = torch.nn.Linear(self.config.hidden_size, 1, bias=False) + + def forward(self, input_ids, attention_mask, **kwargs): + logits = self.score(self.llm(input_ids, attention_mask=attention_mask, **kwargs)[0]) + sequence_lengths = (torch.eq(input_ids, self.llm.config.pad_token_id).long().argmax(-1) - 1).to(logits.device) + pooled_logits = logits[torch.arange(input_ids.shape[0], device=logits.device), sequence_lengths] + return pooled_logits[None] + +# wrapper for UltraRM model +class LlamaRewardModel(PreTrainedModel): + config_class = LlamaConfig + def __init__(self, config): + super().__init__(config) + self.model = LlamaModel(config) + self.regression_head = nn.Linear(self.config.hidden_size, 1, bias=False) + + def forward( # args are the same as LlamaForCausalLM + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + ) + + hidden_states = transformer_outputs[0] + rewards = self.regression_head(hidden_states).squeeze(-1) + + ends = attention_mask.cumsum(dim=1).argmax(dim=1).view(-1,1) + rewards = torch.gather(rewards, 1, ends) + + return rewards[None] parser = argparse.ArgumentParser() -parser.add_argument("--model_path", default="reciprocate/gpt2-tiny", type=str) +parser.add_argument("--model_path", default="reciprocate/tiny-mistral", type=str) parser.add_argument("--revision", default=None, type=str) parser.add_argument("--tokenizer_path", default=None, type=str) parser.add_argument("--dataset", default="reciprocate/number-pairs", type=str) -parser.add_argument("--lr", default=6e-4, type=float) +parser.add_argument("--lr", default=6e-3, type=float) parser.add_argument("--min_lr", default=None, type=float) -parser.add_argument("--weight_decay", default=0.1, type=float) -parser.add_argument("--batch_size", default=20, type=int) -parser.add_argument("--epochs", default=1, type=int) -parser.add_argument("--seq_length", default=1024, type=int) +parser.add_argument("--weight_decay", default=0.0, type=float) +parser.add_argument("--batch_size", default=8, type=int) +parser.add_argument("--epochs", default=4, type=int) +parser.add_argument("--seq_length", default=2048, type=int) parser.add_argument("--num_unfrozen_layers", default=None, type=int) parser.add_argument("--gradient_checkpointing", action="store_true") parser.add_argument("--load_in_4bit", action="store_true") @@ -31,10 +91,11 @@ parser.add_argument("--checkpoint_dir", default="checkpoints", type=str) parser.add_argument("--eval_interval", default=100, type=int) parser.add_argument("--only_eval", action="store_true") -parser.add_argument("--add_oasst_tokens", action="store_true") +parser.add_argument("--wrapper", action="store_true") +parser.add_argument("--wrapper_ultra", action="store_true") +parser.add_argument("--format", default="alpaca", type=str) parser.add_argument("--calibration_datasets", default=[], nargs="+", type=str) -args = parser.parse_args() - +args = parser.parse_args(args=[] if "__file__" not in globals() else sys.argv[1:]) def plot_calibration(model_name: str, dataset_name: str, delta_scores: np.ndarray) -> str: space = np.linspace(0, 4, 32) @@ -51,9 +112,6 @@ def plot_calibration(model_name: str, dataset_name: str, delta_scores: np.ndarra probs.append(prob) - import matplotlib - from matplotlib import pyplot - textcolor = "#333" matplotlib.style.use("ggplot") matplotlib.rcParams.update({ @@ -99,100 +157,167 @@ def plot_calibration(model_name: str, dataset_name: str, delta_scores: np.ndarra revision = "local" else: revision = list_repo_refs(args.model_path).branches[0].target_commit[:8] - model_name = f"{args.model_path}:{revision}" + model_name = f"{args.model_path.split('/')[-1]}" + experiment = os.environ.get("EXPERIMENT") accelerator = Accelerator(log_with="wandb") accelerator.init_trackers( project_name="autocrit", config=vars(args), - init_kwargs={"wandb": {"name": f"{model_name}@{args.dataset}"}}, + init_kwargs={"wandb": {"name": experiment if experiment else f"{model_name}@{args.dataset}"}} ) tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path or args.model_path) - - if args.add_oasst_tokens: - tokenizer.add_tokens(["<|assistant|>", "<|prefix_begin|>", "<|prefix_end|>", "<|prompter|>", "<|system|>"]) - tokenizer.add_special_tokens({"pad_token": "<|padding|>"}) tokenizer.padding_side = "right" tokenizer.truncation_side = "left" + if tokenizer.pad_token is None: + tokenizer.add_special_tokens({"pad_token": "<|padding|>"}) - def tokenize(prompt, selected, rejected, tokenizer): + def format(sample: Dict[str, Union[list, str]], tokenizer: AutoTokenizer) -> Dict[str, str]: + """ + There are multiple formats the datasets can be in, this function converts formats them + into a single string pair of (selected, rejected): + + 1. sample = {selected: str, rejected: str} + assume the samples are already formatted + + 2. sample = {prompt: str, selected: str, rejected: str} + prepend the prompt to the selected and rejected and format them as a single turn chatml + + 3. sample = {selected: chatml, rejected: chatml} + format the selected and rejected according to the tokenizer's template + """ + + if isinstance(sample["selected"], str) and isinstance(sample["rejected"], str): + if "prompt" not in sample: + return { + "selected": sample["selected"], + "rejected": sample["rejected"], + } + + selected = [{"role": "assistant", "content": sample["selected"]}] + rejected = [{"role": "assistant", "content": sample["rejected"]}] + + if isinstance(sample["prompt"], str): + prompt = [{"role": "user", "content": sample["prompt"]}] + + selected = prompt + selected + rejected = prompt + rejected + + else: + selected = sample["selected"] + rejected = sample["rejected"] + + return { + "selected": tokenizer.apply_chat_template(selected, tokenize=False), + "rejected": tokenizer.apply_chat_template(rejected, tokenize=False), + } + + def tokenize(x: Dict[str, str], tokenizer: AutoTokenizer) -> Dict[str, torch.LongTensor]: return { - "selected_input_ids": tokenizer(prompt + selected + tokenizer.eos_token, truncation=True, max_length=args.seq_length).input_ids, - "rejected_input_ids": tokenizer(prompt + rejected + tokenizer.eos_token, truncation=True, max_length=args.seq_length).input_ids, + "selected_input_ids": tokenizer(x["selected"], truncation=True, max_length=args.seq_length).input_ids, + "rejected_input_ids": tokenizer(x["rejected"], truncation=True, max_length=args.seq_length).input_ids, } def collate_fn(batch): input_ids = sum([[x["rejected_input_ids"], x["selected_input_ids"]] for x in batch], []) return tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt") - dataset = load_dataset(args.dataset) - if "chosen" in dataset["train"].column_names: - dataset = dataset.rename_column("chosen", "selected") - if "replies" in dataset["train"].column_names: - dataset = dataset.map(lambda x: {"selected": x["replies"][0], "rejected": x["replies"][1]}, remove_columns=["replies"]) - accelerator.print(args.dataset, dataset) - - def to_vicuna_format(sample): - prompt = sample["prompt"].strip() - prompt = prompt.replace("\n\nHuman: ", "USER: ") \ - .replace("\n\nAssistant: ", " ASSISTANT: ") \ - .replace("\n\nAssistant:", " ASSISTANT:") - if prompt.startswith("Human: "): - prompt = prompt.replace("Human: ", "USER: ") - if prompt.startswith(""): - prompt = prompt[4:] - - selected = " " + sample["selected"].strip() - rejected = " " + sample["rejected"].strip() - - return {"prompt": prompt, "selected": selected, "rejected": rejected} - - def to_oa_format(sample): - prompt = sample["prompt"].strip() - prompt = prompt.replace("\n\nHuman: ", "<|prompter|>") \ - .replace("\n\nAssistant: ", "<|assistant|>") \ - .replace("\n\nAssistant:", "<|assistant|>") - if prompt.startswith("Human: "): - prompt = prompt.replace("Human: ", "<|prompter|>") - - selected = sample["selected"].strip() - rejected = sample["rejected"].strip() - - return {"prompt": prompt, "selected": selected, "rejected": rejected} - - if args.add_oasst_tokens: - dataset = dataset.map(to_oa_format) + dataset_name, split = args.dataset.split(":") if ":" in args.dataset else (args.dataset, None) + if os.path.exists(dataset_name): + dataset = load_from_disk(dataset_name) else: - dataset = dataset.map(to_vicuna_format) + dataset = load_dataset(dataset_name) - eval_dataloaders = [] - for name in args.calibration_datasets: - calibration_dataset = load_dataset(name) - if "test" in calibration_dataset: - calibration_dataset = calibration_dataset["test"] + if "test" in dataset: + args.calibration_datasets.append(f'{args.dataset}:test') + ref_dataset_name = f'{args.dataset}:test' + else: + if len(args.calibration_datasets): + ref_dataset_name = args.calibration_datasets[0] + else: + ref_dataset_name = args.dataset + + dataset_names = [args.dataset, *args.calibration_datasets] + dataset_name, *eval_dataset_names = dataset_names + datasets = [] + + for name in dataset_names: + if ':' in name: + dataset_path, split = name.split(':') + else: + dataset_path = name + split = None + + if os.path.exists(dataset_path): + dataset = load_from_disk(dataset_path) else: - calibration_dataset = calibration_dataset["train"] + dataset = load_dataset(dataset_path) + + if split is not None: + dataset = dataset[split] + else: + if name in eval_dataset_names and "test" in dataset: + dataset = dataset["test"] + elif "train" in dataset: + dataset = dataset["train"] + else: + raise ValueError(f"There is no 'train' or 'test' split in `{name}`") + + if "chosen" in dataset.column_names: + dataset = dataset.rename_column("chosen", "selected") + if "question" in dataset.column_names: + dataset = dataset.rename_column("question", "prompt") + if "replies" in dataset.column_names: + dataset = dataset.map(lambda x: {"selected": x["replies"][0], "rejected": x["replies"][1]}, remove_columns=["replies"]) + + dataset = dataset.map(format, fn_kwargs=dict(tokenizer=tokenizer), desc="Formatting") - accelerator.print(name, calibration_dataset) - tokenized = calibration_dataset.map(tokenize, input_columns=["prompt", "selected", "rejected"], fn_kwargs=dict(tokenizer=tokenizer), desc="Tokenizing") - dataloader = torch.utils.data.DataLoader(tokenized, shuffle=False, batch_size=args.batch_size, collate_fn=collate_fn) - eval_dataloaders.append(dataloader) + if accelerator.is_main_process: + table = Table(title=f"Dataset: {name}", show_lines=True) + table.add_column("selected") + table.add_column("rejected") - tokenized = dataset.map(tokenize, input_columns=["prompt", "selected", "rejected"], fn_kwargs=dict(tokenizer=tokenizer), desc="Tokenizing") - dataloader = torch.utils.data.DataLoader(tokenized["train"], shuffle=True, batch_size=args.batch_size, collate_fn=collate_fn) - eval_dataloaders.append(torch.utils.data.DataLoader(tokenized["test"], shuffle=False, batch_size=args.batch_size, collate_fn=collate_fn)) + # replace forward slash with the next closest unicode character just for printing + # because rich treats it as a formatting character + clean_tags = lambda s: s.replace("[/INST]", "[∕INST]") - if transformers.__version__ >= "4.30.0": - kwargs = {"load_in_4bit": args.load_in_4bit} + for i in range(3): + table.add_row(clean_tags(dataset[i]["selected"]), clean_tags(dataset[i]["rejected"])) + + console = Console() + console.print(table) + + datasets.append(dataset) + + tokenized_datasets = [] + for dataset in datasets: + tokenized_datasets.append(dataset.map(tokenize, fn_kwargs=dict(tokenizer=tokenizer), desc="Tokenizing")) + + dataset, *eval_datasets = tokenized_datasets + + dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, batch_size=args.batch_size, collate_fn=collate_fn) + eval_dataloaders = [] + for eval_dataset in eval_datasets: + eval_dataloaders.append(torch.utils.data.DataLoader(eval_dataset, shuffle=False, batch_size=args.batch_size, collate_fn=collate_fn)) + + if args.wrapper: + model = ClassificationModel.from_pretrained(args.model_path) + model.llm.resize_token_embeddings(len(tokenizer)) + model.llm.config.pad_token_id = tokenizer.pad_token_id + elif args.wrapper_ultra: + model = LlamaRewardModel.from_pretrained(args.model_path, torch_dtype=torch.float16) + model.model.resize_token_embeddings(len(tokenizer)) else: - kwargs = {} - model = AutoModelForSequenceClassification.from_pretrained(args.model_path, revision=args.revision, num_labels=1, **kwargs) - model.config.pad_token_id = tokenizer.pad_token_id - model.resize_token_embeddings(len(tokenizer)) + model = AutoModelForSequenceClassification.from_pretrained(args.model_path, revision=args.revision, num_labels=1, trust_remote_code=True, load_in_4bit=args.load_in_4bit) + model.config.pad_token_id = tokenizer.pad_token_id + model.resize_token_embeddings(len(tokenizer)) if args.gradient_checkpointing: - model.gradient_checkpointing_enable() + if isinstance(model, ClassificationModel): + model.llm.gradient_checkpointing_enable() + else: + model.gradient_checkpointing_enable() if args.downscale_weight: model.score.weight.data *= 0.1 @@ -221,7 +346,6 @@ def to_oa_format(sample): model, *eval_dataloaders = accelerator.prepare(model, *eval_dataloaders) else: opt = torch.optim.AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.95), eps=1e-08, weight_decay=args.weight_decay) - scheduler = CosineAnnealingLR(opt, T_max=len(dataloader) * args.epochs, eta_min=args.min_lr or args.lr) model, opt, scheduler, dataloader, *eval_dataloaders, = accelerator.prepare(model, opt, scheduler, dataloader, *eval_dataloaders) @@ -232,64 +356,105 @@ def to_oa_format(sample): for iepoch in range(args.epochs): for batch in dataloader: if step % args.eval_interval == 0 or step == tbar.total - 1: - for dataset_name, eval_dataloader in zip(args.calibration_datasets + [args.dataset], eval_dataloaders): + accuracies = {} + statistics = {} + for dataset_name, eval_dataloader in zip(eval_dataset_names, eval_dataloaders): model.eval() all_scores, all_delta_scores, all_tokens = [], [], [] + main_scores, main_tokens = [], [] for batch in tqdm(eval_dataloader, desc=f"Evaluating on {dataset_name}", disable=not accelerator.is_main_process, leave=args.only_eval): with torch.no_grad(): scores = model(**batch)[0] + delta_scores = scores.reshape(-1, 2).diff().view(-1) - delta_scores = scores.reshape(-1, 2).diff().view(-1) + main_scores.extend(scores.view(-1).tolist()) + main_tokens.extend(batch["input_ids"].tolist()) + + scores = accelerator.gather_for_metrics(scores.view(-1)) delta_scores = accelerator.gather_for_metrics(delta_scores) + + all_scores.extend(scores.tolist()) all_delta_scores.extend(delta_scores.tolist()) - all_scores.extend(scores.view(-1).tolist()) - all_tokens.extend(batch["input_ids"].tolist()) delta_scores = np.hstack(all_delta_scores) + scores = np.hstack(all_scores) accuracy = (delta_scores > 0).mean() + accuracies[dataset_name] = accuracy if accelerator.is_main_process: image_path = plot_calibration(model_name, dataset_name, delta_scores) - texts = [text.replace(tokenizer.pad_token, "") for text in tokenizer.batch_decode(all_tokens)] - samples = wandb.Table(["text", "score"], rows=list(zip(texts, all_scores))[:128]) + texts = [text.replace(tokenizer.pad_token, "") for text in tokenizer.batch_decode(main_tokens)] + samples = wandb.Table(["text", "score"], rows=list(zip(texts, main_scores))[:128]) + + postfix = f"@{dataset_name.split('/')[-1]}" + stats = { + f"delta_scores{postfix}": delta_scores, + f"delta_scores/mean{postfix}": delta_scores.mean(), + f"delta_scores/std{postfix}": delta_scores.std(), + f"delta_scores/min{postfix}": delta_scores.min(), + f"delta_scores/max{postfix}": delta_scores.max(), + f"scores{postfix}": wandb.Histogram(scores), + f"scores/mean{postfix}": scores.mean(), + f"scores/std{postfix}": scores.std(), + f"scores/min{postfix}": scores.min(), + f"scores/max{postfix}": scores.max(), + } + + statistics[dataset_name] = stats - postfix = "" if dataset_name == args.dataset else f"@{dataset_name.split('/')[-1]}" accelerator.log({ f"accuracy{postfix}": accuracy, - f"samples{postfix}": samples, - f"delta_scores{postfix}": delta_scores, f"calibration{postfix}": wandb.Image(image_path), + f"samples{postfix}": samples, + **stats, }, step=step) - if accuracy > best_accuracy and dataset_name == args.dataset: - best_accuracy = accuracy - accelerator.log({"best_accuracy": best_accuracy}, step=step) - - if args.only_eval: - exit() - else: - path = f"{model_name}_{args.dataset}_{args.lr}".replace("/", "_").replace(":", "_").replace("@", "_") - accelerator.unwrap_model(model).save_pretrained( - os.path.join(args.checkpoint_dir, path), - save_function=accelerator.save, - is_main_process=accelerator.is_main_process, - state_dict=accelerator.get_state_dict(model), - ) - if accelerator.is_main_process: - tokenizer.save_pretrained(os.path.join(args.checkpoint_dir, path)) - accelerator.print(f"Checkpointing -> {os.path.join(args.checkpoint_dir, path)}") + tbar.set_postfix(accuracy=accuracies.get(ref_dataset_name), best_accuracy=best_accuracy) - if dataset_name == args.dataset: - tbar.set_postfix(accuracy=accuracy, best_accuracy=best_accuracy) + if accuracies.get(ref_dataset_name, 0) > best_accuracy: + best_accuracy = accuracies.get(ref_dataset_name, 0) + + best_accuracies = {f"accuracy@{k}@best": v for k, v in accuracies.items()} + best_statistics = [{f"{d}/{k}@best": v for k, v in xs.items()} for d, xs in statistics.items()] + best_statistics = {k: v for xs in best_statistics for k, v in xs.items()} + + accelerator.log({ + "best_accuracy": best_accuracy, + **best_accuracies, + **best_statistics, + }, step=step) + + if not args.only_eval: + path = f"{model_name}_{args.dataset}_{experiment}_{args.lr}".replace("/", "_").replace(":", "_").replace("@", "_") + accelerator.unwrap_model(model).save_pretrained( + os.path.join(args.checkpoint_dir, path), + save_function=accelerator.save, + is_main_process=accelerator.is_main_process, + state_dict=accelerator.get_state_dict(model), + ) + if accelerator.is_main_process: + tokenizer.save_pretrained(os.path.join(args.checkpoint_dir, path)) + accelerator.print(f"Checkpoint: {os.path.join(args.checkpoint_dir, path)}") + + tbar.set_postfix(accuracy=accuracies.get(ref_dataset_name), best_accuracy=best_accuracy) accelerator.wait_for_everyone() + if args.only_eval: + exit() + model.train() with accelerator.accumulate(model): - scores = model(**batch, use_cache=not args.gradient_checkpointing)[0] - loss = -F.logsigmoid(scores.reshape(-1, 2).diff()).mean() + scores = model(batch["input_ids"], attention_mask=batch["attention_mask"], use_cache=not args.gradient_checkpointing)[0] + + if "scores" in batch: + loss = F.mse_loss(scores.view(-1), batch["scores"].view(-1)) + else: + delta_scores = scores.reshape(-1, 2).diff() + loss = -F.logsigmoid(delta_scores).mean() + accelerator.backward(loss) opt.step() opt.zero_grad() From 9ae61bd25e906c6b10684bf9fffe9431a17d51da Mon Sep 17 00:00:00 2001 From: maxreciprocate <56548574+maxreciprocate@users.noreply.github.com> Date: Thu, 7 Dec 2023 18:40:00 +0200 Subject: [PATCH 2/2] feat(README): usage snippet with `apply_chat_template` --- README.md | 40 ++++++++++++++++++++++++++++------------ 1 file changed, 28 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index 0acefce..f2d8410 100644 --- a/README.md +++ b/README.md @@ -2,35 +2,51 @@ A repository for transformer critique learning and generation. ## Scalar reward models -Train [OpenLLaMA-13B](https://github.com/openlm-research/open_llama) on [Helpful and Harmless dataset](https://github.com/anthropics/hh-rlhf): +Train [Mistral-7B](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1) on [UltraFeedback](https://huggingface.co/datasets/allenai/ultrafeedback_binarized_cleaned) dataset: ```bash accelerate launch --config_file configs/accelerate/zero2.yaml \ train_reward_model.py \ - --model_path openlm-research/open_llama_13b \ - --dataset pvduy/rm_oa_hh \ - --batch_size 1 \ + --model_path mistralai/Mistral-7B-Instruct-v0.1 \ + --dataset allenai/ultrafeedback_binarized_cleaned:train_prefs \ + --batch_size 4 \ --eval_interval 1000 \ - --lr 0.00001 \ + --lr 0.000003 \ --weight_decay 0 \ --num_unfrozen_layers 12 \ --gradient_checkpointing \ --checkpoint_dir checkpoints \ - --calibration_datasets reciprocate/vicuna-fair-eval + --calibration_datasets allenai/ultrafeedback_binarized_cleaned:test_prefs Intel/orca_dpo_pairs reciprocate/fair-eval ``` Usage: + ```python -from transformers import AutoModelForSequenceClassification, AutoTokenizer +from transformers import pipeline + +reward_fn = pipeline( + "text-classification", + model="reciprocate/mistral-7b-rm", + truncation=True, + max_length=4096, + function_to_apply="none" +) -ckpt = "reciprocate/openllama-13b_rm_oasst-hh" -model = AutoModelForSequenceClassification.from_pretrained(ckpt, load_in_4bit=True) -tokenizer = AutoTokenizer.from_pretrained(ckpt) +chats = [[ + {"role": "user", "content": "When was the battle at Waterloo?"}, + {"role": "assistant", "content": "I think it was in 1983, but please double-check that when you have a chance."} +], [ + {"role": "user", "content": "When was the battle at Waterloo?"}, + {"role": "assistant", "content": "The battle at Waterloo took place on June 18, 1815."} +]] -model(**tokenizer("ASSISTANT: This sentence is a lie.", return_tensors="pt"))[0].item() +inputs = [reward_fn.tokenizer.apply_chat_template(chat, tokenize=False) for chat in chats] +output = reward_fn(inputs) +scores = [x["score"] for x in output] +scores ``` Output: ```python --1.626953125 +>>> [-1.0530743598937988, 0.6916144490242004] ```