Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 42 additions & 15 deletions benchmarks/continual_eval_checkpoints.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""Evaluating checkpoints obtained from training using the dpo_continual script."""

import glob
import os
import re

import torch
import wandb as wb
from dataloading import init_continual_dataset
from datasets import Dataset
from dpo.continual_dpo_trainer import (
Expand All @@ -17,22 +17,18 @@
AutoTokenizer,
)
from trl import (
DPOConfig,
ModelConfig,
ScriptArguments,
TrlParser,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
)
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE

import wandb as wb


def main(
script_args: ScriptArguments,
training_args: DPOConfig,
script_args: ContinualDPOArguments,
training_args: ContinualDPOConfig,
model_args: ModelConfig,
) -> None:
# Determine torch dtype and quantization configs
Expand All @@ -41,6 +37,9 @@ def main(
if model_args.torch_dtype in ['auto', None]
else getattr(torch, model_args.torch_dtype)
)
if script_args.wandb_run_name is not None:
training_args.run_name = script_args.wandb_run_name

quantization_config = get_quantization_config(model_args)

# Model & Tokenizer Setup
Expand Down Expand Up @@ -87,14 +86,26 @@ def main(

# Validate reward model paths if provided
for i, _ in enumerate(continual_dataset):
reward_path = os.path.join(training_args.reward_model_path, str(i))
reward_path = training_args.reward_model_path + '_' + str(i)
if not os.path.exists(reward_path):
raise FileNotFoundError(
f'Reward model not found for dataset {i} at {reward_path}'
)

checkpoint_paths = glob.glob(f'{script_args.checkpoint_dir}/*/*')
checkpoint_paths = sorted([ch for ch in checkpoint_paths if 'checkpoint' in ch])

def extract_indices(path):
match = re.search(r'dataset-(\d+)/checkpoint-(\d+)', path)
if match:
dataset_idx = int(match.group(1))
checkpoint_idx = int(match.group(2))
return (dataset_idx, checkpoint_idx)
else:
return (float('inf'), float('inf')) # in case of unexpected format

checkpoint_paths = [ch for ch in checkpoint_paths if 'checkpoint' in ch]
checkpoint_paths.sort(key=extract_indices)
print('checkpoint_paths', checkpoint_paths)

# Checkpoint loop
for checkpoint_path in checkpoint_paths:
Expand All @@ -103,14 +114,20 @@ def main(
print(
f'Evaluating checkpoint: {checkpoint_step} trained on dataset: {dataset_name} on all tasks'
)
adapter_name = dataset_name + checkpoint_step
model.load_adapter(checkpoint_path, adapter_name=adapter_name)
# adapter_name = dataset_name + checkpoint_step
# model.load_adapter(checkpoint_path, adapter_name=adapter_name)
model = AutoModelForCausalLM.from_pretrained(
checkpoint_path,
trust_remote_code=model_args.trust_remote_code,
**model_kwargs,
)
metrics = {}

# Task Loop
for i, dataset in enumerate(continual_dataset):
print('task', i)
reward_model = AutoModelForSequenceClassification.from_pretrained(
training_args.reward_model_path + f'/{str(i)}', num_labels=1
training_args.reward_model_path + f'_{str(i)}', num_labels=1
)

training_args.output_dir = f'{output_dir}/dataset-{i}'
Expand All @@ -129,8 +146,18 @@ def main(
ev_metrics = trainer.evaluate()
ev_metrics = {f'dataset-{i}/' + k: v for k, v in ev_metrics.items()}
metrics.update(ev_metrics)

wb.log(metrics) # type: ignore[attr-defined]
if training_args.local_rank in (None, -1, 0):
wb.log({f'task/{dataset_name}/{k}': v for k, v in ev_metrics.items()})

# If using DeepSpeed through Accelerate, tear down the engine after training.
if hasattr(trainer, 'deepspeed') and trainer.deepspeed is not None:
# Remove reference to the DeepSpeed engine to allow proper cleanup.
del trainer.deepspeed
# Free cached GPU memory.
torch.cuda.empty_cache()

if training_args.local_rank in (None, -1, 0):
wb.log(metrics) # type: ignore[attr-defined]

print('Evaluation completed for all tasks and checkpoints!')

Expand Down
5 changes: 4 additions & 1 deletion benchmarks/dataloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,11 @@ def init_continual_dataset(
data = ContinualAlignmentDataset.from_json(dataset)
except OSError: # need to try downloading from hub
try:
# print(f'Downloading {json_name} from Hugging Face Hub...')
local_path = hf_hub_download(
repo_id=dataset, filename='dataset.json', repo_type='dataset'
repo_id=f'LifelongAlignment/{dataset}',
filename='data.json',
repo_type='dataset',
)
data = ContinualAlignmentDataset.from_json(local_path)
except Exception as e:
Expand Down
21 changes: 21 additions & 0 deletions benchmarks/dpo/accelerate_configs/deepspeed_zero2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
deepspeed_multinode_launcher: standard
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: false
zero_stage: 2
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: 'bf16'
num_machines: 1
num_processes: 2
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
2 changes: 1 addition & 1 deletion benchmarks/dpo/accelerate_configs/deepspeed_zero3.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 1 # TODO change to whatever number of gpus is used
num_processes: 8 # TODO change to whatever number of gpus is used
rdzv_backend: static
same_network: true
tpu_env: []
Expand Down
124 changes: 79 additions & 45 deletions benchmarks/dpo/continual_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@
import pandas as pd
import torch
import torch.nn as nn
import wandb as wb
from accelerate import Accelerator, PartialState
from accelerate.utils import gather_object
from datasets import Dataset
from rich.console import Console
from rich.table import Table
from torch.utils.data import DataLoader
from transformers import (
BaseImageProcessor,
Expand All @@ -36,8 +39,6 @@
)
from typing_extensions import override

import wandb as wb


@dataclass
class ContinualDPOArguments(ScriptArguments):
Expand Down Expand Up @@ -285,7 +286,10 @@ def evaluate_policy(self) -> dict:

with torch.no_grad():
if self.eval_policy_dataloader is not None:
for batch in self.eval_policy_dataloader:
for idx, batch in enumerate(self.eval_policy_dataloader):
print(
f'Processing batch {idx} out of {len(self.eval_policy_dataloader)}'
)
query = batch['input_ids'].to(self.accelerator.device)
context_length = query.shape[1]
with unwrap_model_for_generation(
Expand Down Expand Up @@ -324,71 +328,101 @@ def log(
train_eval = 'train' if 'loss' in logs else 'eval'
print(f'Logging {train_eval} metrics...')
if train_eval == 'eval':
print('Computing policy metrics...')
eval_policy_metrics = self.evaluate_policy()
logs.update(eval_policy_metrics)
if self.reward_model is not None:
print('Computing policy metrics...')
eval_policy_metrics = self.evaluate_policy()
logs.update(eval_policy_metrics)

# TODO: Only generation sample completions every x steps
do_generate_completions = True
if do_generate_completions:
print('Generating completions...')
self._generate_completions()
torch.cuda.empty_cache()

return super().log(logs, start_time)

def _generate_completions(self) -> None:
# Config from: https://github.com/huggingface/trl/blob/56e57662053e2d0cc6302dad404820b0c0ec6a91/trl/trainer/ppo_trainer.py#L688
# generation_config = GenerationConfig(
# max_new_tokens=53,
# temperature=(0.01 + 1e-7),
# top_k=0.0,
# top_p=1.0,
# do_sample=True,
# )
generation_config = GenerationConfig(
max_new_tokens=53,
temperature=(0.01 + 1e-7),
max_new_tokens=self.args.response_length,
temperature=(self.args.temperature + 1e-7),
top_k=0.0,
top_p=1.0,
do_sample=True,
)

self.model.eval()
table = defaultdict(list)
with torch.no_grad():
with unwrap_model_for_generation(
self.model,
self.accelerator,
gather_deepspeed3_params=None,
) as unwrapped_model:
for batch in self.eval_dataloader:
query = batch['input_ids']
context_length = query.shape[1]
query_response, _ = batch_generation(
unwrapped_model,
query,
query.shape[0],
self.processing_class.pad_token_id,
generation_config,
)
response = query_response[:, context_length:]
postprocessed_response = response
postprocessed_query_response = torch.cat(
(query, postprocessed_response), 1
)
_, score, _ = get_reward(
self.reward_model,
postprocessed_query_response,
self.processing_class.pad_token_id,
context_length,
)
if self.eval_policy_dataloader is not None:
for batch in self.eval_policy_dataloader:
query = batch['input_ids']
context_length = query.shape[1]
query_response, _ = batch_generation(
unwrapped_model,
query,
query.shape[0],
self.processing_class.pad_token_id,
generation_config,
)
response = query_response[:, context_length:]
postprocessed_response = response
postprocessed_query_response = torch.cat(
(query, postprocessed_response), 1
)
_, score, _ = get_reward(
self.reward_model,
postprocessed_query_response,
self.processing_class.pad_token_id,
context_length,
)

queries = gather_object(
self.processing_class.batch_decode(
query, skip_special_tokens=True
queries = gather_object(
self.processing_class.batch_decode(
query, skip_special_tokens=True
)
)
)
responses = gather_object(
self.processing_class.batch_decode(postprocessed_response)
)
scores = (
self.accelerator.gather_for_metrics(score).float().cpu().numpy()
)
table['query'].extend(queries)
table['model response'].extend(responses)
table['score'].extend(scores)
break
responses = gather_object(
self.processing_class.batch_decode(postprocessed_response)
)
scores = (
self.accelerator.gather_for_metrics(score)
.float()
.cpu()
.numpy()
)
table['query'].extend(queries)
table['model response'].extend(responses)
table['score'].extend(scores)
break

self.model.train()
df = pd.DataFrame(table)
if self.accelerator.is_main_process and wb.run is not None:
wb.log({'completions': wb.Table(dataframe=df)})

if self.accelerator.is_main_process or self.accelerator is None:
print_rich_table(df.iloc[0 : 0 + 5])
if wb.run is not None:
wb.log({'completions': wb.Table(dataframe=df)})


def print_rich_table(df: pd.DataFrame) -> Table:
console = Console()
table = Table(show_lines=True)
for column in df.columns:
table.add_column(column)
for _, row in df.iterrows():
table.add_row(*row.astype(str).tolist())
console.print(table)
17 changes: 8 additions & 9 deletions benchmarks/dpo/dpo_continual.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,7 @@
import os

import torch
from continual_dpo_trainer import (
ContinualDPOArguments,
ContinualDPOConfig,
ContinualDPOTrainer,
)
import wandb as wb
from datasets import Dataset
from transformers import (
AutoModelForCausalLM,
Expand All @@ -23,7 +19,6 @@
)
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE

import wandb as wb
from benchmarks.dataloading import init_continual_dataset
from benchmarks.dpo.continual_dpo_trainer import (
ContinualDPOArguments,
Expand Down Expand Up @@ -104,7 +99,7 @@ def main(
# first check the hub if the model is present
try:
AutoModelForSequenceClassification.from_pretrained(
reward_path, num_labels=1
reward_path, num_labels=1, use_cache=True
)
except:
# if not found in the hub, check the local path
Expand Down Expand Up @@ -137,6 +132,9 @@ def main(
peft_config=peft_config,
)

# if i == 0:
# trainer.save_model(os.path.join(training_args.output_dir, 'checkpoint-0'))

# TODO will throw Invalidate trace cache @ step 10: expected module 11, but got module 19
# https://github.com/deepspeedai/DeepSpeed/issues/6870
# Fix with deepspeed fix release
Expand All @@ -152,8 +150,9 @@ def main(
print(f'eval/dataset/{i}')
trainer.log_metrics(f'eval/dataset/{i}', metrics)
trainer.save_metrics(f'eval', metrics)
wb.log({'eval': {'last': metrics}}) # type: ignore[attr-defined]
wb.log({f'task/{current_dataset_name}/last': metrics}) # type: ignore[attr-defined]
if training_args.local_rank in (None, -1, 0):
wb.log({'eval': {'last': metrics}}) # type: ignore[attr-defined]
wb.log({f'task/{current_dataset_name}/last': metrics}) # type: ignore[attr-defined]

# Save and push to hub
trainer.save_model(os.path.join(training_args.output_dir, 'last'))
Expand Down
Loading
Loading