diff --git a/examples/tool_env/ppo_gsm8k_code_interpreter.py b/examples/tool_env/ppo_gsm8k_code_interpreter.py new file mode 100644 index 000000000..dc8309615 --- /dev/null +++ b/examples/tool_env/ppo_gsm8k_code_interpreter.py @@ -0,0 +1,189 @@ +# Generates positive movie reviews by tuning a pretrained model on IMDB dataset +# with a sentiment reward function +import os + +from datasets import load_dataset +from transformers import load_tool + +import trlx +from trlx.data.default_configs import ( + ModelConfig, + OptimizerConfig, + PPOConfig, + SchedulerConfig, + TokenizerConfig, + TrainConfig, + TRLConfig, +) +from trlx.environment.base_tool import ToolEnvironment + +os.environ["HF_ALLOW_CODE_EVAL"] = "1" +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +def ppo_init_config(): + return TRLConfig( + train=TrainConfig( + seq_length=768, + epochs=100, + total_steps=1000, + batch_size=32, + minibatch_size=1, + checkpoint_interval=10000, + eval_interval=20, + pipeline="PromptPipeline", + trainer="AcceleratePPOTrainer", + save_best=True, + checkpoint_dir="/fsx/home-duyphung/trlx_checkpoints", + ), + model=ModelConfig(model_path="codellama/CodeLlama-7b-Instruct-hf", num_layers_unfrozen=12), + tokenizer=TokenizerConfig(tokenizer_path="codellama/CodeLlama-7b-Instruct-hf", truncation_side="right"), + optimizer=OptimizerConfig( + name="adamw", kwargs=dict(lr=1e-6, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6) + ), + scheduler=SchedulerConfig(name="cosine_annealing", kwargs=dict(T_max=10000, eta_min=1.0e-7)), + method=PPOConfig( + name="PPOConfig", + num_rollouts=128, + chunk_size=16, + ppo_epochs=4, + init_kl_coef=0.001, + target=6, + horizon=10000, + gamma=1, + lam=0.95, + num_value_layers_unfrozen=8, + cliprange=0.2, + cliprange_value=0.2, + vf_coef=1, + scale_reward="ignored", + ref_mean=None, + ref_std=None, + cliprange_reward=10, + gen_kwargs=dict( + max_new_tokens=256, + top_k=0, + top_p=1.0, + do_sample=True, + ), + ), + ) + + +def exact_match_reward(responses, answers=None): + """Reward if generated response contains correct answer.""" + rewards = [] + for response, answer in zip(responses, answers): + reward = 0.0 + if "Error:" in response: + reward = -1.0 + else: + response = response.strip() + answer = answer.strip() + reward += float(response == answer) + rewards.append(reward) + return rewards + + +def create_reward_function(prompt): + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained("codellama/CodeLlama-7b-Instruct-hf") + tool_env = ToolEnvironment([load_tool("lvwerra/python-interpreter")], prompt, exact_match_reward, tokenizer) + + def reward_fn(samples, prompts, original_output, **kwargs): + rewards = tool_env.get_reward(samples, **{"answers": original_output}) + return rewards + + return reward_fn + + +def main(hparams={}): + # Merge sweep config with default config if given + config = TRLConfig.update(ppo_init_config().to_dict(), hparams) + + ds = load_dataset("gsm8k", "main", split="train") + ds = ds.rename_columns({"question": "query"}) + ds = ds.map(lambda x: {"answer": x["answer"].split("#### ")[1]}) + ds = ds.select(range(1, len(ds))) # skip the first sample which is used in prompt + + ds_test = load_dataset("gsm8k", "main", split="test").select(range(1, 500)) + ds_test = ds_test.rename_columns({"question": "query"}) + ds_test = ds_test.map(lambda x: {"answer": x["answer"].split("#### ")[1]}) + + df = ds.to_pandas() + df_test = ds_test.to_pandas() + + few_shot_prompt = """\ +Instruction: Using a Python API to solve math questions.\ +Write function solution to solve the following questions, then "print(solution())" to output the result. + +Question: Olivia has $23. She bought five bagels for $3 each. How much money does she have left? + + +def solution(): + money_initial = 23 + bagels = 5 + bagel_cost = 3 + money_spent = bagels * bagel_cost + money_left = money_initial - money_spent + result = money_left + return result +print(solution()) +72 + +Result = 72 + +Question: Michael loves to paint and sells his creations. He charges $100 for a large painting and $80 for a small painting.\ +At his last art show, he sold 5 large paintings and 8 small paintings. How much did he earn in all? + + +def solution(): + price_large = 100 + price_small = 80 + paintings_large = 5 + paintings_small = 8 + total_large_price = price_large * paintings_large + total_small_price = price_small * paintings_small + total = total_large_price + total_small_price + result = total + return result +print(solution()) +1140 + +Result = 1140 + +""" + + reward_fn = create_reward_function(few_shot_prompt) + + generate_prompt = """\ +{few_shot_prompt} +Question: {query} + +""" + + df["query"] = df["query"].apply(lambda x: generate_prompt.format(few_shot_prompt=few_shot_prompt, query=x)) + df_test["query"] = df_test["query"].apply( + lambda x: generate_prompt.format(few_shot_prompt=few_shot_prompt, query=x) + ) + + train_prompts = [ + {"prompt": query, "original_output": answer} + for (query, answer) in zip(df["query"].tolist(), df["answer"].tolist()) + ] + eval_prompts = [ + {"prompt": query, "original_output": answer} + for (query, answer) in zip(df_test["query"].tolist(), df_test["answer"].tolist()) + ] + trlx.train( + reward_fn=reward_fn, + prompts=train_prompts, + eval_prompts=eval_prompts, + config=config, + stop_sequences=["", ""], + ) + + +if __name__ == "__main__": + main() diff --git a/examples/tool_env/ppo_tool_calculate.py b/examples/tool_env/ppo_tool_calculate.py new file mode 100644 index 000000000..c20b7dae7 --- /dev/null +++ b/examples/tool_env/ppo_tool_calculate.py @@ -0,0 +1,187 @@ +# Generates positive movie reviews by tuning a pretrained model on IMDB dataset +# with a sentiment reward function +import os + +import numpy as np +import pandas as pd +from sklearn.model_selection import train_test_split +from transformers import load_tool + +import trlx +from trlx.data.default_configs import ( + ModelConfig, + OptimizerConfig, + PPOConfig, + SchedulerConfig, + TokenizerConfig, + TrainConfig, + TRLConfig, +) +from trlx.environment.base_tool import ToolEnvironment + +os.environ["HF_ALLOW_CODE_EVAL"] = "1" +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +def ppo_init_config(): + return TRLConfig( + train=TrainConfig( + seq_length=256, + epochs=100, + total_steps=1000, + batch_size=8, + checkpoint_interval=10000, + eval_interval=10, + pipeline="PromptPipeline", + trainer="AcceleratePPOTrainer", + save_best=False, + ), + model=ModelConfig(model_path="gpt2", num_layers_unfrozen=2), + tokenizer=TokenizerConfig(tokenizer_path="gpt2", truncation_side="right"), + optimizer=OptimizerConfig( + name="adamw", kwargs=dict(lr=1e-5, betas=(0.9, 0.99), eps=1.0e-8, weight_decay=1.0e-5) + ), + scheduler=SchedulerConfig(name="cosine_annealing", kwargs=dict(T_max=10000, eta_min=1.0e-6)), + method=PPOConfig( + name="PPOConfig", + num_rollouts=128, + chunk_size=16, + ppo_epochs=4, + init_kl_coef=0.001, + target=6, + horizon=10000, + gamma=1, + lam=0.95, + cliprange=0.2, + cliprange_value=0.2, + vf_coef=1, + scale_reward="ignored", + ref_mean=None, + ref_std=None, + cliprange_reward=10, + gen_kwargs=dict( + max_new_tokens=32, + top_k=0, + top_p=1.0, + do_sample=True, + ), + ), + ) + + +def generate_data(n): + """Generate random arithmetic tasks and answers.""" + tasks, answers = [], [] + for _ in range(n): + a = np.random.randint(0, 50) + b = np.random.randint(0, 50) + op = np.random.choice(["-", "+", "*"]) + tasks.append(f"\n\nWhat is {a} {op} {b}?") + if op == "-": + answers.append(a - b) + elif op == "+": + answers.append(a + b) + else: + answers.append(a * b) + return tasks, answers + + +def exact_match_reward(responses, answers=None): + """Reward if generated response contains correct answer.""" + rewards = [] + responses = [str(response) for response in responses] + answers = [str(answer) for answer in answers] + for response, answer in zip(responses, answers): + reward = 0.0 + if "Error:" in response: + reward = -1.0 + else: + response = response.strip() + answer = answer.strip() + try: + response = float(response) + answer = float(answer) + if np.abs(response - answer) < 1e-5: + reward = 1.0 + except ValueError: + reward = 0 + rewards.append(reward) + return rewards + + +def create_reward_function(prompt): + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained("gpt2") + tool_env = ToolEnvironment( + {"SimpleCalculatorTool": load_tool("ybelkada/simple-calculator")}, prompt, exact_match_reward, tokenizer + ) + + def reward_fn(samples, prompts, original_output, **kwargs): + rewards = tool_env.get_reward(samples, **{"answers": original_output}) + return rewards + + return reward_fn + + +def main(hparams={}): + # Merge sweep config with default config if given + config = TRLConfig.update(ppo_init_config().to_dict(), hparams) + + tasks, answers = generate_data(256 * 100) + tasks = [x.strip("\n") for x in tasks] + df = pd.DataFrame({"query": tasks, "answer": answers}) + + df, df_test = train_test_split(df, test_size=500) + few_shot_prompt = """\ +Q: What is 13 - 3? + +13-310.0 + +Result=10 + +Q: What is 4 * 3? + +4*312.0 + +Result=12 + +Q: What is 1 + 2? + +1+23 + +Result=3""" + + reward_fn = create_reward_function(few_shot_prompt) + + generate_prompt = """\ +{few_shot_prompt} + +Q: {query} + +""" + + df["query"] = df["query"].apply(lambda x: generate_prompt.format(few_shot_prompt=few_shot_prompt, query=x)) + df_test["query"] = df_test["query"].apply( + lambda x: generate_prompt.format(few_shot_prompt=few_shot_prompt, query=x) + ) + + train_prompts = [ + {"prompt": query, "original_output": answer} + for (query, answer) in zip(df["query"].tolist(), df["answer"].tolist()) + ] + eval_prompts = [ + {"prompt": query, "original_output": answer} + for (query, answer) in zip(df_test["query"].tolist(), df_test["answer"].tolist()) + ] + trlx.train( + reward_fn=reward_fn, + prompts=train_prompts, + eval_prompts=eval_prompts, + config=config, + stop_sequences=["", ""], + ) + + +if __name__ == "__main__": + main() diff --git a/trlx/environment/base_tool.py b/trlx/environment/base_tool.py new file mode 100644 index 000000000..d3e9900c5 --- /dev/null +++ b/trlx/environment/base_tool.py @@ -0,0 +1,79 @@ +### Adapted from trl: https://github.com/huggingface/trl/blob/main/trl/environment/base_environment.py +import re + + +class ToolEnvironment: + + """ + LLM interaction with the tool to get the reward + """ + + def __init__(self, tools=None, prompt=None, reward_fn=None, tokenizer=None): + if isinstance(tools, dict): + self.tools = tools + else: + self.tools = dict([(tool.__class__.__name__, tool) for tool in tools]) + self.prompt = prompt + self.reward_fn = reward_fn + self.request_token = "" + self.call_token = "" + self.response_token = "" + self.submit_token = "" + self.eos_token = tokenizer.eos_token + + def parse_tool_call(self, text): + """ + Parse request string. Expected format: query + """ + result = re.search(f"(?<={self.request_token}).*?(?={self.call_token})", text, re.DOTALL) + + # if we can't find a / span we return none + if result is None: + return None, None + else: + extracted_text = result.group() + + result = re.search(r"<(.*?)>", extracted_text) + + # if we can't find a tool name we return none + if result is None: + return None, None + else: + tool = result.group(1) + + # split off the tool name + query = ">".join(extracted_text.split(">")[1:]) + + return tool, query + + def get_reward(self, texts, **kwargs): + """ + Get the reward for the generated text + """ + tool_responses = [self.execution(text) for text in texts] + reward = self.reward_fn(tool_responses, **kwargs) + return reward + + def execution(self, text): + """ + Tool execution and get reward + """ + generated_text = self._get_generated_text(text) + tool, query = self.parse_tool_call(generated_text) + if tool is None or query is None: + response = f"Unknown tool call: {query}" + else: + if tool not in self.tools: + response = f"Unknown tool {tool}." + try: + response = self.tools[tool](query) + except Exception as error: + response = f"Tool Error: {str(error)}" + return response + + def _get_generated_text(self, text): + text = text.strip() + text = text.replace(self.eos_token, "") + if not text.endswith(self.call_token): + text = f"{text}{self.call_token}" + return text[len(self.prompt) :] diff --git a/trlx/models/modeling_ppo.py b/trlx/models/modeling_ppo.py index 51d54cf36..b0cad0ff4 100644 --- a/trlx/models/modeling_ppo.py +++ b/trlx/models/modeling_ppo.py @@ -1082,13 +1082,14 @@ def __init__( base_model: transformers.PreTrainedModel, *, num_layers_unfrozen: int, + frozen: bool = True, ): """ Args: base_model (transformers.PreTrainedModel): The pretrained model to extract upper trunk from num_layers_unfrozen (int): The number of trainable layers """ - super().__init__(base_model, num_layers_unfrozen=num_layers_unfrozen) + super().__init__(base_model, num_layers_unfrozen=num_layers_unfrozen, frozen=frozen) self.config = base_model.transformer.config self.bias = base_model.transformer.bias self.multi_query = base_model.transformer.multi_query