Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions benchmarks/cppo/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ uv sync --group benchmarks
uv run benchmarks/cppo/cppo.py \
--dataset_name benchmarks/continual_data_debug.json \
--sft_model_path Qwen/Qwen2-0.5B-Instruct \
--value_model_path Shahradmz/Qwen2-0.5B-Instruct_continual_data_debug_REWARD_0 \
--reward_model_path Shahradmz/Qwen2-0.5B-Instruct_continual_data_debug_REWARD \
--value_model_path LifelongAlignment/Qwen2.5-0.5B-Instruct_CPPO_REWARD_0 \
--reward_model_path LifelongAlignment/Qwen2.5-0.5B-Instruct_CPPO_REWARD \
--learning_rate 5.0e-6 \
--num_train_epochs 1 \
--gradient_accumulation_steps 8 \
Expand All @@ -31,7 +31,7 @@ uv run benchmarks/cppo/cppo.py \
--no_remove_unused_columns \
--use_peft \
--lora_r 32 \
--lora_alpha 16 \
--lora_alpha 16
--push_to_hub True
```

Expand Down
2 changes: 1 addition & 1 deletion benchmarks/cppo/accelerate_configs/deepspeed_zero2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ machine_rank: 0
main_training_function: main
mixed_precision: 'bf16'
num_machines: 1
num_processes: 2
num_processes: 4
rdzv_backend: static
same_network: true
tpu_env: []
Expand Down
90 changes: 65 additions & 25 deletions benchmarks/cppo/cppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
get_peft_config,
get_quantization_config,
)
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE
from trl import setup_chat_format

import wandb as wb
from benchmarks.dataloading import init_continual_dataset
Expand Down Expand Up @@ -50,11 +50,6 @@ def main(

# Load main model and (optionally) reference model
model = str(training_args.sft_model_path)
policy = AutoModelForCausalLM.from_pretrained(
training_args.sft_model_path,
trust_remote_code=model_args.trust_remote_code,
**model_kwargs,
)
peft_config = get_peft_config(model_args)
if peft_config is None:
ref_policy = AutoModelForCausalLM.from_pretrained(
Expand All @@ -65,22 +60,11 @@ def main(
else:
ref_policy = None

# Load value model and policy model (main model)
value_model = AutoModelForSequenceClassification.from_pretrained(
script_args.value_model_path,
trust_remote_code=model_args.trust_remote_code,
num_labels=1,
)

# Load tokenizer and set chat template if needed
tokenizer = AutoTokenizer.from_pretrained(
training_args.sft_model_path,
trust_remote_code=model_args.trust_remote_code,
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
if tokenizer.chat_template is None:
tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE

# Initialize continual dataset
continual_dataset: list[dict[str, Dataset]] = init_continual_dataset(
Expand Down Expand Up @@ -114,6 +98,34 @@ def main(
old_logprobs, old_rewards = None, None

for i, dataset in enumerate(continual_dataset):
# Load main model and (optionally) reference model
if i == 0:
model_path = training_args.sft_model_path
value_model_path = script_args.value_model_path
else:
model_path = os.path.join(training_args.output_dir, 'last')
value_model_path = os.path.join(training_args.output_dir, 'last', 'value_model')
policy = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=model_path,
trust_remote_code=model_args.trust_remote_code,
**model_kwargs,
)

# Load value model and policy model (main model)
try:
value_model = AutoModelForSequenceClassification.from_pretrained(
value_model_path,
trust_remote_code=model_args.trust_remote_code,
num_labels=1,
)
except OSError:
# Maybe it was saved as safetensors?
value_model = AutoModelForSequenceClassification.from_pretrained(
value_model_path,
trust_remote_code=model_args.trust_remote_code,
num_labels=1,
from_tf=True, # or use `subfolder="safetensors"` if you saved a .safetensors file
)
# Build custom repository name for this task
custom_repo_name = (
model.split('/')[-1] + '_' + clean_dataset_name + '_CPPO_' + str(i)
Expand All @@ -127,6 +139,22 @@ def main(
training_args.reward_model_path + '_' + str(i), num_labels=1
)

for idx, _model in enumerate([policy, value_model, reward_model]):
# Align padding tokens between tokenizer and model
_model.config.pad_token_id = tokenizer.pad_token_id

# Use ChatML format if the tokenizer doesn't already have a chat template
if tokenizer.chat_template is None:
updated_model, updated_tokenizer = setup_chat_format(_model, tokenizer)
# Actually store the updated model
if idx == 0:
policy = updated_model
elif idx == 1:
value_model = updated_model
else:
reward_model = updated_model
tokenizer = updated_tokenizer

################
# Training and Evaluation
################
Expand Down Expand Up @@ -163,21 +191,33 @@ def main(
trainer.log_metrics(f'eval/dataset/{i}', metrics)
trainer.save_metrics('eval', metrics)

# Log metrics to WandB
wb.log({'eval': {'last': metrics}})
wb.log({f'task/{custom_repo_name}/last': metrics})
if training_args.local_rank in (None, -1, 0):
# Log metrics to WandB
wb.log({'eval': {'dataset': i, 'last': metrics}})
wb.log({f'task/{custom_repo_name}/dataset/{i}': metrics})

# Save model checkpoint and optionally push
if not training_args.push_to_hub:
trainer.save_model(os.path.join(training_args.output_dir, 'last'))
else:
last_dir = os.path.join(training_args.output_dir, 'last')
policy.save_pretrained(last_dir)
tokenizer.save_pretrained(last_dir)

value_model_dir = os.path.join(last_dir, 'value_model')
os.makedirs(value_model_dir, exist_ok=True)
value_model.save_pretrained(value_model_dir,
safe_serialization=False)

trainer.accelerator.wait_for_everyone()

if training_args.push_to_hub:
trainer.push_to_hub(
model_name=custom_repo_name,
dataset_name='CPPO_' + clean_dataset_name + '_' + str(i),
dataset_name='Continual_CPPO_' + clean_dataset_name + '_' + str(i),
)

ref_policy = None
old_logprobs, old_rewards = trainer.old_logprobs, trainer.old_rewards
if hasattr(trainer, 'deepspeed') and trainer.deepspeed is not None:
del trainer.deepspeed
torch.cuda.empty_cache()

print('Training completed for all tasks!')

Expand Down
111 changes: 55 additions & 56 deletions benchmarks/cppo/cppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,13 +118,6 @@


class CPPOTrainer(PPOTrainer):
# Shared accelerator instance across all trainer instances
shared_accelerator: Optional[Accelerator] = None
current_task_index: Optional[int] = None
policy_value_models: Any # the policy and value model wrapper
ds_wrapped_models: Any # TODO work with this after deepspeed is initialized
accelerator: Accelerator # now non-optional after creation
ref_model: Optional[Union[PreTrainedModel, nn.Module]] = None

def __init__(
self,
Expand Down Expand Up @@ -153,6 +146,14 @@
old_logprobs: Optional[Tensor] = None,
old_rewards: Optional[Tensor] = None,
):
self.shared_accelerator: Optional[Accelerator] = None
self.current_task_index: Optional[int] = None
self.policy_value_models: Any = None # the policy and value model wrapper
self.ds_wrapped_models: Any = None # TODO work with this after deepspeed is initialized
self.accelerator: Accelerator = None # now non-optional after creation
self.ref_model: Optional[Union[PreTrainedModel, nn.Module]] = None


# Basic setup and validation
if args is None:
raise ValueError('`args` cannot be None')
Expand All @@ -175,18 +176,18 @@
# Initialize task tracking
self._stored_metrics: Dict = defaultdict(lambda: defaultdict(list))
self.current_task = (
f'task_{CPPOTrainer.current_task_index}'
if CPPOTrainer.current_task_index is not None
f'task_{self.current_task_index}'
if self.current_task_index is not None
else 'task_0'
)

# Set up task index tracking
is_first_task = False
if CPPOTrainer.current_task_index is None:
CPPOTrainer.current_task_index = 0
if self.current_task_index is None:
self.current_task_index = 0
is_first_task = True
else:
CPPOTrainer.current_task_index += 1
self.current_task_index += 1
self.is_final_eval = False

# Store basic configuration
Expand Down Expand Up @@ -247,7 +248,7 @@
else:
self.ref_model = create_reference_model(self.policy_model)

CPPOTrainer.class_ref_model = self.ref_model
self.class_ref_model = self.ref_model

else:
# For subsequent tasks, reuse the reference model
Expand Down Expand Up @@ -284,14 +285,14 @@
args.total_episodes = int(args.num_train_epochs * self.train_dataset_len)

# Setup accelerator - shared across all tasks
if CPPOTrainer.shared_accelerator is None:
if self.shared_accelerator is None:
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps
)
self.accelerator = accelerator
CPPOTrainer.shared_accelerator = accelerator
self.shared_accelerator = accelerator
else:
self.accelerator = CPPOTrainer.shared_accelerator
self.accelerator = self.shared_accelerator
self.gather_function = self.accelerator.gather_for_metrics
if (
'use_gather_object'
Expand Down Expand Up @@ -331,7 +332,7 @@
args.num_total_batches = math.ceil(args.total_episodes / args.batch_size)
time_tensor = torch.tensor(int(time.time()), device=self.accelerator.device)
time_int = broadcast(time_tensor, 0).item()
args.run_name = f'{args.exp_name}__{args.seed}__{time_int}'
# args.run_name = f'{args.exp_name}__{args.seed}__{time_int}'
self.local_seed = args.seed + self.accelerator.process_index * 100003 # Prime
if args.num_sample_generations > 0:
self.sample_generations_freq = max(
Expand All @@ -353,11 +354,12 @@

# Create policy and value model wrapper
self.model = PolicyAndValueWrapper(self.policy_model, self.value_model)
CPPOTrainer.policy_value_models = self.model
self.policy_value_models = self.model
self.model.config = self.policy_model.config # needed for pushing to hub
else:
disable_dropout_in_model(self.reward_model)
# Subsequent tasks: Reuse existing model
self.model = CPPOTrainer.policy_value_models
self.model = self.policy_value_models
self.model.config = self.policy_model.config # needed for pushing to hub

# Always create optimizer and scheduler for each task
Expand Down Expand Up @@ -425,14 +427,14 @@
self.model, self.optimizer, self.dataloader = self.accelerator.prepare(
self.model, self.optimizer, self.dataloader
)
CPPOTrainer.ds_wrapped_models = self.model
self.ds_wrapped_models = self.model
else:
# For subsequent tasks, only prepare optimizer and dataloader
self.optimizer, self.dataloader = self.accelerator.prepare(
self.optimizer, self.dataloader
)
# Reuse the model from the first task
self.model = CPPOTrainer.ds_wrapped_models
self.model = self.ds_wrapped_models

torch.manual_seed(self.local_seed) # Reset local seed

Expand Down Expand Up @@ -469,10 +471,10 @@
args.fp16,
args.bf16,
)
CPPOTrainer.class_ref_model = self.ref_model
self.class_ref_model = self.ref_model
else:
# Reuse prepared ref_model on subsequent tasks
self.ref_model = CPPOTrainer.class_ref_model
self.ref_model = self.class_ref_model
else:
# Non-DeepSpeed path
if self.ref_model is None:
Expand All @@ -483,10 +485,10 @@
elif is_first_task:
# Only move ref_model to device on first task
self.ref_model = self.ref_model.to(self.accelerator.device) # type: ignore
CPPOTrainer.class_ref_model = self.ref_model
self.class_ref_model = self.ref_model
else:
# Reuse ref_model on subsequent tasks
self.ref_model = CPPOTrainer.class_ref_model
self.ref_model = self.class_ref_model

# Always move reward model to device
self.reward_model = self.reward_model.to(self.accelerator.device) # type: ignore
Expand Down Expand Up @@ -1019,15 +1021,15 @@
if self.ref_model is None and original_ref_model is not None:
print('Reference model was cleared during training - restoring')
self.ref_model = original_ref_model
CPPOTrainer.class_ref_model = original_ref_model
self.class_ref_model = original_ref_model

# Ensure the class variable is updated
CPPOTrainer.class_ref_model = self.ref_model
self.class_ref_model = self.ref_model
if self.is_deepspeed_enabled:
CPPOTrainer.ds_wrapped_models = self.deepspeed
self.ds_wrapped_models = self.deepspeed
else:
CPPOTrainer.ds_wrapped_models = self.model
CPPOTrainer.policy_value_models = self.model
self.ds_wrapped_models = self.model
self.policy_value_models = self.model

def evaluate(self) -> Dict[str, float]:
"""Custom evaluation method for PPO. Generates completions from the evaluation dataloader,
Expand Down Expand Up @@ -1240,32 +1242,29 @@
self.is_final_eval = is_final
return self

def save_model(
self, output_dir: Optional[str] = None, _internal_call: bool = False
) -> None:
"""Save the model, dealing with the case where it's a PEFT model without a policy attribute."""
# Store the original model
original_model = self.model

# For PEFT models (which lack .policy attribute), use the model directly
if hasattr(self.model, 'base_model'):
# PEFT model case - don't try to access .policy
pass # Keep the model as is
elif hasattr(self.model, 'policy'):
# Standard PPO case - use the policy as in the original implementation
self.model = self.model.policy
elif hasattr(self.model, 'policy_model'):
# Standard PPO case - use the policy_model as in the original implementation
self.model = self.model.policy_model

# Call the parent class's save_model
if output_dir is None:
output_dir = self.args.output_dir

Trainer.save_model(self, output_dir, _internal_call)

# Restore the original model
self.model = original_model
def save_model(self, output_dir: str, _internal_call=True) -> None:
"""
Manually save the model (and training state) to a specified directory.
This follows a similar procedure as _save_checkpoint.
"""

Check failure on line 1249 in benchmarks/cppo/cppo_trainer.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (D202)

benchmarks/cppo/cppo_trainer.py:1246:9: D202 No blank lines allowed after function docstring (found 1)

Check failure on line 1249 in benchmarks/cppo/cppo_trainer.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (D212)

benchmarks/cppo/cppo_trainer.py:1246:9: D212 Multi-line docstring summary should start at the first line

# Save the model files to output_dir (marking _internal_call True)
from transformers import Trainer # ensure Trainer is imported
Trainer.save_model(self, output_dir, _internal_call=True)

# If not saving only the model, save optimizer, scheduler, and RNG state
if not self.args.save_only_model:
self._save_optimizer_and_scheduler(output_dir)
self._save_scaler(output_dir)
self._save_rng_state(output_dir)

# Save the trainer state
trainer_state_path = os.path.join(output_dir, "trainer_state.json")
self.state.save_to_json(trainer_state_path)

# Optionally push to hub if that option is enabled
if self.args.push_to_hub:
self._push_from_checkpoint(output_dir)


def get_cppo_plasticity_weights(
Expand Down
Loading
Loading