Skip to content
Draft
13 changes: 13 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,19 @@ trainer = trlx.train('EleutherAI/gpt-j-6B', dataset=[('dolphins', 'geese'), (1.0
trainer.generate(**tokenizer('Q: Who rules the world? A:', return_tensors='pt'), do_sample=True)
```

#### Configure Hyperparameters

```python
from trlx.data.default_configs import default_ppo_config, TrainConfig

config = default_ppo_config()
config.model.model_path = 'EleutherAI/gpt-neox-20b'
config.train.seq_length = 32
config.train.batch_size = 16

trainer = trlx.train(config=config, reward_fn=lambda samples, **kwargs: [float(int(sample)) for sample in samples])
```

#### Save the resulting model to a Hugging Face pretrained language model. (Ready to upload to the Hub!)

```python
Expand Down
50 changes: 0 additions & 50 deletions configs/ilql_config.yml

This file was deleted.

52 changes: 0 additions & 52 deletions configs/nemo_ilql_config.yml

This file was deleted.

56 changes: 0 additions & 56 deletions configs/ppo_config.yml

This file was deleted.

56 changes: 0 additions & 56 deletions configs/ppo_gptj.yml

This file was deleted.

41 changes: 0 additions & 41 deletions configs/sft_config.yml

This file was deleted.

14 changes: 3 additions & 11 deletions examples/architext.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
# Toy example of optimizing textual interior designs to output the least number of rooms
# Also see https://architext.design/
import pathlib

import yaml

import trlx
from trlx.data.configs import TRLConfig
from trlx.data.default_configs import default_ppo_config


def reward_fn(samples, **kwargs):
Expand All @@ -30,13 +26,9 @@ def reward_fn(samples, **kwargs):
"[prompt] the kitchen is not adjacent to the bathroom [layout]",
]

config_path = pathlib.Path(__file__).parent.joinpath("../configs/ppo_config.yml")
with config_path.open() as f:
default_config = yaml.safe_load(f)


def main(hparams={}):
config = TRLConfig.update(default_config, hparams)
def main():
config = default_ppo_config()

trlx.train(model_path="architext/gptj-162M", reward_fn=reward_fn, prompts=prompts, config=config)

Expand Down
13 changes: 3 additions & 10 deletions examples/ilql_sentiments.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,20 @@
import os
import pathlib
from typing import Dict, List

import yaml
from datasets import load_dataset
from transformers import pipeline

import trlx
from trlx.data.configs import TRLConfig
from trlx.data.default_configs import default_ilql_config


def get_positive_score(scores):
"Extract value associated with a positive sentiment from pipeline's output"
return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"]


config_path = pathlib.Path(__file__).parent.joinpath("../configs/ilql_config.yml")
with config_path.open() as f:
default_config = yaml.safe_load(f)


def main(hparams={}):
config = TRLConfig.update(default_config, hparams)
def main():
config = default_ilql_config()

sentiment_fn = pipeline(
"sentiment-analysis",
Expand Down
20 changes: 16 additions & 4 deletions examples/nemo_ilql_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,31 @@
import sys
from glob import glob

import yaml
from nemo.collections.nlp.modules.common.megatron.megatron_init import (
fake_initialize_model_parallel,
)
from nemo.utils.app_state import AppState
from nemo.utils.model_utils import inject_model_parallel_rank
from omegaconf.omegaconf import OmegaConf

from trlx.data.configs import TRLConfig
from trlx.data.configs import TrainConfig
from trlx.data.default_configs import default_ilql_config
from trlx.trainer.nemo_ilql_trainer import ILQLGPT, megatron_trainer

default_config = yaml.safe_load(open(os.path.dirname(__file__) + "/../configs/nemo_ilql_config.yml"))
default_config = default_ilql_config()

trl_config = default_config.evolve(
train=TrainConfig(
**dict(
default_config.train.__dict__,
trainer="NeMoILQLTrainer",
trainer_kwargs=dict(
pretrained_model="/mnt/nvme/home/uwu/nemo-megatron-gpt-20B/",
megatron_cfg="megatron_20b.yaml",
),
),
)
)


def find_checkpoints(checkpoint_dir):
Expand All @@ -23,7 +36,6 @@ def find_checkpoints(checkpoint_dir):


def main(megatron_cfg_path, checkpoint_path):
trl_config = TRLConfig.update(default_config, {})
ilql_config = trl_config.method

megatron_cfg = OmegaConf.load(megatron_cfg_path)
Expand Down
Loading