From 61bc5b14436ef17772ddd342f31de1994dc05a62 Mon Sep 17 00:00:00 2001 From: cat-state Date: Mon, 13 Feb 2023 18:00:28 +0000 Subject: [PATCH 01/13] init --- examples/architext.py | 14 +- examples/ilql_sentiments.py | 12 +- examples/nemo_ilql_inference.py | 18 ++- examples/nemo_ilql_sentiments.py | 32 ++++- examples/ppo_sentiments.py | 14 +- examples/randomwalks/ilql_randomwalks.py | 11 +- examples/randomwalks/ppo_randomwalks.py | 11 +- examples/simulacra.py | 2 + trlx/data/default_configs.py | 175 +++++++++++++++++++++++ 9 files changed, 230 insertions(+), 59 deletions(-) create mode 100644 trlx/data/default_configs.py diff --git a/examples/architext.py b/examples/architext.py index d854c4858..141b7b3c7 100644 --- a/examples/architext.py +++ b/examples/architext.py @@ -1,12 +1,7 @@ # Toy example of optimizing textual interior designs to output the least number of rooms # Also see https://architext.design/ -import pathlib - -import yaml - import trlx -from trlx.data.configs import TRLConfig - +from trlx.data.default_configs import default_ppo_config def reward_fn(samples, **kwargs): "Gives a negative count of rooms for each sample" @@ -30,13 +25,8 @@ def reward_fn(samples, **kwargs): "[prompt] the kitchen is not adjacent to the bathroom [layout]", ] -config_path = pathlib.Path(__file__).parent.joinpath("../configs/ppo_config.yml") -with config_path.open() as f: - default_config = yaml.safe_load(f) - - def main(hparams={}): - config = TRLConfig.update(default_config, hparams) + config = default_ppo_config() trlx.train(model_path="architext/gptj-162M", reward_fn=reward_fn, prompts=prompts, config=config) diff --git a/examples/ilql_sentiments.py b/examples/ilql_sentiments.py index 03caa66aa..c1217bfca 100644 --- a/examples/ilql_sentiments.py +++ b/examples/ilql_sentiments.py @@ -1,27 +1,19 @@ import os -import pathlib from typing import Dict, List -import yaml from datasets import load_dataset from transformers import pipeline import trlx -from trlx.data.configs import TRLConfig - +from trlx.data.default_configs import default_ilql_config def get_positive_score(scores): "Extract value associated with a positive sentiment from pipeline's output" return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"] -config_path = pathlib.Path(__file__).parent.joinpath("../configs/ilql_config.yml") -with config_path.open() as f: - default_config = yaml.safe_load(f) - - def main(hparams={}): - config = TRLConfig.update(default_config, hparams) + config = default_ilql_config() sentiment_fn = pipeline( "sentiment-analysis", diff --git a/examples/nemo_ilql_inference.py b/examples/nemo_ilql_inference.py index 425a8cdb2..1fb96d9e6 100644 --- a/examples/nemo_ilql_inference.py +++ b/examples/nemo_ilql_inference.py @@ -2,7 +2,6 @@ import sys from glob import glob -import yaml from nemo.collections.nlp.modules.common.megatron.megatron_init import ( fake_initialize_model_parallel, ) @@ -10,11 +9,23 @@ from nemo.utils.model_utils import inject_model_parallel_rank from omegaconf.omegaconf import OmegaConf -from trlx.data.configs import TRLConfig +from trlx.data.configs import TRLConfig, TrainConfig, SchedulerConfig +from trlx.data.default_configs import default_ilql_config + from trlx.trainer.nemo_ilql_trainer import ILQLGPT, megatron_trainer -default_config = yaml.safe_load(open(os.path.dirname(__file__) + "/../configs/nemo_ilql_config.yml")) +default_config = default_ilql_config() + +nemo_ilql_train_cfg = TrainConfig( + **default_config.train.__dict__, + trainer="NeMoILQLTrainer", + trainer_kwargs=dict( + pretrained_model="/mnt/nvme/home/uwu/nemo-megatron-gpt-20B/", + megatron_cfg="megatron_20b.yaml", + ) +) +trl_config = TRLConfig(**default_config.__dict__, train=nemo_ilql_train_cfg) def find_checkpoints(checkpoint_dir): checkpoints = glob(os.path.join(checkpoint_dir, "*", "*.ckpt")) @@ -23,7 +34,6 @@ def find_checkpoints(checkpoint_dir): def main(megatron_cfg_path, checkpoint_path): - trl_config = TRLConfig.update(default_config, {}) ilql_config = trl_config.method megatron_cfg = OmegaConf.load(megatron_cfg_path) diff --git a/examples/nemo_ilql_sentiments.py b/examples/nemo_ilql_sentiments.py index 82abe3b7b..af399bb0d 100644 --- a/examples/nemo_ilql_sentiments.py +++ b/examples/nemo_ilql_sentiments.py @@ -1,12 +1,11 @@ -import os from typing import Dict, List -import yaml from datasets import load_dataset from transformers import pipeline import trlx -from trlx.data.configs import TRLConfig +from trlx.data.configs import TRLConfig, TrainConfig, SchedulerConfig +from trlx.data.default_configs import default_ilql_config def get_positive_score(scores): @@ -14,11 +13,32 @@ def get_positive_score(scores): return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"] -default_config = yaml.safe_load(open(os.path.dirname(__file__) + "/../configs/nemo_ilql_config.yml")) +default_config = default_ilql_config() + +nemo_ilql_train_cfg = TrainConfig( + **default_config.train.__dict__, + seq_length=1024, + batch_size=512, + total_steps=200, + trainer="NeMoILQLTrainer", + trainer_kwargs=dict( + pretrained_model="/mnt/nvme/home/uwu/nemo-megatron-gpt-20B/", + megatron_cfg="megatron_20b.yaml", + ) +) + +scheduler_cfg = SchedulerConfig( + name="cosine_annealing", + kwargs=dict( + T_max=nemo_ilql_train_cfg.total_steps, + eta_min=1.0e-6 + ) +) + +config = TRLConfig(**default_config.__dict__, train=nemo_ilql_train_cfg, scheduler=scheduler_cfg) -def main(hparams={}): - config = TRLConfig.update(default_config, hparams) +def main(): sentiment_fn = pipeline( "sentiment-analysis", diff --git a/examples/ppo_sentiments.py b/examples/ppo_sentiments.py index 2f8edc91a..cbb96d2bc 100644 --- a/examples/ppo_sentiments.py +++ b/examples/ppo_sentiments.py @@ -2,30 +2,22 @@ # with a sentiment reward function import os -import pathlib from typing import List import torch -import yaml from datasets import load_dataset from transformers import pipeline import trlx -from trlx.data.configs import TRLConfig - +from trlx.data.default_configs import default_ppo_config def get_positive_score(scores): "Extract value associated with a positive sentiment from pipeline's output" return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"] -config_path = pathlib.Path(__file__).parent.joinpath("../configs/ppo_config.yml") -with config_path.open() as f: - default_config = yaml.safe_load(f) - - -def main(hparams={}): - config = TRLConfig.update(default_config, hparams) +def main(): + config = default_ppo_config() if torch.cuda.is_available(): device = int(os.environ.get("LOCAL_RANK", 0)) diff --git a/examples/randomwalks/ilql_randomwalks.py b/examples/randomwalks/ilql_randomwalks.py index ebc31660a..48cdee9b7 100644 --- a/examples/randomwalks/ilql_randomwalks.py +++ b/examples/randomwalks/ilql_randomwalks.py @@ -5,15 +5,10 @@ import trlx from examples.randomwalks import generate_random_walks -from trlx.data.configs import TRLConfig +from trlx.data.default_configs import default_ilql_config -config_path = pathlib.Path(__file__).parent.joinpath("configs/ilql_randomwalks.yml") -with config_path.open() as f: - default_config = yaml.safe_load(f) - - -def main(hparams={}): - config = TRLConfig.update(default_config, hparams) +def main(): + config = default_ilql_config() metric_fn, eval_prompts, walks, _ = generate_random_walks(seed=config.train.seed) rewards = metric_fn(walks)["optimality"] diff --git a/examples/randomwalks/ppo_randomwalks.py b/examples/randomwalks/ppo_randomwalks.py index 113897fe6..7878230de 100644 --- a/examples/randomwalks/ppo_randomwalks.py +++ b/examples/randomwalks/ppo_randomwalks.py @@ -5,15 +5,10 @@ import trlx from examples.randomwalks import generate_random_walks from trlx.data.configs import TRLConfig - -config_path = pathlib.Path(__file__).parent.joinpath("configs/ppo_randomwalks.yml") -with config_path.open() as f: - default_config = yaml.safe_load(f) - +from trlx.data.default_configs import default_ppo_config def main(hparams={}): - config = TRLConfig.update(default_config, hparams) - + config = default_ppo_config() metric_fn, prompts, *_ = generate_random_walks(seed=config.train.seed) trlx.train( @@ -25,7 +20,7 @@ def main(hparams={}): prompts=prompts, eval_prompts=prompts, metric_fn=lambda samples, prompts, outputs: metric_fn(samples), - config=config, + config=config ) diff --git a/examples/simulacra.py b/examples/simulacra.py index cc28520d6..f4d6f82d8 100644 --- a/examples/simulacra.py +++ b/examples/simulacra.py @@ -6,6 +6,7 @@ from urllib.request import urlretrieve import trlx +from trlx.data.default_configs import default_ilql_config url = "https://raw.githubusercontent.com/JD-P/simulacra-aesthetic-captions/main/sac_public_2022_06_29.sqlite" dbpath = "sac_public_2022_06_29.sqlite" @@ -26,6 +27,7 @@ prompts, ratings = tuple(map(list, zip(*c.fetchall()))) trlx.train( + config=default_ilql_config(), samples=prompts, rewards=ratings, eval_prompts=["Hatsune Miku, Red Dress"] * 64, diff --git a/trlx/data/default_configs.py b/trlx/data/default_configs.py new file mode 100644 index 000000000..81306194e --- /dev/null +++ b/trlx/data/default_configs.py @@ -0,0 +1,175 @@ +from .configs import TRLConfig, TrainConfig, ModelConfig, TokenizerConfig, OptimizerConfig, SchedulerConfig +from ..trainer.nn.ppo_models import PPOConfig +from ..trainer.nn.ilql_models import ILQLConfig + +def default_ppo_config(): + return TRLConfig( + train=TrainConfig( + seq_length=1024, + epochs=100, + total_steps=10000, + batch_size=32, + checkpoint_interval=10000, + eval_interval=100, + pipeline="PromptPipeline", + trainer="AcceleratePPOTrainer" + ), + model=ModelConfig( + model_path="lwerra/gpt2-imdb", + num_layers_unfrozen=2 + ), + tokenizer=TokenizerConfig( + tokenizer_path="gpt2", + truncation_side="right" + ), + optimizer=OptimizerConfig( + name="adamw", + kwargs=dict( + lr=1.0e-4, + betas=(0.9, 0.95), + eps=1.0e-8, + weight_decay=1.0e-6 + ) + ), + scheduler=SchedulerConfig( + name="cosine_annealing", + kwargs=dict( + T_max=10000, + eta_min=1.0e-4 + ) + ), + method=PPOConfig( + name="PPOConfig", + num_rollouts=128, + chunk_size=128, + ppo_epochs=4, + init_kl_coef=0.05, + target=6, + horizon=10000, + gamma=1, + lam=0.95, + cliprange=0.2, + cliprange_value=0.2, + vf_coef=1, + scale_reward="ignored", + ref_mean=None, + ref_std=None, + cliprange_reward=10, + gen_kwargs=dict( + max_new_tokens=40, + top_k=0, + top_p=1.0, + do_sample=True, + ) + ) + ) +''' +train: + seq_length: 64 + batch_size: 128 + epochs: 100 + total_steps: 1000 + + checkpoint_interval: 1000 + eval_interval: 100 + + pipeline: "PromptPipeline" + trainer: "AccelerateILQLTrainer" + seed: 1000 + +model: + model_path: "gpt2" + num_layers_unfrozen: -1 + +tokenizer: + tokenizer_path: "gpt2" + truncation_side: "right" + +optimizer: + name: "adamw" + kwargs: + lr: 5.0e-5 + betas: [0.9, 0.95] + eps: 1.0e-8 + weight_decay: 1.0e-6 + +scheduler: + name: "cosine_annealing" + kwargs: + T_max: 1000 # train.total_steps + eta_min: 5.0e-5 + +method: + name: "ilqlconfig" + tau: 0.7 + gamma: 0.99 + cql_scale: 0.1 + awac_scale: 1 + alpha: 0.001 + beta: 0 + steps_for_target_q_sync: 5 + two_qs: true + gen_kwargs: + max_new_tokens: 56 + top_k: 20 + beta: 4 + temperature: 1.0 +''' +# rewrite the above in python below + +def default_ilql_config(): + return TRLConfig( + train=TrainConfig( + seq_length=64, + batch_size=128, + epochs=100, + total_steps=1000, + checkpoint_interval=1000, + eval_interval=100, + pipeline="PromptPipeline", + trainer="AccelerateILQLTrainer" + ), + model=ModelConfig( + model_path="gpt2", + num_layers_unfrozen=-1 + ), + tokenizer=TokenizerConfig( + tokenizer_path="gpt2", + truncation_side="right" + ), + optimizer=OptimizerConfig( + name="adamw", + kwargs=dict( + lr=5.0e-5, + betas=(0.9, 0.95), + eps=1.0e-8, + weight_decay=1.0e-6 + ) + ), + scheduler=SchedulerConfig( + name="cosine_annealing", + kwargs=dict( + T_max=1000, # train.total_steps + eta_min=5.0e-5 + ) + ), + method=ILQLConfig( + name="ilqlconfig", + tau=0.7, + gamma=0.99, + cql_scale=0.1, + awac_scale=1, + alpha=0.001, + beta=0, + steps_for_target_q_sync=5, + two_qs=True, + gen_kwargs=dict( + max_new_tokens=56, + top_k=20, + beta=4, + temperature=1.0 + ) + ) + ) + + From 7820d981d0f2d47787c61494b98bd6158b979040 Mon Sep 17 00:00:00 2001 From: cat-state Date: Mon, 13 Feb 2023 18:02:18 +0000 Subject: [PATCH 02/13] remove unused --- examples/architext.py | 2 + examples/ilql_sentiments.py | 1 + examples/nemo_ilql_inference.py | 6 +- examples/nemo_ilql_sentiments.py | 9 +-- examples/ppo_sentiments.py | 1 + examples/randomwalks/ilql_randomwalks.py | 4 +- examples/randomwalks/ppo_randomwalks.py | 8 +-- trlx/data/default_configs.py | 88 ++++++++---------------- 8 files changed, 41 insertions(+), 78 deletions(-) diff --git a/examples/architext.py b/examples/architext.py index 141b7b3c7..d3461d4b9 100644 --- a/examples/architext.py +++ b/examples/architext.py @@ -3,6 +3,7 @@ import trlx from trlx.data.default_configs import default_ppo_config + def reward_fn(samples, **kwargs): "Gives a negative count of rooms for each sample" return [-sample.count(":") for sample in samples] @@ -25,6 +26,7 @@ def reward_fn(samples, **kwargs): "[prompt] the kitchen is not adjacent to the bathroom [layout]", ] + def main(hparams={}): config = default_ppo_config() diff --git a/examples/ilql_sentiments.py b/examples/ilql_sentiments.py index c1217bfca..7b96b08f5 100644 --- a/examples/ilql_sentiments.py +++ b/examples/ilql_sentiments.py @@ -7,6 +7,7 @@ import trlx from trlx.data.default_configs import default_ilql_config + def get_positive_score(scores): "Extract value associated with a positive sentiment from pipeline's output" return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"] diff --git a/examples/nemo_ilql_inference.py b/examples/nemo_ilql_inference.py index 1fb96d9e6..6d90f0bf6 100644 --- a/examples/nemo_ilql_inference.py +++ b/examples/nemo_ilql_inference.py @@ -9,9 +9,8 @@ from nemo.utils.model_utils import inject_model_parallel_rank from omegaconf.omegaconf import OmegaConf -from trlx.data.configs import TRLConfig, TrainConfig, SchedulerConfig +from trlx.data.configs import TrainConfig, TRLConfig from trlx.data.default_configs import default_ilql_config - from trlx.trainer.nemo_ilql_trainer import ILQLGPT, megatron_trainer default_config = default_ilql_config() @@ -22,11 +21,12 @@ trainer_kwargs=dict( pretrained_model="/mnt/nvme/home/uwu/nemo-megatron-gpt-20B/", megatron_cfg="megatron_20b.yaml", - ) + ), ) trl_config = TRLConfig(**default_config.__dict__, train=nemo_ilql_train_cfg) + def find_checkpoints(checkpoint_dir): checkpoints = glob(os.path.join(checkpoint_dir, "*", "*.ckpt")) names = [os.path.basename(c) for c in checkpoints] diff --git a/examples/nemo_ilql_sentiments.py b/examples/nemo_ilql_sentiments.py index af399bb0d..6067ff567 100644 --- a/examples/nemo_ilql_sentiments.py +++ b/examples/nemo_ilql_sentiments.py @@ -4,7 +4,7 @@ from transformers import pipeline import trlx -from trlx.data.configs import TRLConfig, TrainConfig, SchedulerConfig +from trlx.data.configs import SchedulerConfig, TrainConfig, TRLConfig from trlx.data.default_configs import default_ilql_config @@ -28,18 +28,13 @@ def get_positive_score(scores): ) scheduler_cfg = SchedulerConfig( - name="cosine_annealing", - kwargs=dict( - T_max=nemo_ilql_train_cfg.total_steps, - eta_min=1.0e-6 - ) + name="cosine_annealing", kwargs=dict(T_max=nemo_ilql_train_cfg.total_steps, eta_min=1.0e-6) ) config = TRLConfig(**default_config.__dict__, train=nemo_ilql_train_cfg, scheduler=scheduler_cfg) def main(): - sentiment_fn = pipeline( "sentiment-analysis", "lvwerra/distilbert-imdb", diff --git a/examples/ppo_sentiments.py b/examples/ppo_sentiments.py index cbb96d2bc..8c33489c8 100644 --- a/examples/ppo_sentiments.py +++ b/examples/ppo_sentiments.py @@ -11,6 +11,7 @@ import trlx from trlx.data.default_configs import default_ppo_config + def get_positive_score(scores): "Extract value associated with a positive sentiment from pipeline's output" return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"] diff --git a/examples/randomwalks/ilql_randomwalks.py b/examples/randomwalks/ilql_randomwalks.py index 48cdee9b7..043787e10 100644 --- a/examples/randomwalks/ilql_randomwalks.py +++ b/examples/randomwalks/ilql_randomwalks.py @@ -1,12 +1,10 @@ -import pathlib - -import yaml from transformers import GPT2Config import trlx from examples.randomwalks import generate_random_walks from trlx.data.default_configs import default_ilql_config + def main(): config = default_ilql_config() diff --git a/examples/randomwalks/ppo_randomwalks.py b/examples/randomwalks/ppo_randomwalks.py index 7878230de..ad1e7919e 100644 --- a/examples/randomwalks/ppo_randomwalks.py +++ b/examples/randomwalks/ppo_randomwalks.py @@ -1,12 +1,8 @@ -import pathlib - -import yaml - import trlx from examples.randomwalks import generate_random_walks -from trlx.data.configs import TRLConfig from trlx.data.default_configs import default_ppo_config + def main(hparams={}): config = default_ppo_config() metric_fn, prompts, *_ = generate_random_walks(seed=config.train.seed) @@ -20,7 +16,7 @@ def main(hparams={}): prompts=prompts, eval_prompts=prompts, metric_fn=lambda samples, prompts, outputs: metric_fn(samples), - config=config + config=config, ) diff --git a/trlx/data/default_configs.py b/trlx/data/default_configs.py index 81306194e..78b80bea8 100644 --- a/trlx/data/default_configs.py +++ b/trlx/data/default_configs.py @@ -1,6 +1,14 @@ -from .configs import TRLConfig, TrainConfig, ModelConfig, TokenizerConfig, OptimizerConfig, SchedulerConfig -from ..trainer.nn.ppo_models import PPOConfig from ..trainer.nn.ilql_models import ILQLConfig +from ..trainer.nn.ppo_models import PPOConfig +from .configs import ( + ModelConfig, + OptimizerConfig, + SchedulerConfig, + TokenizerConfig, + TrainConfig, + TRLConfig, +) + def default_ppo_config(): return TRLConfig( @@ -12,32 +20,14 @@ def default_ppo_config(): checkpoint_interval=10000, eval_interval=100, pipeline="PromptPipeline", - trainer="AcceleratePPOTrainer" - ), - model=ModelConfig( - model_path="lwerra/gpt2-imdb", - num_layers_unfrozen=2 - ), - tokenizer=TokenizerConfig( - tokenizer_path="gpt2", - truncation_side="right" + trainer="AcceleratePPOTrainer", ), + model=ModelConfig(model_path="lwerra/gpt2-imdb", num_layers_unfrozen=2), + tokenizer=TokenizerConfig(tokenizer_path="gpt2", truncation_side="right"), optimizer=OptimizerConfig( - name="adamw", - kwargs=dict( - lr=1.0e-4, - betas=(0.9, 0.95), - eps=1.0e-8, - weight_decay=1.0e-6 - ) - ), - scheduler=SchedulerConfig( - name="cosine_annealing", - kwargs=dict( - T_max=10000, - eta_min=1.0e-4 - ) + name="adamw", kwargs=dict(lr=1.0e-4, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6) ), + scheduler=SchedulerConfig(name="cosine_annealing", kwargs=dict(T_max=10000, eta_min=1.0e-4)), method=PPOConfig( name="PPOConfig", num_rollouts=128, @@ -60,10 +50,12 @@ def default_ppo_config(): top_k=0, top_p=1.0, do_sample=True, - ) - ) + ), + ), ) -''' + + +""" train: seq_length: 64 batch_size: 128 @@ -114,9 +106,10 @@ def default_ppo_config(): top_k: 20 beta: 4 temperature: 1.0 -''' +""" # rewrite the above in python below + def default_ilql_config(): return TRLConfig( train=TrainConfig( @@ -127,31 +120,15 @@ def default_ilql_config(): checkpoint_interval=1000, eval_interval=100, pipeline="PromptPipeline", - trainer="AccelerateILQLTrainer" - ), - model=ModelConfig( - model_path="gpt2", - num_layers_unfrozen=-1 - ), - tokenizer=TokenizerConfig( - tokenizer_path="gpt2", - truncation_side="right" + trainer="AccelerateILQLTrainer", ), + model=ModelConfig(model_path="gpt2", num_layers_unfrozen=-1), + tokenizer=TokenizerConfig(tokenizer_path="gpt2", truncation_side="right"), optimizer=OptimizerConfig( - name="adamw", - kwargs=dict( - lr=5.0e-5, - betas=(0.9, 0.95), - eps=1.0e-8, - weight_decay=1.0e-6 - ) + name="adamw", kwargs=dict(lr=5.0e-5, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6) ), scheduler=SchedulerConfig( - name="cosine_annealing", - kwargs=dict( - T_max=1000, # train.total_steps - eta_min=5.0e-5 - ) + name="cosine_annealing", kwargs=dict(T_max=1000, eta_min=5.0e-5) # train.total_steps ), method=ILQLConfig( name="ilqlconfig", @@ -163,13 +140,6 @@ def default_ilql_config(): beta=0, steps_for_target_q_sync=5, two_qs=True, - gen_kwargs=dict( - max_new_tokens=56, - top_k=20, - beta=4, - temperature=1.0 - ) - ) + gen_kwargs=dict(max_new_tokens=56, top_k=20, beta=4, temperature=1.0), + ), ) - - From f3c40e52bd42a53b6ba54a41945f7742ab6e9f50 Mon Sep 17 00:00:00 2001 From: cat-state Date: Mon, 13 Feb 2023 18:32:40 +0000 Subject: [PATCH 03/13] add to train --- examples/randomwalks/ppo_randomwalks.py | 2 +- examples/sft_sentiments.py | 12 +--- trlx/data/default_configs.py | 83 +++++++++---------------- trlx/trlx.py | 11 +++- 4 files changed, 40 insertions(+), 68 deletions(-) diff --git a/examples/randomwalks/ppo_randomwalks.py b/examples/randomwalks/ppo_randomwalks.py index ad1e7919e..1675b4865 100644 --- a/examples/randomwalks/ppo_randomwalks.py +++ b/examples/randomwalks/ppo_randomwalks.py @@ -3,7 +3,7 @@ from trlx.data.default_configs import default_ppo_config -def main(hparams={}): +def main(): config = default_ppo_config() metric_fn, prompts, *_ = generate_random_walks(seed=config.train.seed) diff --git a/examples/sft_sentiments.py b/examples/sft_sentiments.py index c289d3a38..9f22077f4 100644 --- a/examples/sft_sentiments.py +++ b/examples/sft_sentiments.py @@ -1,17 +1,11 @@ import os -import pathlib from typing import Dict, List -import yaml from datasets import load_dataset from transformers import pipeline import trlx -from trlx.data.configs import TRLConfig - -config_path = pathlib.Path(__file__).parent.joinpath("../configs/sft_config.yml") -with config_path.open() as f: - default_config = yaml.safe_load(f) +from trlx.data.default_configs import default_sft_config def get_positive_score(scores): @@ -19,8 +13,8 @@ def get_positive_score(scores): return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"] -def main(hparams={}): - config = TRLConfig.update(default_config, hparams) +def main(): + config = default_sft_config() imdb = load_dataset("imdb", split="train+test") # Finetune on only positive reviews diff --git a/trlx/data/default_configs.py b/trlx/data/default_configs.py index 78b80bea8..607585383 100644 --- a/trlx/data/default_configs.py +++ b/trlx/data/default_configs.py @@ -1,3 +1,4 @@ +from ..trainer.accelerate_sft_trainer import SFTConfig from ..trainer.nn.ilql_models import ILQLConfig from ..trainer.nn.ppo_models import PPOConfig from .configs import ( @@ -55,61 +56,6 @@ def default_ppo_config(): ) -""" -train: - seq_length: 64 - batch_size: 128 - epochs: 100 - total_steps: 1000 - - checkpoint_interval: 1000 - eval_interval: 100 - - pipeline: "PromptPipeline" - trainer: "AccelerateILQLTrainer" - seed: 1000 - -model: - model_path: "gpt2" - num_layers_unfrozen: -1 - -tokenizer: - tokenizer_path: "gpt2" - truncation_side: "right" - -optimizer: - name: "adamw" - kwargs: - lr: 5.0e-5 - betas: [0.9, 0.95] - eps: 1.0e-8 - weight_decay: 1.0e-6 - -scheduler: - name: "cosine_annealing" - kwargs: - T_max: 1000 # train.total_steps - eta_min: 5.0e-5 - -method: - name: "ilqlconfig" - tau: 0.7 - gamma: 0.99 - cql_scale: 0.1 - awac_scale: 1 - alpha: 0.001 - beta: 0 - steps_for_target_q_sync: 5 - two_qs: true - gen_kwargs: - max_new_tokens: 56 - top_k: 20 - beta: 4 - temperature: 1.0 -""" -# rewrite the above in python below - - def default_ilql_config(): return TRLConfig( train=TrainConfig( @@ -143,3 +89,30 @@ def default_ilql_config(): gen_kwargs=dict(max_new_tokens=56, top_k=20, beta=4, temperature=1.0), ), ) + + +def default_sft_config(): + return TRLConfig( + train=TrainConfig( + seq_length=1024, + epochs=100, + total_steps=1000, + batch_size=8, + checkpoint_interval=10000, + eval_interval=100, + pipeline="PromptPipeline", + trainer="AccelerateSFTTrainer", + ), + model=ModelConfig(model_path="gpt2", num_layers_unfrozen=-1), + tokenizer=TokenizerConfig(tokenizer_path="gpt2", truncation_side="right"), + optimizer=OptimizerConfig( + name="adamw", kwargs=dict(lr=1.0e-4, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6) + ), + scheduler=SchedulerConfig( + name="cosine_annealing", kwargs=dict(T_max=10000, eta_min=1.0e-4) # train.total_steps + ), + method=SFTConfig( + name="sftconfig", + gen_kwargs=dict(max_new_tokens=40, top_k=0, top_p=1.0, do_sample=True), + ), + ) diff --git a/trlx/trlx.py b/trlx/trlx.py index 10e3621d1..fbfff75bf 100644 --- a/trlx/trlx.py +++ b/trlx/trlx.py @@ -3,6 +3,11 @@ from typing import Callable, Dict, Iterable, List, Optional, Tuple from trlx.data.configs import TRLConfig +from trlx.data.default_configs import ( + default_ilql_config, + default_ppo_config, + default_sft_config, +) from trlx.utils import set_seed from trlx.utils.loading import get_pipeline, get_trainer @@ -55,11 +60,11 @@ def train( # noqa: C901 "Passing the `config` argument implicitly is depreciated, load it from `configs` directory instead" ) if reward_fn: - config = TRLConfig.load_yaml("configs/ppo_config.yml") + config = default_ppo_config() elif rewards: - config = TRLConfig.load_yaml("configs/ilql_config.yml") + config = default_ilql_config() else: - config = TRLConfig.load_yaml("configs/sft_config.yml") + config = default_sft_config() set_seed(config.train.seed) From b208d8babe82cfe84e95d6a4299324699f37546f Mon Sep 17 00:00:00 2001 From: cat-state Date: Mon, 13 Feb 2023 22:13:44 +0000 Subject: [PATCH 04/13] remove default configs and update readme --- README.md | 13 ++++++++ configs/ilql_config.yml | 50 ---------------------------- configs/nemo_ilql_config.yml | 52 ----------------------------- configs/ppo_config.yml | 56 -------------------------------- configs/ppo_gptj.yml | 56 -------------------------------- configs/sft_config.yml | 41 ----------------------- examples/nemo_ilql_sentiments.py | 2 ++ trlx/data/default_configs.py | 4 +-- 8 files changed, 17 insertions(+), 257 deletions(-) delete mode 100644 configs/ilql_config.yml delete mode 100644 configs/nemo_ilql_config.yml delete mode 100644 configs/ppo_config.yml delete mode 100644 configs/ppo_gptj.yml delete mode 100644 configs/sft_config.yml diff --git a/README.md b/README.md index da9ba405d..93a39e46a 100644 --- a/README.md +++ b/README.md @@ -57,6 +57,19 @@ trainer = trlx.train('EleutherAI/gpt-j-6B', dataset=[('dolphins', 'geese'), (1.0 trainer.generate(**tokenizer('Q: Who rules the world? A:', return_tensors='pt'), do_sample=True) ``` +#### Configure Hyperparameters + +```python +from trlx.data.default_configs import default_ppo_config, TrainConfig + +config = default_ppo_config() +config.model.model_path = 'EleutherAI/gpt-neox-20b' +config.train.seq_length = 32 +config.train.batch_size = 16 + +trainer = trlx.train(config=config, reward_fn=lambda samples, **kwargs: [float(int(sample)) for sample in samples]) +``` + #### Save the resulting model to a Hugging Face pretrained language model. (Ready to upload to the Hub!) ```python diff --git a/configs/ilql_config.yml b/configs/ilql_config.yml deleted file mode 100644 index 40c162c70..000000000 --- a/configs/ilql_config.yml +++ /dev/null @@ -1,50 +0,0 @@ -train: - seq_length: 64 - batch_size: 128 - epochs: 100 - total_steps: 1000 - - checkpoint_interval: 1000 - eval_interval: 100 - - pipeline: "PromptPipeline" - trainer: "AccelerateILQLTrainer" - seed: 1000 - -model: - model_path: "gpt2" - num_layers_unfrozen: -1 - -tokenizer: - tokenizer_path: "gpt2" - truncation_side: "right" - -optimizer: - name: "adamw" - kwargs: - lr: 5.0e-5 - betas: [0.9, 0.95] - eps: 1.0e-8 - weight_decay: 1.0e-6 - -scheduler: - name: "cosine_annealing" - kwargs: - T_max: 1000 # train.total_steps - eta_min: 5.0e-5 - -method: - name: "ilqlconfig" - tau: 0.7 - gamma: 0.99 - cql_scale: 0.1 - awac_scale: 1 - alpha: 0.001 - beta: 0 - steps_for_target_q_sync: 5 - two_qs: true - gen_kwargs: - max_new_tokens: 56 - top_k: 20 - beta: 4 - temperature: 1.0 diff --git a/configs/nemo_ilql_config.yml b/configs/nemo_ilql_config.yml deleted file mode 100644 index 1d4cc71e2..000000000 --- a/configs/nemo_ilql_config.yml +++ /dev/null @@ -1,52 +0,0 @@ -train: - seq_length: 1024 - batch_size: 512 - epochs: 100 - total_steps: 200 - checkpoint_interval: 200 - eval_interval: 20 - - pipeline: "PromptPipeline" - trainer: "NeMoILQLTrainer" - trainer_kwargs: - pretrained_model: "/mnt/nvme/home/uwu/nemo-megatron-gpt-20B/" - megatron_cfg: "megatron_20b.yaml" - seed: 1000 - -model: - model_path: "gpt2" - num_layers_unfrozen: -1 - -tokenizer: - tokenizer_path: "gpt2" - truncation_side: "right" - -optimizer: - name: "adamw" - kwargs: - lr: 5.0e-5 - betas: [0.9, 0.95] - eps: 1.0e-8 - weight_decay: 1.0e-6 - -scheduler: - name: "cosine_annealing" - kwargs: - T_max: 2000 # train.total_steps - eta_min: 1.0e-6 - -method: - name: "ilqlconfig" - tau: 0.7 - gamma: 0.99 - cql_scale: 0.1 - awac_scale: 1 - alpha: 0.001 - beta: 0 - steps_for_target_q_sync: 5 - two_qs: True - gen_kwargs: - max_new_tokens: 56 - top_k: 20 - beta: 2 - temperature: 0.9 diff --git a/configs/ppo_config.yml b/configs/ppo_config.yml deleted file mode 100644 index 92388a2be..000000000 --- a/configs/ppo_config.yml +++ /dev/null @@ -1,56 +0,0 @@ -train: - seq_length: 1024 - epochs: 100 - total_steps: 10000 - batch_size: 128 - - checkpoint_interval: 10000 - eval_interval: 100 - - pipeline: "PromptPipeline" - trainer: "AcceleratePPOTrainer" - -model: - model_path: "lvwerra/gpt2-imdb" - num_layers_unfrozen: 2 - -tokenizer: - tokenizer_path: "gpt2" - truncation_side: "right" - -optimizer: - name: "adamw" - kwargs: - lr: 1.0e-4 - betas: [0.9, 0.95] - eps: 1.0e-8 - weight_decay: 1.0e-6 - -scheduler: - name: "cosine_annealing" - kwargs: - T_max: 10000 # train.total_steps - eta_min: 1.0e-4 - -method: - name: "ppoconfig" - num_rollouts: 128 - chunk_size: 128 - ppo_epochs: 4 - init_kl_coef: 0.05 - target: 6 - horizon: 10000 - gamma: 1 - lam: 0.95 - cliprange: 0.2 - cliprange_value: 0.2 - vf_coef: 1 - scale_reward: False - ref_mean: null - ref_std: null - cliprange_reward: 10 - gen_kwargs: - max_new_tokens: 40 - top_k: 0 - top_p: 1.0 - do_sample: True diff --git a/configs/ppo_gptj.yml b/configs/ppo_gptj.yml deleted file mode 100644 index 0595f7ded..000000000 --- a/configs/ppo_gptj.yml +++ /dev/null @@ -1,56 +0,0 @@ -train: - seq_length: 48 - epochs: 10 - total_steps: 80000 - batch_size: 8 - - checkpoint_interval: 1000000 - eval_interval: 16 - - pipeline: "PromptPipeline" - trainer: "AcceleratePPOTrainer" - -model: - model_path: "EleutherAI/gpt-j-6B" - num_layers_unfrozen: 2 - -tokenizer: - tokenizer_path: "gpt2" - -optimizer: - name: "adamw" - kwargs: - lr: 1.412e-4 - betas: [0.9, 0.95] - eps: 1.0e-8 - weight_decay: 1.0e-6 - -scheduler: - name: "cosine_annealing" - kwargs: - T_max: 80000 # train.total_steps - eta_min: 1.412e-4 - -method: - name: "ppoconfig" - num_rollouts: 8 - chunk_size: 8 - ppo_epochs: 4 - init_kl_coef: 0.2 - target: 6 - horizon: 10000 - gamma: 1 - lam: 0.95 - cliprange: 0.2 - cliprange_value: 0.2 - vf_coef: 0.2 - scale_reward: False - ref_mean: null - ref_std: null - cliprange_reward: 10 - gen_kwargs: - max_new_tokens: 48 - top_k: 0.0 - top_p: 0.7 - do_sample: True - temperature: 0.5 diff --git a/configs/sft_config.yml b/configs/sft_config.yml deleted file mode 100644 index 4b1efe358..000000000 --- a/configs/sft_config.yml +++ /dev/null @@ -1,41 +0,0 @@ -train: - seq_length: 1024 - epochs: 100 - total_steps: 1000 - batch_size: 8 - - checkpoint_interval: 10000 - eval_interval: 100 - - pipeline: "PromptPipeline" - trainer: "AccelerateSFTTrainer" - -model: - model_path: "gpt2" - num_layers_unfrozen: -1 - -tokenizer: - tokenizer_path: "gpt2" - truncation_side: "right" - -optimizer: - name: "adamw" - kwargs: - lr: 1.0e-4 - betas: [0.9, 0.95] - eps: 1.0e-8 - weight_decay: 1.0e-6 - -scheduler: - name: "cosine_annealing" - kwargs: - T_max: 10000 # train.total_steps - eta_min: 1.0e-4 - -method: - name: "sftconfig" - gen_kwargs: - max_new_tokens: 40 - top_k: 0 - top_p: 1.0 - do_sample: True diff --git a/examples/nemo_ilql_sentiments.py b/examples/nemo_ilql_sentiments.py index 6067ff567..4a790fa94 100644 --- a/examples/nemo_ilql_sentiments.py +++ b/examples/nemo_ilql_sentiments.py @@ -35,6 +35,8 @@ def get_positive_score(scores): def main(): + print(config) + return sentiment_fn = pipeline( "sentiment-analysis", "lvwerra/distilbert-imdb", diff --git a/trlx/data/default_configs.py b/trlx/data/default_configs.py index 607585383..3f1196b11 100644 --- a/trlx/data/default_configs.py +++ b/trlx/data/default_configs.py @@ -23,7 +23,7 @@ def default_ppo_config(): pipeline="PromptPipeline", trainer="AcceleratePPOTrainer", ), - model=ModelConfig(model_path="lwerra/gpt2-imdb", num_layers_unfrozen=2), + model=ModelConfig(model_path="lvwerra/gpt2-imdb", num_layers_unfrozen=2), tokenizer=TokenizerConfig(tokenizer_path="gpt2", truncation_side="right"), optimizer=OptimizerConfig( name="adamw", kwargs=dict(lr=1.0e-4, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6) @@ -60,7 +60,7 @@ def default_ilql_config(): return TRLConfig( train=TrainConfig( seq_length=64, - batch_size=128, + batch_size=32, epochs=100, total_steps=1000, checkpoint_interval=1000, From 88c5b97e15b20ead9dcb3f05874c12d4ee7fc558 Mon Sep 17 00:00:00 2001 From: cat-state Date: Mon, 13 Feb 2023 22:17:57 +0000 Subject: [PATCH 05/13] remove dbg --- examples/nemo_ilql_sentiments.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/examples/nemo_ilql_sentiments.py b/examples/nemo_ilql_sentiments.py index 4a790fa94..6067ff567 100644 --- a/examples/nemo_ilql_sentiments.py +++ b/examples/nemo_ilql_sentiments.py @@ -35,8 +35,6 @@ def get_positive_score(scores): def main(): - print(config) - return sentiment_fn = pipeline( "sentiment-analysis", "lvwerra/distilbert-imdb", From 01b24c4ceae68a65818a6bb41bd65e67dcb5924e Mon Sep 17 00:00:00 2001 From: cat-state Date: Mon, 13 Feb 2023 23:16:30 +0000 Subject: [PATCH 06/13] remove superfluous arg --- examples/ilql_sentiments.py | 2 +- examples/randomwalks/ppo_randomwalks.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/ilql_sentiments.py b/examples/ilql_sentiments.py index 7b96b08f5..ad75e5113 100644 --- a/examples/ilql_sentiments.py +++ b/examples/ilql_sentiments.py @@ -13,7 +13,7 @@ def get_positive_score(scores): return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"] -def main(hparams={}): +def main(): config = default_ilql_config() sentiment_fn = pipeline( diff --git a/examples/randomwalks/ppo_randomwalks.py b/examples/randomwalks/ppo_randomwalks.py index 1675b4865..3bcad20e6 100644 --- a/examples/randomwalks/ppo_randomwalks.py +++ b/examples/randomwalks/ppo_randomwalks.py @@ -5,6 +5,8 @@ def main(): config = default_ppo_config() + config.model.model_path = "gpt2" + metric_fn, prompts, *_ = generate_random_walks(seed=config.train.seed) trlx.train( From 2df3b6291967845ebe3fee40bc2662c59a507dd8 Mon Sep 17 00:00:00 2001 From: cat-state Date: Mon, 13 Feb 2023 23:17:15 +0000 Subject: [PATCH 07/13] remove superfluous arg --- examples/architext.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/architext.py b/examples/architext.py index d3461d4b9..6e31f3497 100644 --- a/examples/architext.py +++ b/examples/architext.py @@ -27,7 +27,7 @@ def reward_fn(samples, **kwargs): ] -def main(hparams={}): +def main(): config = default_ppo_config() trlx.train(model_path="architext/gptj-162M", reward_fn=reward_fn, prompts=prompts, config=config) From 49226d8f9a4a9e29f98aee6b1f9296951814f175 Mon Sep 17 00:00:00 2001 From: cat-state Date: Tue, 14 Feb 2023 00:35:22 +0000 Subject: [PATCH 08/13] t5 examples --- .../configs/ppo_config_cnn_daily.yml | 61 -------------- .../t5_summarize_daily_cnn.py | 79 ++++++++++++++++++- .../configs/ppo_config_summ_gptj.yml | 53 ------------- .../trlx_gptj_text_summarization.py | 72 +++++++++++++++-- trlx/trainer/nn/ppo_models.py | 2 +- 5 files changed, 143 insertions(+), 124 deletions(-) delete mode 100755 examples/summarize_daily_cnn/configs/ppo_config_cnn_daily.yml delete mode 100755 examples/summarize_rlhf/configs/ppo_config_summ_gptj.yml diff --git a/examples/summarize_daily_cnn/configs/ppo_config_cnn_daily.yml b/examples/summarize_daily_cnn/configs/ppo_config_cnn_daily.yml deleted file mode 100755 index 2134beadd..000000000 --- a/examples/summarize_daily_cnn/configs/ppo_config_cnn_daily.yml +++ /dev/null @@ -1,61 +0,0 @@ -train: - seq_length: 612 - epochs: 100 - total_steps: 100000 - batch_size: 12 - - checkpoint_interval: 10000 - eval_interval: 500 - save_best: False - - pipeline: "PromptPipeline" - trainer: "AcceleratePPOTrainer" - -model: - model_path: "google/flan-t5-large" - model_arch_type: "seq2seq" - num_layers_unfrozen: 2 - -tokenizer: - tokenizer_path: "google/flan-t5-large" - truncation_side: "right" - -optimizer: - name: "adamw" - kwargs: - lr: 1.0e-5 - betas: [0.9, 0.999] - eps: 1.0e-8 - weight_decay: 1.0e-6 - -scheduler: - name: "cosine_annealing" - kwargs: - T_max: 10000 - eta_min: 1.0e-6 - -method: - name: "ppoconfig" - num_rollouts: 512 - chunk_size: 12 - ppo_epochs: 4 - init_kl_coef: 0.05 - target: 6 - horizon: 10000 - gamma: 0.99 - lam: 0.95 - cliprange: 0.2 - cliprange_value: 0.2 - vf_coef: 1.0 - scale_reward: False - ref_mean: null - ref_std: null - cliprange_reward: 10 - gen_kwargs: - max_new_tokens: 100 - gen_experience_kwargs: - max_new_tokens: 100 - do_sample: True - temperature: 1.0 - top_k: 50 - top_p: 0.95 diff --git a/examples/summarize_daily_cnn/t5_summarize_daily_cnn.py b/examples/summarize_daily_cnn/t5_summarize_daily_cnn.py index 67863bf7d..318d38b1c 100755 --- a/examples/summarize_daily_cnn/t5_summarize_daily_cnn.py +++ b/examples/summarize_daily_cnn/t5_summarize_daily_cnn.py @@ -1,4 +1,3 @@ -import pathlib from typing import List from datasets import load_dataset @@ -6,7 +5,15 @@ from transformers import AutoTokenizer import trlx -from trlx.data.configs import TRLConfig +from trlx.data.configs import ( + ModelConfig, + OptimizerConfig, + SchedulerConfig, + TokenizerConfig, + TrainConfig, + TRLConfig, +) +from trlx.trainer.nn.ppo_models import PPOConfig try: import evaluate @@ -15,8 +22,72 @@ "To run this example, please install the `evaluate` and `nltk` packages" "by running `pip install evaluate`" ) -config_path = pathlib.Path(__file__).parent / "configs/ppo_config_cnn_daily.yml" -config = TRLConfig.load_yaml(config_path) +config = TRLConfig( + train=TrainConfig( + seq_length=612, + epochs=100, + total_steps=100000, + batch_size=12, + checkpoint_interval=10000, + eval_interval=500, + pipeline="PromptPipeline", + trainer="AcceleratePPOTrainer", + ), + model=ModelConfig( + model_path="google/flan-t5-large", + model_arch_type="seq2seq", + num_layers_unfrozen=2, + ), + tokenizer=TokenizerConfig( + tokenizer_path="google/flan-t5-large", + truncation_side="right", + ), + optimizer=OptimizerConfig( + name="adamw", + kwargs={ + "lr": 1.0e-5, + "betas": [0.9, 0.999], + "eps": 1.0e-8, + "weight_decay": 1.0e-6, + }, + ), + scheduler=SchedulerConfig( + name="cosine_annealing", + kwargs={ + "T_max": 10000, + "eta_min": 1.0e-6, + }, + ), + method=PPOConfig( + name="PPOConfig", + num_rollouts=512, + chunk_size=12, + ppo_epochs=4, + init_kl_coef=0.05, + target=6, + horizon=10000, + gamma=0.99, + lam=0.95, + cliprange=0.2, + cliprange_value=0.2, + vf_coef=1.0, + scale_reward=None, + ref_mean=None, + ref_std=None, + cliprange_reward=10, + gen_kwargs={ + "max_new_tokens": 100, + }, + gen_experience_kwargs={ + "max_new_tokens": 100, + "do_sample": True, + "temperature": 1.0, + "top_k": 50, + "top_p": 0.95, + }, + ), +) + meteor = evaluate.load("meteor") # use meteor as the reward function diff --git a/examples/summarize_rlhf/configs/ppo_config_summ_gptj.yml b/examples/summarize_rlhf/configs/ppo_config_summ_gptj.yml deleted file mode 100755 index 8055a49b5..000000000 --- a/examples/summarize_rlhf/configs/ppo_config_summ_gptj.yml +++ /dev/null @@ -1,53 +0,0 @@ -train: - seq_length: 550 - epochs: 50 - total_steps: 100000 - batch_size: 4 - - checkpoint_interval: 10000 - eval_interval: 200 - - pipeline: "PromptPipeline" - trainer: "AcceleratePPOTrainer" - -model: - model_path: "CarperAI/openai_summarize_tldr_sft" - num_layers_unfrozen: 8 - -tokenizer: - tokenizer_path: "gpt2" - truncation_side: "right" - -optimizer: - name: "adamw" - kwargs: - lr: 5.0e-6 - betas: [0.9, 0.999] - eps: 1.0e-8 - weight_decay: 0.01 - -scheduler: - name: "cosine_annealing" - kwargs: - T_max: 100000 - eta_min: 5.0e-6 - -method: - name: "ppoconfig" - num_rollouts: 128 - chunk_size: 16 - ppo_epochs: 4 - init_kl_coef: 0.1 - target: 6 - horizon: 10000 - gamma: 1 - lam: 0.95 - cliprange: 0.2 - cliprange_value: 0.2 - vf_coef: 0.2 - scale_reward: False - ref_mean: null - ref_std: null - cliprange_reward: 10 - gen_kwargs: - max_new_tokens: 50 diff --git a/examples/summarize_rlhf/trlx_gptj_text_summarization.py b/examples/summarize_rlhf/trlx_gptj_text_summarization.py index 3d9e3c5f3..9d0d8dd46 100755 --- a/examples/summarize_rlhf/trlx_gptj_text_summarization.py +++ b/examples/summarize_rlhf/trlx_gptj_text_summarization.py @@ -1,5 +1,4 @@ import os -import pathlib from typing import List import torch @@ -9,7 +8,15 @@ from transformers import AutoTokenizer import trlx -from trlx.data.configs import TRLConfig +from trlx.data.configs import ( + ModelConfig, + OptimizerConfig, + SchedulerConfig, + TokenizerConfig, + TrainConfig, + TRLConfig, +) +from trlx.trainer.nn.ppo_models import PPOConfig REWARD_CHECKPOINT_PATH = "reward_model/rm_checkpoint/pytorch_model.bin" if not os.path.exists(REWARD_CHECKPOINT_PATH): @@ -20,6 +27,64 @@ ) SFT_MODEL_PATH = "CarperAI/openai_summarize_tldr_sft" +config = TRLConfig( + train=TrainConfig( + seq_length=550, + epochs=50, + total_steps=100000, + batch_size=4, + checkpoint_interval=10000, + eval_interval=200, + pipeline="PromptPipeline", + trainer="AcceleratePPOTrainer", + ), + model=ModelConfig( + model_path="CarperAI/openai_summarize_tldr_sft", + num_layers_unfrozen=8, + ), + tokenizer=TokenizerConfig( + tokenizer_path="gpt2", + truncation_side="right", + ), + optimizer=OptimizerConfig( + name="adamw", + kwargs={ + "lr": 5.0e-6, + "betas": [0.9, 0.999], + "eps": 1.0e-8, + "weight_decay": 0.01, + }, + ), + scheduler=SchedulerConfig( + name="cosine_annealing", + kwargs={ + "T_max": 100000, + "eta_min": 5.0e-6, + }, + ), + method=PPOConfig( + name="PPOConfig", + num_rollouts=128, + chunk_size=16, + ppo_epochs=4, + init_kl_coef=0.1, + target=6, + horizon=10000, + gamma=1, + lam=0.95, + cliprange=0.2, + cliprange_value=0.2, + vf_coef=0.2, + scale_reward=None, + ref_mean=None, + ref_std=None, + cliprange_reward=10, + gen_kwargs={ + "max_new_tokens": 50, + }, + ), +) + if __name__ == "__main__": # Load the pre-trained reward model @@ -87,9 +152,6 @@ def reward_fn(samples: List[str], **kwargs): norms_scores = scores - original_scores return norms_scores - config_path = pathlib.Path(__file__).parent.joinpath("configs/ppo_config_summ_gptj.yml") - config = TRLConfig.load_yaml(config_path) - tokenizer = AutoTokenizer.from_pretrained(config.tokenizer.tokenizer_path) tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "left" diff --git a/trlx/trainer/nn/ppo_models.py b/trlx/trainer/nn/ppo_models.py index 6cbb64d25..5d304a14c 100644 --- a/trlx/trainer/nn/ppo_models.py +++ b/trlx/trainer/nn/ppo_models.py @@ -122,7 +122,7 @@ class PPOConfig(MethodConfig): cliprange: float cliprange_value: float vf_coef: float - scale_reward: str + scale_reward: Optional[str] ref_mean: Optional[float] ref_std: Optional[float] cliprange_reward: float From 5f18fc54823b70c5014841b24ea8e3bca9f2dc3b Mon Sep 17 00:00:00 2001 From: cat-state Date: Wed, 15 Feb 2023 22:37:41 +0000 Subject: [PATCH 09/13] fmt --- examples/nemo_ilql_inference.py | 14 ++++++++------ examples/nemo_ilql_sentiments.py | 26 ++++++++++++-------------- setup.cfg | 2 ++ 3 files changed, 22 insertions(+), 20 deletions(-) diff --git a/examples/nemo_ilql_inference.py b/examples/nemo_ilql_inference.py index 6d90f0bf6..28217daa0 100644 --- a/examples/nemo_ilql_inference.py +++ b/examples/nemo_ilql_inference.py @@ -16,15 +16,17 @@ default_config = default_ilql_config() nemo_ilql_train_cfg = TrainConfig( - **default_config.train.__dict__, - trainer="NeMoILQLTrainer", - trainer_kwargs=dict( - pretrained_model="/mnt/nvme/home/uwu/nemo-megatron-gpt-20B/", - megatron_cfg="megatron_20b.yaml", + **dict( + default_config.train.__dict__, + trainer="NeMoILQLTrainer", + trainer_kwargs=dict( + pretrained_model="/mnt/nvme/home/uwu/nemo-megatron-gpt-20B/", + megatron_cfg="megatron_20b.yaml", + ), ), ) -trl_config = TRLConfig(**default_config.__dict__, train=nemo_ilql_train_cfg) +trl_config = TRLConfig(**dict(default_config.__dict__, train=nemo_ilql_train_cfg)) def find_checkpoints(checkpoint_dir): diff --git a/examples/nemo_ilql_sentiments.py b/examples/nemo_ilql_sentiments.py index 6067ff567..bbf5b62e9 100644 --- a/examples/nemo_ilql_sentiments.py +++ b/examples/nemo_ilql_sentiments.py @@ -4,7 +4,7 @@ from transformers import pipeline import trlx -from trlx.data.configs import SchedulerConfig, TrainConfig, TRLConfig +from trlx.data.configs import TrainConfig, TRLConfig from trlx.data.default_configs import default_ilql_config @@ -16,22 +16,20 @@ def get_positive_score(scores): default_config = default_ilql_config() nemo_ilql_train_cfg = TrainConfig( - **default_config.train.__dict__, - seq_length=1024, - batch_size=512, - total_steps=200, - trainer="NeMoILQLTrainer", - trainer_kwargs=dict( - pretrained_model="/mnt/nvme/home/uwu/nemo-megatron-gpt-20B/", - megatron_cfg="megatron_20b.yaml", + **dict( + default_config.train.__dict__, + seq_length=1024, + batch_size=512, + total_steps=200, + trainer="NeMoILQLTrainer", + trainer_kwargs=dict( + pretrained_model="/mnt/nvme/home/uwu/nemo-megatron-gpt-20B/", + megatron_cfg="megatron_20b.yaml", + ), ) ) -scheduler_cfg = SchedulerConfig( - name="cosine_annealing", kwargs=dict(T_max=nemo_ilql_train_cfg.total_steps, eta_min=1.0e-6) -) - -config = TRLConfig(**default_config.__dict__, train=nemo_ilql_train_cfg, scheduler=scheduler_cfg) +config = TRLConfig(**dict(default_config.__dict__, train=nemo_ilql_train_cfg)) def main(): diff --git a/setup.cfg b/setup.cfg index 4a54f7747..2ff7a3d43 100644 --- a/setup.cfg +++ b/setup.cfg @@ -12,6 +12,8 @@ license = MIT packages = find: install_requires = accelerate>=0.12.0 + attrs>=22.1.0 + cattrs>=22.2.0 datasets deepspeed>=0.7.3 einops>=0.4.1 From 88272ab36f718abff3e8f16c7e742ee3b29d7044 Mon Sep 17 00:00:00 2001 From: cat-state Date: Thu, 23 Feb 2023 01:43:16 +0000 Subject: [PATCH 10/13] add evolve --- examples/nemo_ilql_inference.py | 22 +++++++++--------- examples/nemo_ilql_sentiments.py | 31 +++++++++++-------------- examples/ppo_sentiments.py | 4 ++-- examples/randomwalks/ppo_randomwalks.py | 3 +-- examples/sft_sentiments.py | 4 ++-- trlx/data/configs.py | 10 ++++++++ 6 files changed, 40 insertions(+), 34 deletions(-) diff --git a/examples/nemo_ilql_inference.py b/examples/nemo_ilql_inference.py index 28217daa0..f172f6fbb 100644 --- a/examples/nemo_ilql_inference.py +++ b/examples/nemo_ilql_inference.py @@ -9,25 +9,25 @@ from nemo.utils.model_utils import inject_model_parallel_rank from omegaconf.omegaconf import OmegaConf -from trlx.data.configs import TrainConfig, TRLConfig +from trlx.data.configs import TrainConfig from trlx.data.default_configs import default_ilql_config from trlx.trainer.nemo_ilql_trainer import ILQLGPT, megatron_trainer default_config = default_ilql_config() -nemo_ilql_train_cfg = TrainConfig( - **dict( - default_config.train.__dict__, - trainer="NeMoILQLTrainer", - trainer_kwargs=dict( - pretrained_model="/mnt/nvme/home/uwu/nemo-megatron-gpt-20B/", - megatron_cfg="megatron_20b.yaml", +trl_config = default_config.evolve( + train=TrainConfig( + **dict( + default_config.train.__dict__, + trainer="NeMoILQLTrainer", + trainer_kwargs=dict( + pretrained_model="/mnt/nvme/home/uwu/nemo-megatron-gpt-20B/", + megatron_cfg="megatron_20b.yaml", + ), ), - ), + ) ) -trl_config = TRLConfig(**dict(default_config.__dict__, train=nemo_ilql_train_cfg)) - def find_checkpoints(checkpoint_dir): checkpoints = glob(os.path.join(checkpoint_dir, "*", "*.ckpt")) diff --git a/examples/nemo_ilql_sentiments.py b/examples/nemo_ilql_sentiments.py index bbf5b62e9..d7d4f23eb 100644 --- a/examples/nemo_ilql_sentiments.py +++ b/examples/nemo_ilql_sentiments.py @@ -4,7 +4,6 @@ from transformers import pipeline import trlx -from trlx.data.configs import TrainConfig, TRLConfig from trlx.data.default_configs import default_ilql_config @@ -15,24 +14,22 @@ def get_positive_score(scores): default_config = default_ilql_config() -nemo_ilql_train_cfg = TrainConfig( - **dict( - default_config.train.__dict__, - seq_length=1024, - batch_size=512, - total_steps=200, - trainer="NeMoILQLTrainer", - trainer_kwargs=dict( - pretrained_model="/mnt/nvme/home/uwu/nemo-megatron-gpt-20B/", - megatron_cfg="megatron_20b.yaml", - ), - ) -) - -config = TRLConfig(**dict(default_config.__dict__, train=nemo_ilql_train_cfg)) +def main(hparams={}): + config = default_config.evolve( + train=dict( + seq_length=1024, + batch_size=512, + total_steps=200, + trainer="NeMoILQLTrainer", + trainer_kwargs=dict( + pretrained_model="/mnt/nvme/home/uwu/nemo-megatron-gpt-20B/", + megatron_cfg="megatron_20b.yaml", + ), + ) + ) + config = config.evolve(**hparams) -def main(): sentiment_fn = pipeline( "sentiment-analysis", "lvwerra/distilbert-imdb", diff --git a/examples/ppo_sentiments.py b/examples/ppo_sentiments.py index 8c33489c8..ab8c2c75a 100644 --- a/examples/ppo_sentiments.py +++ b/examples/ppo_sentiments.py @@ -17,8 +17,8 @@ def get_positive_score(scores): return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"] -def main(): - config = default_ppo_config() +def main(hparams={}): + config = default_ppo_config().evolve(**hparams) if torch.cuda.is_available(): device = int(os.environ.get("LOCAL_RANK", 0)) diff --git a/examples/randomwalks/ppo_randomwalks.py b/examples/randomwalks/ppo_randomwalks.py index 3bcad20e6..50981cbe2 100644 --- a/examples/randomwalks/ppo_randomwalks.py +++ b/examples/randomwalks/ppo_randomwalks.py @@ -4,8 +4,7 @@ def main(): - config = default_ppo_config() - config.model.model_path = "gpt2" + config = default_ppo_config().evolve(model=dict(model_path="gpt2")) metric_fn, prompts, *_ = generate_random_walks(seed=config.train.seed) diff --git a/examples/sft_sentiments.py b/examples/sft_sentiments.py index 9f22077f4..245b4d13b 100644 --- a/examples/sft_sentiments.py +++ b/examples/sft_sentiments.py @@ -13,8 +13,8 @@ def get_positive_score(scores): return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"] -def main(): - config = default_sft_config() +def main(hparams={}): + config = default_sft_config().evolve(**hparams) imdb = load_dataset("imdb", split="train+test") # Finetune on only positive reviews diff --git a/trlx/data/configs.py b/trlx/data/configs.py index a0a6feaec..193f99dfb 100644 --- a/trlx/data/configs.py +++ b/trlx/data/configs.py @@ -251,6 +251,16 @@ def to_dict(self): return data + def evolve(self, **kwargs) -> "TRLConfig": + """ + Evolve TRLConfig with new parameters. Can update nested parameters. + >>> config = trlx.data.default_configs.default_ilql_config() + >>> config = config.evolve(method=dict(gamma=0.99, gen_kwargs=dict(max_new_tokens=100)) + >>> config.method.gamma + 0.99 + """ + return TRLConfig.update(self.to_dict(), kwargs) + @classmethod def from_dict(cls, config: Dict): """ From 3e4731d0e446e5845b3b3d640867e9990082c60b Mon Sep 17 00:00:00 2001 From: cat-state Date: Thu, 23 Feb 2023 02:17:52 +0000 Subject: [PATCH 11/13] update merging --- trlx/data/configs.py | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/trlx/data/configs.py b/trlx/data/configs.py index 193f99dfb..4ebba925f 100644 --- a/trlx/data/configs.py +++ b/trlx/data/configs.py @@ -1,20 +1,22 @@ +from copy import deepcopy from dataclasses import dataclass, field -from typing import Any, Dict, Optional, Set +from typing import Any, Dict, Optional import yaml from trlx.data.method_configs import MethodConfig, get_method -def merge(base: Dict, update: Dict, updated: Set) -> Dict: +def merge(base: Dict, update: Dict) -> Dict: "Recursively updates a nested dictionary with new values" - for k, v in base.items(): - if k in update and isinstance(v, dict): - base[k] = merge(v, update[k], updated) - updated.add(k) - elif k in update: - base[k] = update[k] - updated.add(k) + + base = deepcopy(base) + + for k, v in update.items(): + if k in base and isinstance(v, dict) and isinstance(base[k], dict): + base[k] = merge(base[k], v) + else: + base[k] = v return base @@ -277,13 +279,12 @@ def from_dict(cls, config: Dict): @classmethod def update(cls, baseconfig: Dict, config: Dict): - updates = set() - merged = merge(baseconfig, config, updates) - for param in config: - if param not in updates: + if param not in baseconfig: raise ValueError(f"parameter {param} is not present in the config (typo or a wrong config)") + merged = merge(baseconfig, config) + return cls.from_dict(merged) def __str__(self): From 8c80056da266a0e198f47b5a91d7febe9b7e8648 Mon Sep 17 00:00:00 2001 From: cat-state Date: Thu, 23 Feb 2023 16:47:46 +0000 Subject: [PATCH 12/13] attrs --- trlx/data/configs.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/trlx/data/configs.py b/trlx/data/configs.py index 4ebba925f..74da82902 100644 --- a/trlx/data/configs.py +++ b/trlx/data/configs.py @@ -1,8 +1,8 @@ from copy import deepcopy -from dataclasses import dataclass, field from typing import Any, Dict, Optional import yaml +from attrs import define, field from trlx.data.method_configs import MethodConfig, get_method @@ -21,7 +21,7 @@ def merge(base: Dict, update: Dict) -> Dict: return base -@dataclass +@define class ModelConfig: """ Config for a model. @@ -61,7 +61,7 @@ def from_dict(cls, config: Dict[str, Any]): return cls(**config) -@dataclass +@define class TokenizerConfig: """ Config for a model. @@ -85,7 +85,7 @@ def from_dict(cls, config: Dict[str, Any]): return cls(**config) -@dataclass +@define class OptimizerConfig: """ Config for an optimizer. @@ -98,14 +98,14 @@ class OptimizerConfig: """ name: str - kwargs: Dict[str, Any] = field(default_factory=dict) + kwargs: Dict[str, Any] = field(factory=dict) @classmethod def from_dict(cls, config: Dict[str, Any]): return cls(**config) -@dataclass +@define class SchedulerConfig: """ Config for a learning rate scheduler. @@ -118,14 +118,14 @@ class SchedulerConfig: """ name: str - kwargs: Dict[str, Any] = field(default_factory=dict) + kwargs: Dict[str, Any] = field(factory=dict) @classmethod def from_dict(cls, config: Dict[str, Any]): return cls(**config) -@dataclass +@define class TrainConfig: """ Config for train job on model. @@ -193,7 +193,7 @@ class TrainConfig: pipeline: str # One of the pipelines in framework.pipeline trainer: str # One of the trainers - trainer_kwargs: Dict[str, Any] = field(default_factory=dict) # Extra keyword arguments for the trainer + trainer_kwargs: Dict[str, Any] = field(factory=dict) # Extra keyword arguments for the trainer project_name: str = "trlx" entity_name: Optional[str] = None @@ -213,7 +213,7 @@ def from_dict(cls, config: Dict[str, Any]): return cls(**config) -@dataclass +@define class TRLConfig: """ Top level config for trlX. Loads configs and can be converted to dictionary. From e5fb2387cd227cba46080e46261d3b4b4fb597e3 Mon Sep 17 00:00:00 2001 From: cat-state Date: Thu, 23 Feb 2023 19:25:46 +0000 Subject: [PATCH 13/13] attrs fix --- trlx/data/configs.py | 41 ++++---------------------- trlx/data/default_configs.py | 8 ++--- trlx/data/method_configs.py | 28 +++++++----------- trlx/models/modeling_ilql.py | 4 +-- trlx/models/modeling_ppo.py | 3 +- trlx/trainer/accelerate_sft_trainer.py | 5 ++-- 6 files changed, 25 insertions(+), 64 deletions(-) diff --git a/trlx/data/configs.py b/trlx/data/configs.py index 74da82902..6f5d8d346 100644 --- a/trlx/data/configs.py +++ b/trlx/data/configs.py @@ -1,8 +1,9 @@ from copy import deepcopy -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Union import yaml from attrs import define, field +from cattrs import structure, unstructure from trlx.data.method_configs import MethodConfig, get_method @@ -80,10 +81,6 @@ class TokenizerConfig: padding_side: str = "left" truncation_side: str = "right" - @classmethod - def from_dict(cls, config: Dict[str, Any]): - return cls(**config) - @define class OptimizerConfig: @@ -100,10 +97,6 @@ class OptimizerConfig: name: str kwargs: Dict[str, Any] = field(factory=dict) - @classmethod - def from_dict(cls, config: Dict[str, Any]): - return cls(**config) - @define class SchedulerConfig: @@ -120,10 +113,6 @@ class SchedulerConfig: name: str kwargs: Dict[str, Any] = field(factory=dict) - @classmethod - def from_dict(cls, config: Dict[str, Any]): - return cls(**config) - @define class TrainConfig: @@ -208,10 +197,6 @@ class TrainConfig: seed: int = 1000 - @classmethod - def from_dict(cls, config: Dict[str, Any]): - return cls(**config) - @define class TRLConfig: @@ -242,16 +227,7 @@ def to_dict(self): """ Convert TRLConfig to dictionary. """ - data = { - "method": self.method.__dict__, - "model": self.model.__dict__, - "optimizer": self.optimizer.__dict__, - "scheduler": self.scheduler.__dict__, - "tokenizer": self.tokenizer.__dict__, - "train": self.train.__dict__, - } - - return data + return unstructure(self) def evolve(self, **kwargs) -> "TRLConfig": """ @@ -261,21 +237,14 @@ def evolve(self, **kwargs) -> "TRLConfig": >>> config.method.gamma 0.99 """ - return TRLConfig.update(self.to_dict(), kwargs) + return TRLConfig.from_dict(merge(self.to_dict(), kwargs)) @classmethod def from_dict(cls, config: Dict): """ Convert dictionary to TRLConfig. """ - return cls( - method=get_method(config["method"]["name"]).from_dict(config["method"]), - model=ModelConfig.from_dict(config["model"]), - tokenizer=TokenizerConfig.from_dict(config["tokenizer"]), - optimizer=OptimizerConfig.from_dict(config["optimizer"]), - scheduler=SchedulerConfig.from_dict(config["scheduler"]), - train=TrainConfig.from_dict(config["train"]), - ) + return structure(config, cls) @classmethod def update(cls, baseconfig: Dict, config: Dict): diff --git a/trlx/data/default_configs.py b/trlx/data/default_configs.py index 3f1196b11..4216a98aa 100644 --- a/trlx/data/default_configs.py +++ b/trlx/data/default_configs.py @@ -1,7 +1,4 @@ -from ..trainer.accelerate_sft_trainer import SFTConfig -from ..trainer.nn.ilql_models import ILQLConfig -from ..trainer.nn.ppo_models import PPOConfig -from .configs import ( +from trlx.data.configs import ( ModelConfig, OptimizerConfig, SchedulerConfig, @@ -9,6 +6,9 @@ TrainConfig, TRLConfig, ) +from trlx.models.modeling_ilql import ILQLConfig +from trlx.models.modeling_ppo import PPOConfig +from trlx.trainer.accelerate_sft_trainer import SFTConfig def default_ppo_config(): diff --git a/trlx/data/method_configs.py b/trlx/data/method_configs.py index 435ce13cf..40a74bb93 100644 --- a/trlx/data/method_configs.py +++ b/trlx/data/method_configs.py @@ -1,34 +1,22 @@ -import sys -from dataclasses import dataclass from typing import Any, Dict +from attrs import asdict, define +from cattrs import register_structure_hook, register_unstructure_hook, unstructure + # specifies a dictionary of method configs _METHODS: Dict[str, Any] = {} # registry -def register_method(name): +def register_method(cls): """Decorator used register a method config Args: name: Name of the method """ - - def register_class(cls, name): - _METHODS[name] = cls - setattr(sys.modules[__name__], name, cls) - return cls - - if isinstance(name, str): - name = name.lower() - return lambda c: register_class(c, name) - - cls = name - name = cls.__name__ - register_class(cls, name.lower()) - + _METHODS[cls.__name__.lower()] = cls return cls -@dataclass +@define @register_method class MethodConfig: """ @@ -54,3 +42,7 @@ def get_method(name: str) -> MethodConfig: return _METHODS[name] else: raise Exception("Error: Trying to access a method that has not been registered") + + +register_structure_hook(MethodConfig, lambda obj, _: get_method(obj["name"])(**obj)) +register_unstructure_hook(MethodConfig, lambda obj: {**asdict(obj), "name": obj.__class__.__name__}) diff --git a/trlx/models/modeling_ilql.py b/trlx/models/modeling_ilql.py index 5eab697e4..3e9bb9976 100644 --- a/trlx/models/modeling_ilql.py +++ b/trlx/models/modeling_ilql.py @@ -1,7 +1,6 @@ import gc import os from copy import deepcopy -from dataclasses import dataclass from functools import reduce from itertools import chain @@ -10,6 +9,7 @@ import torch import torch.nn.functional as F import transformers +from attrs import define from torch import nn from torchtyping import TensorType @@ -44,7 +44,7 @@ def batched_index_select( return x.gather(dim=dim, index=idxs) -@dataclass +@define @register_method class ILQLConfig(MethodConfig): tau: float diff --git a/trlx/models/modeling_ppo.py b/trlx/models/modeling_ppo.py index 787286123..f72101e8b 100644 --- a/trlx/models/modeling_ppo.py +++ b/trlx/models/modeling_ppo.py @@ -8,6 +8,7 @@ import torch import torch.nn as nn import transformers +from attrs import define from torchtyping import TensorType from transformers.modeling_outputs import ModelOutput from transformers.models.bloom import modeling_bloom @@ -69,7 +70,7 @@ def update(self, current: float, n_steps: int): # PPO Configs -@dataclass +@define @register_method class PPOConfig(MethodConfig): """ diff --git a/trlx/trainer/accelerate_sft_trainer.py b/trlx/trainer/accelerate_sft_trainer.py index e061896e6..01e0b6929 100644 --- a/trlx/trainer/accelerate_sft_trainer.py +++ b/trlx/trainer/accelerate_sft_trainer.py @@ -1,5 +1,4 @@ -from dataclasses import dataclass - +from attrs import define from transformers import AutoModelForCausalLM from trlx.data.configs import TRLConfig @@ -8,7 +7,7 @@ from trlx.trainer.accelerate_base_trainer import AccelerateRLTrainer -@dataclass +@define @register_method class SFTConfig(MethodConfig): """