diff --git a/benchmarks/continual_eval_checkpoints.py b/benchmarks/continual_eval_checkpoints.py index 016887fa..6f748b1d 100644 --- a/benchmarks/continual_eval_checkpoints.py +++ b/benchmarks/continual_eval_checkpoints.py @@ -1,9 +1,9 @@ -"""Evaluating checkpoints obtained from training using the dpo_continual script.""" - import glob import os +import re import torch +import wandb as wb from dataloading import init_continual_dataset from datasets import Dataset from dpo.continual_dpo_trainer import ( @@ -17,9 +17,7 @@ AutoTokenizer, ) from trl import ( - DPOConfig, ModelConfig, - ScriptArguments, TrlParser, get_kbit_device_map, get_peft_config, @@ -27,12 +25,10 @@ ) from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE -import wandb as wb - def main( - script_args: ScriptArguments, - training_args: DPOConfig, + script_args: ContinualDPOArguments, + training_args: ContinualDPOConfig, model_args: ModelConfig, ) -> None: # Determine torch dtype and quantization configs @@ -41,6 +37,9 @@ def main( if model_args.torch_dtype in ['auto', None] else getattr(torch, model_args.torch_dtype) ) + if script_args.wandb_run_name is not None: + training_args.run_name = script_args.wandb_run_name + quantization_config = get_quantization_config(model_args) # Model & Tokenizer Setup @@ -87,14 +86,26 @@ def main( # Validate reward model paths if provided for i, _ in enumerate(continual_dataset): - reward_path = os.path.join(training_args.reward_model_path, str(i)) + reward_path = training_args.reward_model_path + '_' + str(i) if not os.path.exists(reward_path): raise FileNotFoundError( f'Reward model not found for dataset {i} at {reward_path}' ) checkpoint_paths = glob.glob(f'{script_args.checkpoint_dir}/*/*') - checkpoint_paths = sorted([ch for ch in checkpoint_paths if 'checkpoint' in ch]) + + def extract_indices(path): + match = re.search(r'dataset-(\d+)/checkpoint-(\d+)', path) + if match: + dataset_idx = int(match.group(1)) + checkpoint_idx = int(match.group(2)) + return (dataset_idx, checkpoint_idx) + else: + return (float('inf'), float('inf')) # in case of unexpected format + + checkpoint_paths = [ch for ch in checkpoint_paths if 'checkpoint' in ch] + checkpoint_paths.sort(key=extract_indices) + print('checkpoint_paths', checkpoint_paths) # Checkpoint loop for checkpoint_path in checkpoint_paths: @@ -103,14 +114,20 @@ def main( print( f'Evaluating checkpoint: {checkpoint_step} trained on dataset: {dataset_name} on all tasks' ) - adapter_name = dataset_name + checkpoint_step - model.load_adapter(checkpoint_path, adapter_name=adapter_name) + # adapter_name = dataset_name + checkpoint_step + # model.load_adapter(checkpoint_path, adapter_name=adapter_name) + model = AutoModelForCausalLM.from_pretrained( + checkpoint_path, + trust_remote_code=model_args.trust_remote_code, + **model_kwargs, + ) metrics = {} # Task Loop for i, dataset in enumerate(continual_dataset): + print('task', i) reward_model = AutoModelForSequenceClassification.from_pretrained( - training_args.reward_model_path + f'/{str(i)}', num_labels=1 + training_args.reward_model_path + f'_{str(i)}', num_labels=1 ) training_args.output_dir = f'{output_dir}/dataset-{i}' @@ -129,8 +146,18 @@ def main( ev_metrics = trainer.evaluate() ev_metrics = {f'dataset-{i}/' + k: v for k, v in ev_metrics.items()} metrics.update(ev_metrics) - - wb.log(metrics) # type: ignore[attr-defined] + if training_args.local_rank in (None, -1, 0): + wb.log({f'task/{dataset_name}/{k}': v for k, v in ev_metrics.items()}) + + # If using DeepSpeed through Accelerate, tear down the engine after training. + if hasattr(trainer, 'deepspeed') and trainer.deepspeed is not None: + # Remove reference to the DeepSpeed engine to allow proper cleanup. + del trainer.deepspeed + # Free cached GPU memory. + torch.cuda.empty_cache() + + if training_args.local_rank in (None, -1, 0): + wb.log(metrics) # type: ignore[attr-defined] print('Evaluation completed for all tasks and checkpoints!') diff --git a/benchmarks/dataloading.py b/benchmarks/dataloading.py index b93d07b2..c6d8c704 100644 --- a/benchmarks/dataloading.py +++ b/benchmarks/dataloading.py @@ -89,8 +89,11 @@ def init_continual_dataset( data = ContinualAlignmentDataset.from_json(dataset) except OSError: # need to try downloading from hub try: + # print(f'Downloading {json_name} from Hugging Face Hub...') local_path = hf_hub_download( - repo_id=dataset, filename='dataset.json', repo_type='dataset' + repo_id=f'LifelongAlignment/{dataset}', + filename='data.json', + repo_type='dataset', ) data = ContinualAlignmentDataset.from_json(local_path) except Exception as e: diff --git a/benchmarks/dpo/accelerate_configs/deepspeed_zero2.yaml b/benchmarks/dpo/accelerate_configs/deepspeed_zero2.yaml new file mode 100644 index 00000000..f369ef96 --- /dev/null +++ b/benchmarks/dpo/accelerate_configs/deepspeed_zero2.yaml @@ -0,0 +1,21 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + deepspeed_multinode_launcher: standard + offload_optimizer_device: none + offload_param_device: none + zero3_init_flag: false + zero_stage: 2 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +machine_rank: 0 +main_training_function: main +mixed_precision: 'bf16' +num_machines: 1 +num_processes: 2 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/benchmarks/dpo/accelerate_configs/deepspeed_zero3.yaml b/benchmarks/dpo/accelerate_configs/deepspeed_zero3.yaml index 7f17a48f..6b68067b 100644 --- a/benchmarks/dpo/accelerate_configs/deepspeed_zero3.yaml +++ b/benchmarks/dpo/accelerate_configs/deepspeed_zero3.yaml @@ -11,7 +11,7 @@ machine_rank: 0 main_training_function: main mixed_precision: bf16 num_machines: 1 -num_processes: 1 # TODO change to whatever number of gpus is used +num_processes: 8 # TODO change to whatever number of gpus is used rdzv_backend: static same_network: true tpu_env: [] diff --git a/benchmarks/dpo/continual_dpo_trainer.py b/benchmarks/dpo/continual_dpo_trainer.py index 8b68e7e6..024cbc08 100644 --- a/benchmarks/dpo/continual_dpo_trainer.py +++ b/benchmarks/dpo/continual_dpo_trainer.py @@ -9,9 +9,12 @@ import pandas as pd import torch import torch.nn as nn +import wandb as wb from accelerate import Accelerator, PartialState from accelerate.utils import gather_object from datasets import Dataset +from rich.console import Console +from rich.table import Table from torch.utils.data import DataLoader from transformers import ( BaseImageProcessor, @@ -36,8 +39,6 @@ ) from typing_extensions import override -import wandb as wb - @dataclass class ContinualDPOArguments(ScriptArguments): @@ -285,7 +286,10 @@ def evaluate_policy(self) -> dict: with torch.no_grad(): if self.eval_policy_dataloader is not None: - for batch in self.eval_policy_dataloader: + for idx, batch in enumerate(self.eval_policy_dataloader): + print( + f'Processing batch {idx} out of {len(self.eval_policy_dataloader)}' + ) query = batch['input_ids'].to(self.accelerator.device) context_length = query.shape[1] with unwrap_model_for_generation( @@ -324,26 +328,38 @@ def log( train_eval = 'train' if 'loss' in logs else 'eval' print(f'Logging {train_eval} metrics...') if train_eval == 'eval': - print('Computing policy metrics...') - eval_policy_metrics = self.evaluate_policy() - logs.update(eval_policy_metrics) + if self.reward_model is not None: + print('Computing policy metrics...') + eval_policy_metrics = self.evaluate_policy() + logs.update(eval_policy_metrics) # TODO: Only generation sample completions every x steps do_generate_completions = True if do_generate_completions: + print('Generating completions...') self._generate_completions() torch.cuda.empty_cache() + return super().log(logs, start_time) def _generate_completions(self) -> None: # Config from: https://github.com/huggingface/trl/blob/56e57662053e2d0cc6302dad404820b0c0ec6a91/trl/trainer/ppo_trainer.py#L688 + # generation_config = GenerationConfig( + # max_new_tokens=53, + # temperature=(0.01 + 1e-7), + # top_k=0.0, + # top_p=1.0, + # do_sample=True, + # ) generation_config = GenerationConfig( - max_new_tokens=53, - temperature=(0.01 + 1e-7), + max_new_tokens=self.args.response_length, + temperature=(self.args.temperature + 1e-7), top_k=0.0, top_p=1.0, do_sample=True, ) + + self.model.eval() table = defaultdict(list) with torch.no_grad(): with unwrap_model_for_generation( @@ -351,44 +367,62 @@ def _generate_completions(self) -> None: self.accelerator, gather_deepspeed3_params=None, ) as unwrapped_model: - for batch in self.eval_dataloader: - query = batch['input_ids'] - context_length = query.shape[1] - query_response, _ = batch_generation( - unwrapped_model, - query, - query.shape[0], - self.processing_class.pad_token_id, - generation_config, - ) - response = query_response[:, context_length:] - postprocessed_response = response - postprocessed_query_response = torch.cat( - (query, postprocessed_response), 1 - ) - _, score, _ = get_reward( - self.reward_model, - postprocessed_query_response, - self.processing_class.pad_token_id, - context_length, - ) + if self.eval_policy_dataloader is not None: + for batch in self.eval_policy_dataloader: + query = batch['input_ids'] + context_length = query.shape[1] + query_response, _ = batch_generation( + unwrapped_model, + query, + query.shape[0], + self.processing_class.pad_token_id, + generation_config, + ) + response = query_response[:, context_length:] + postprocessed_response = response + postprocessed_query_response = torch.cat( + (query, postprocessed_response), 1 + ) + _, score, _ = get_reward( + self.reward_model, + postprocessed_query_response, + self.processing_class.pad_token_id, + context_length, + ) - queries = gather_object( - self.processing_class.batch_decode( - query, skip_special_tokens=True + queries = gather_object( + self.processing_class.batch_decode( + query, skip_special_tokens=True + ) ) - ) - responses = gather_object( - self.processing_class.batch_decode(postprocessed_response) - ) - scores = ( - self.accelerator.gather_for_metrics(score).float().cpu().numpy() - ) - table['query'].extend(queries) - table['model response'].extend(responses) - table['score'].extend(scores) - break + responses = gather_object( + self.processing_class.batch_decode(postprocessed_response) + ) + scores = ( + self.accelerator.gather_for_metrics(score) + .float() + .cpu() + .numpy() + ) + table['query'].extend(queries) + table['model response'].extend(responses) + table['score'].extend(scores) + break + self.model.train() df = pd.DataFrame(table) - if self.accelerator.is_main_process and wb.run is not None: - wb.log({'completions': wb.Table(dataframe=df)}) + + if self.accelerator.is_main_process or self.accelerator is None: + print_rich_table(df.iloc[0 : 0 + 5]) + if wb.run is not None: + wb.log({'completions': wb.Table(dataframe=df)}) + + +def print_rich_table(df: pd.DataFrame) -> Table: + console = Console() + table = Table(show_lines=True) + for column in df.columns: + table.add_column(column) + for _, row in df.iterrows(): + table.add_row(*row.astype(str).tolist()) + console.print(table) diff --git a/benchmarks/dpo/dpo_continual.py b/benchmarks/dpo/dpo_continual.py index 080d8d51..7ff9bd4a 100644 --- a/benchmarks/dpo/dpo_continual.py +++ b/benchmarks/dpo/dpo_continual.py @@ -3,11 +3,7 @@ import os import torch -from continual_dpo_trainer import ( - ContinualDPOArguments, - ContinualDPOConfig, - ContinualDPOTrainer, -) +import wandb as wb from datasets import Dataset from transformers import ( AutoModelForCausalLM, @@ -23,7 +19,6 @@ ) from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE -import wandb as wb from benchmarks.dataloading import init_continual_dataset from benchmarks.dpo.continual_dpo_trainer import ( ContinualDPOArguments, @@ -104,7 +99,7 @@ def main( # first check the hub if the model is present try: AutoModelForSequenceClassification.from_pretrained( - reward_path, num_labels=1 + reward_path, num_labels=1, use_cache=True ) except: # if not found in the hub, check the local path @@ -137,6 +132,9 @@ def main( peft_config=peft_config, ) + # if i == 0: + # trainer.save_model(os.path.join(training_args.output_dir, 'checkpoint-0')) + # TODO will throw Invalidate trace cache @ step 10: expected module 11, but got module 19 # https://github.com/deepspeedai/DeepSpeed/issues/6870 # Fix with deepspeed fix release @@ -152,8 +150,9 @@ def main( print(f'eval/dataset/{i}') trainer.log_metrics(f'eval/dataset/{i}', metrics) trainer.save_metrics(f'eval', metrics) - wb.log({'eval': {'last': metrics}}) # type: ignore[attr-defined] - wb.log({f'task/{current_dataset_name}/last': metrics}) # type: ignore[attr-defined] + if training_args.local_rank in (None, -1, 0): + wb.log({'eval': {'last': metrics}}) # type: ignore[attr-defined] + wb.log({f'task/{current_dataset_name}/last': metrics}) # type: ignore[attr-defined] # Save and push to hub trainer.save_model(os.path.join(training_args.output_dir, 'last')) diff --git a/benchmarks/dpo_ewc/continual_dpo_EWC_trainer.py b/benchmarks/dpo_ewc/continual_dpo_EWC_trainer.py index 815dd537..5ee13556 100644 --- a/benchmarks/dpo_ewc/continual_dpo_EWC_trainer.py +++ b/benchmarks/dpo_ewc/continual_dpo_EWC_trainer.py @@ -1,6 +1,7 @@ from dataclasses import dataclass, field from typing import Any, Dict, Optional, Union +import deepspeed import torch import torch.nn as nn from transformers import PreTrainedModel @@ -115,29 +116,50 @@ def compute_ewc_loss(self) -> torch.Tensor: # No previous tasks, so no regularization needed return torch.tensor(0.0, device=self.accelerator.device) - ewc_loss = torch.tensor(0.0, device=self.accelerator.device) - # Calculate the EWC penalty for each parameter model = self.accelerator.unwrap_model(self.model) + ewc_loss = torch.tensor(0.0, device=self.accelerator.device) for name, param in model.named_parameters(): - if ( - name in ContinualDPOEWCTrainer.class_fisher_information - and param.requires_grad - ): - # Get the Fisher information and old parameter values - fisher = ContinualDPOEWCTrainer.class_fisher_information[name].to( - param.device - ) - old_param = ContinualDPOEWCTrainer.class_old_params[name].to( - param.device - ) + if not param.requires_grad or name not in self.class_fisher_information: + continue + # self.accelerator.print(name, param.shape) + with deepspeed.zero.GatheredParameters([param], modifier_rank=0): + if self.accelerator.is_main_process: + # Get the Fisher information and old parameter values + fisher = ContinualDPOEWCTrainer.class_fisher_information[name].to( + self.accelerator.device + ) + old_param = ContinualDPOEWCTrainer.class_old_params[name].to( + self.accelerator.device + ) + + # Calculate squared distance weighted by Fisher information + delta = param - old_param + ewc_loss = ewc_loss + (fisher * delta.pow(2)).sum() - # Calculate squared distance weighted by Fisher information - delta = param - old_param - ewc_loss += (fisher * delta.pow(2)).sum() + # Apply the EWC lambda coefficient and return + ewc_loss = 0.5 * self.ewc_lambda * ewc_loss + else: + # Non-main processes should not compute EWC loss + ewc_loss = torch.tensor(0.0, device=self.accelerator.device) - # Apply the EWC lambda coefficient and return - return 0.5 * self.ewc_lambda * ewc_loss + ewc_loss = self.accelerator.reduce(ewc_loss, 'mean') + return ewc_loss + + def store_current_parameters(self) -> Dict[str, torch.Tensor]: + """Store the current model parameters. + + Returns: + Dictionary mapping parameter names to their current values + """ + model = self.accelerator.unwrap_model(self.model) + old_params = {} + for name, param in model.named_parameters(): + with deepspeed.zero.GatheredParameters([param], modifier_rank=0): + if self.accelerator.is_main_process: + if param.requires_grad: + old_params[name] = param.data.clone().detach() + return old_params def compute_fisher_information( self, num_samples: int = 120 @@ -152,11 +174,6 @@ def compute_fisher_information( """ # Get unwrapped model for computing Fisher model = self.accelerator.unwrap_model(self.model) - self.accelerator.device - - # Make sure parameters require gradients - for param in model.parameters(): - param.requires_grad_(True) # Initialize fisher information dictionary fisher_info = {} @@ -197,7 +214,9 @@ def compute_fisher_information( model.zero_grad() try: - loss, _ = self.compute_loss(model, batch, return_outputs=True) + loss, _ = super(ContinualDPOEWCTrainer, self).compute_loss( + model, batch, return_outputs=True + ) # Check if loss requires gradient if not loss.requires_grad: @@ -229,19 +248,6 @@ def compute_fisher_information( print(f'Computed Fisher information for {sample_count} examples') return fisher_info - def store_current_parameters(self) -> Dict[str, torch.Tensor]: - """Store the current model parameters. - - Returns: - Dictionary mapping parameter names to their current values - """ - model = self.accelerator.unwrap_model(self.model) - old_params = {} - for name, param in model.named_parameters(): - if param.requires_grad: - old_params[name] = param.data.clone().detach() - return old_params - def train(self) -> Any: """Override train method to incorporate EWC regularization.""" # Regular training diff --git a/benchmarks/dpo_ewc/dpo_EWC_continual.py b/benchmarks/dpo_ewc/dpo_EWC_continual.py index 35a00b1e..021a4e71 100644 --- a/benchmarks/dpo_ewc/dpo_EWC_continual.py +++ b/benchmarks/dpo_ewc/dpo_EWC_continual.py @@ -3,6 +3,7 @@ import os import torch +import wandb as wb from continual_dpo_EWC_trainer import ( ContinualDPOEWCArguments, ContinualDPOEWCConfig, @@ -23,7 +24,6 @@ ) from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE -import wandb as wb from benchmarks.dataloading import init_continual_dataset @@ -132,6 +132,9 @@ def main( peft_config=peft_config, ) + # if i == 0: + # trainer.save_model(os.path.join(training_args.output_dir, 'checkpoint-0')) + # TODO will throw Invalidate trace cache @ step 10: expected module 11, but got module 19 # https://github.com/deepspeedai/DeepSpeed/issues/6870 # Fix with deepspeed fix release @@ -147,8 +150,9 @@ def main( print(f'eval/dataset/{i}') trainer.log_metrics(f'eval/dataset/{i}', metrics) trainer.save_metrics(f'eval', metrics) - wb.log({'eval': {'last': metrics}}) # type: ignore[attr-defined] - wb.log({f'task/{current_dataset_name}/last': metrics}) # type: ignore[attr-defined] + if training_args.local_rank in (None, -1, 0): + wb.log({'eval': {'last': metrics}}) # type: ignore[attr-defined] + wb.log({f'task/{current_dataset_name}/last': metrics}) # type: ignore[attr-defined] # Save and push to hub trainer.save_model(os.path.join(training_args.output_dir, 'last')) diff --git a/benchmarks/hf_upload_models.py b/benchmarks/hf_upload_models.py new file mode 100644 index 00000000..c50c805c --- /dev/null +++ b/benchmarks/hf_upload_models.py @@ -0,0 +1,29 @@ +from huggingface_hub import HfApi, upload_folder + +datasets = 'aifgen-long-piecewise aifgen-lipschitz aifgen-piecewise-preference-shift aifgen-domain-preference-shift aifgen-short-piecewise CPPO-REWARD' +dataset_indices = '0 1 2 3 4 5 6 7 8 9' +model = 'Qwen2-0.5B' +# datasets="aifgen-long-piecewise" +# dataset_indices="0" + +for dataset_name in datasets.split(): + for dataset_index in dataset_indices.split(): + # Upload the model to the Hugging Face Hub + try: + repo_id = f'LifelongAlignment/{model}-Instruct_{dataset_name}_REWARD_{dataset_index}' + api = HfApi() + api.create_repo(repo_id, repo_type='model', exist_ok=True, private=False) + + path = f'/lustre/orion/bif151/scratch/ivan.anokhin/AIF-Gen/{dataset_name}/{model}-Reward-8gpus/{model}-Instruct_{dataset_name}_REWARD_{dataset_index}' + print('path', path) + + upload_folder( + repo_id=repo_id, + # path_in_repo=f"{dataset_name}-{dataset_index}/reward-model", + folder_path=path, + commit_message='Upload AIFGen reward model', + repo_type='model', + ) + except: + print(f'Failed to upload {dataset_name}-{dataset_index} reward model') + continue diff --git a/benchmarks/parallel_eval_checkpoints.py b/benchmarks/parallel_eval_checkpoints.py new file mode 100644 index 00000000..c4921819 --- /dev/null +++ b/benchmarks/parallel_eval_checkpoints.py @@ -0,0 +1,169 @@ +import os + +import torch +import wandb as wb +from dataloading import init_continual_dataset +from datasets import Dataset +from dpo.continual_dpo_trainer import ( + ContinualDPOArguments, + ContinualDPOConfig, + ContinualDPOTrainer, +) +from safetensors import safe_open +from transformers import ( + AutoModelForCausalLM, + AutoModelForSequenceClassification, + AutoTokenizer, +) +from trl import ( + ModelConfig, + TrlParser, + get_kbit_device_map, + get_peft_config, + get_quantization_config, +) +from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE + + +def main( + script_args: ContinualDPOArguments, + training_args: ContinualDPOConfig, + model_args: ModelConfig, +) -> None: + # Determine torch dtype and quantization configs + + torch_dtype = ( + model_args.torch_dtype + if model_args.torch_dtype in ['auto', None] + else getattr(torch, model_args.torch_dtype) + ) + if script_args.wandb_run_name is not None: + training_args.run_name = script_args.wandb_run_name + + quantization_config = get_quantization_config(model_args) + + # Model & Tokenizer Setup + model_kwargs = dict( + revision=model_args.model_revision, + attn_implementation=model_args.attn_implementation, + torch_dtype=torch_dtype, + use_cache=False if training_args.gradient_checkpointing else True, + device_map=get_kbit_device_map() if quantization_config is not None else None, + quantization_config=quantization_config, + ) + + # Checkpoint loop + checkpoint_path = script_args.checkpoint_dir + if 'PPO' in checkpoint_path: + dataset_name = 'dataset-' + checkpoint_path.split('/')[-2].split('_')[-1] + else: + dataset_name = checkpoint_path.split('/')[-2].replace('.', '') + + checkpoint_step = checkpoint_path.split('/')[-1].replace('.', '') + print( + f'Evaluating checkpoint: {checkpoint_step} trained on dataset: {dataset_name} on all tasks' + ) + checkpoint_name = dataset_name + '_' + checkpoint_step + print('checkpoint_name', checkpoint_name) + + if 'PPO' in checkpoint_path: + # remove the prefix 'policy.' from the keys to load the model; skip the critic and value model + prefix = 'policy.' + with safe_open( + checkpoint_path + '/model.safetensors', framework='pt', device='cpu' + ) as f: + clean_sd = { + k[len(prefix) :] if k.startswith(prefix) else k: f.get_tensor(k) + for k in f.keys() + if not ( + k.startswith('critic_backbone.') or k.startswith('value_model.') + ) + } + + model = AutoModelForCausalLM.from_pretrained( + checkpoint_path, + trust_remote_code=model_args.trust_remote_code, + state_dict=clean_sd, + **model_kwargs, + ) + else: + model = AutoModelForCausalLM.from_pretrained( + checkpoint_path, + trust_remote_code=model_args.trust_remote_code, + **model_kwargs, + ) + peft_config = get_peft_config(model_args) + + ref_model = AutoModelForCausalLM.from_pretrained( + model_args.model_name_or_path, + trust_remote_code=model_args.trust_remote_code, + **model_kwargs, + ) + + # Load tokenizer and set chat template if needed + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code + ) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + if tokenizer.chat_template is None: + tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE + + # Initialize continual dataset + continual_dataset: list[dict[str, Dataset]] = init_continual_dataset( + script_args.dataset_name, + mock=training_args.mock, + tokenizer=tokenizer, + tools=getattr(training_args, 'tools', None), + ) + output_dir = training_args.output_dir + + # Validate reward model paths if provided + for i, _ in enumerate(continual_dataset): + reward_path = training_args.reward_model_path + '_' + str(i) + if not os.path.exists(reward_path): + raise FileNotFoundError( + f'Reward model not found for dataset {i} at {reward_path}' + ) + + # Task Loop + for i, dataset in enumerate(continual_dataset): + print('task', i) + reward_model = AutoModelForSequenceClassification.from_pretrained( + training_args.reward_model_path + f'_{str(i)}', num_labels=1 + ) + + training_args.output_dir = f'{output_dir}/dataset-{i}' + # using ContinualDPOTrainer for all pipelines (PPO, DPO, COPR, ..) only for evaluation + trainer = ContinualDPOTrainer( + args=training_args, + processing_class=tokenizer, + model=model, + ref_model=ref_model, + reward_model=reward_model, + train_dataset=dataset[script_args.dataset_test_split], + eval_dataset=dataset[script_args.dataset_test_split], + peft_config=peft_config, + ) + + print('evaluating...') + ev_metrics = trainer.evaluate() + # ev_metrics = {f'dataset-{i}/' + k: v for k, v in ev_metrics.items()} + if training_args.local_rank in (None, -1, 0): + print('ev_metrics', ev_metrics) + wb.log(ev_metrics) + wb.log({f'{checkpoint_name}/{k}': v for k, v in ev_metrics.items()}) + + # If using DeepSpeed through Accelerate, tear down the engine after training. + if hasattr(trainer, 'deepspeed') and trainer.deepspeed is not None: + # Remove reference to the DeepSpeed engine to allow proper cleanup. + del trainer.deepspeed + # Free cached GPU memory. + torch.cuda.empty_cache() + + +if __name__ == '__main__': + dataclass_types = (ContinualDPOArguments, ContinualDPOConfig, ModelConfig) + parser = TrlParser(dataclass_types) + script_args, training_args, model_args = parser.parse_args_and_config() + main(script_args, training_args, model_args) diff --git a/benchmarks/ppo/README.md b/benchmarks/ppo/README.md index e49ee680..0ba63852 100644 --- a/benchmarks/ppo/README.md +++ b/benchmarks/ppo/README.md @@ -20,6 +20,7 @@ uv run benchmarks/ppo/ppo_continual.py \ --reward_model_path Shahradmz/Qwen2-0.5B-Instruct_continual_data_debug_REWARD \ --learning_rate 5.0e-6 \ --num_train_epochs 1 \ + --gradient_accumulation_steps 2 \ --gradient_accumulation_steps 8 \ --gradient_checkpointing \ --logging_steps 20 \ @@ -32,7 +33,7 @@ uv run benchmarks/ppo/ppo_continual.py \ --use_peft \ --lora_r 32 \ --lora_alpha 16 \ - --push_to_hub True + --push_to_hub False ``` ### Using accelerate launch (with DeepSpeed / multi-GPU) @@ -50,19 +51,19 @@ accelerate launch --config_file benchmarks/ppo/accelerate_configs/deepspeed_zero --learning_rate 5.0e-6 \ --num_train_epochs 1 \ --per_device_train_batch_size 2 \ - --gradient_accumulation_steps 8 \ + --gradient_accumulation_steps 1 \ --gradient_checkpointing \ - --logging_steps 2 \ + --logging_steps 10 \ --eval_strategy steps \ - --eval_steps 5 \ - --save_steps 5 \ + --eval_steps 10 \ + --save_steps 10 \ --bf16 \ --output_dir "$SCRATCH/Qwen2-0.5B-PPO-test" \ --no_remove_unused_columns \ --use_peft \ --lora_r 32 \ --lora_alpha 16 \ - --push_to_hub True + --push_to_hub False ``` *Make sure you do not add the dataset index to the reward model name as the script itself iterates over the dataset indices.* @@ -70,7 +71,8 @@ accelerate launch --config_file benchmarks/ppo/accelerate_configs/deepspeed_zero ### Full Training (without PEFT push, for local evaluation) ```sh -uv run benchmarks/ppo/ppo_continual.py \ +accelerate launch --config_file benchmarks/ppo/accelerate_configs/deepspeed_zero2.yaml \ + benchmarks/ppo/ppo_continual.py \ --dataset_name benchmarks/continual_data_debug.json \ --mock False \ --sft_model_path Qwen/Qwen2-0.5B-Instruct \ @@ -78,14 +80,16 @@ uv run benchmarks/ppo/ppo_continual.py \ --reward_model_path Shahradmz/Qwen2-0.5B-Instruct_continual_data_debug_REWARD \ --learning_rate 5.0e-7 \ --num_train_epochs 1 \ - --per_device_train_batch_size 2 \ - --gradient_accumulation_steps 8 \ + --bf16 \ + --per_device_train_batch_size 1 \ + --gradient_accumulation_steps 1 \ --gradient_checkpointing \ --logging_steps 20 \ --eval_strategy steps \ --eval_steps 20 \ --output_dir "$SCRATCH/Qwen2-0.5B-PPO" \ - --no_remove_unused_columns + --no_remove_unused_columns \ + --push_to_hub False ``` ### Run a Sweep with wandb diff --git a/benchmarks/ppo/accelerate_configs/deepspeed_zero2.yaml b/benchmarks/ppo/accelerate_configs/deepspeed_zero2.yaml index 8046cccc..239b14ac 100644 --- a/benchmarks/ppo/accelerate_configs/deepspeed_zero2.yaml +++ b/benchmarks/ppo/accelerate_configs/deepspeed_zero2.yaml @@ -12,7 +12,7 @@ machine_rank: 0 main_training_function: main mixed_precision: 'bf16' num_machines: 1 -num_processes: 1 +num_processes: 8 rdzv_backend: static same_network: true tpu_env: [] diff --git a/benchmarks/ppo/continual_ppo_trainer.py b/benchmarks/ppo/continual_ppo_trainer.py index 0c20471c..cb3d747f 100644 --- a/benchmarks/ppo/continual_ppo_trainer.py +++ b/benchmarks/ppo/continual_ppo_trainer.py @@ -311,8 +311,8 @@ def __init__( # Training scheduling args.num_total_batches = math.ceil(args.total_episodes / args.batch_size) time_tensor = torch.tensor(int(time.time()), device=self.accelerator.device) - time_int = broadcast(time_tensor, 0).item() - args.run_name = f'{args.exp_name}__{args.seed}__{time_int}' + broadcast(time_tensor, 0).item() + # args.run_name = f'{args.exp_name}__{args.seed}__{time_int}' self.local_seed = args.seed + self.accelerator.process_index * 100003 # Prime if args.num_sample_generations > 0: self.sample_generations_freq = max( @@ -729,110 +729,110 @@ def repeat_generator() -> DataLoader: for micro_batch_start in range( 0, args.local_mini_batch_size, args.per_device_train_batch_size ): - with accelerator.accumulate(model): - micro_batch_end = ( - micro_batch_start + args.per_device_train_batch_size - ) - micro_batch_inds = mini_batch_inds[ - micro_batch_start:micro_batch_end - ] - mb_advantage = advantages[micro_batch_inds] - mb_responses = responses[micro_batch_inds] - mb_query_responses = query_responses[micro_batch_inds] - mb_logprobs = logprobs[micro_batch_inds] - mb_return = returns[micro_batch_inds] - mb_values = values[micro_batch_inds] - - output, vpred_temp = forward( - model, mb_query_responses, processing_class.pad_token_id - ) - logits = output.logits[:, context_length - 1 : -1] - logits /= args.temperature + 1e-7 - new_logprobs = selective_log_softmax(logits, mb_responses) - new_logprobs = torch.masked_fill( - new_logprobs, - padding_mask[micro_batch_inds], - INVALID_LOGPROB, - ) - vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1) - vpred = torch.masked_fill( - vpred, padding_mask_p1[micro_batch_inds], 0 - ) - vpredclipped = torch.clamp( - vpred, - mb_values - args.cliprange_value, - mb_values + args.cliprange_value, - ) - vf_losses1 = torch.square(vpred - mb_return) - vf_losses2 = torch.square(vpredclipped - mb_return) - vf_loss_max = torch.max(vf_losses1, vf_losses2) - vf_loss = 0.5 * masked_mean( - vf_loss_max, ~padding_mask_p1[micro_batch_inds] - ) - vf_clipfrac = masked_mean( - (vf_losses2 > vf_losses1).float(), - ~padding_mask_p1[micro_batch_inds], - ) - logprobs_diff = new_logprobs - mb_logprobs - ratio = torch.exp(logprobs_diff) - pg_losses = -mb_advantage * ratio - pg_losses2 = -mb_advantage * torch.clamp( - ratio, 1.0 - args.cliprange, 1.0 + args.cliprange + # with accelerator.accumulate(model): + micro_batch_end = ( + micro_batch_start + args.per_device_train_batch_size + ) + micro_batch_inds = mini_batch_inds[ + micro_batch_start:micro_batch_end + ] + mb_advantage = advantages[micro_batch_inds] + mb_responses = responses[micro_batch_inds] + mb_query_responses = query_responses[micro_batch_inds] + mb_logprobs = logprobs[micro_batch_inds] + mb_return = returns[micro_batch_inds] + mb_values = values[micro_batch_inds] + + output, vpred_temp = forward( + model, mb_query_responses, processing_class.pad_token_id + ) + logits = output.logits[:, context_length - 1 : -1] + logits /= args.temperature + 1e-7 + new_logprobs = selective_log_softmax(logits, mb_responses) + new_logprobs = torch.masked_fill( + new_logprobs, + padding_mask[micro_batch_inds], + INVALID_LOGPROB, + ) + vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1) + vpred = torch.masked_fill( + vpred, padding_mask_p1[micro_batch_inds], 0 + ) + vpredclipped = torch.clamp( + vpred, + mb_values - args.cliprange_value, + mb_values + args.cliprange_value, + ) + vf_losses1 = torch.square(vpred - mb_return) + vf_losses2 = torch.square(vpredclipped - mb_return) + vf_loss_max = torch.max(vf_losses1, vf_losses2) + vf_loss = 0.5 * masked_mean( + vf_loss_max, ~padding_mask_p1[micro_batch_inds] + ) + vf_clipfrac = masked_mean( + (vf_losses2 > vf_losses1).float(), + ~padding_mask_p1[micro_batch_inds], + ) + logprobs_diff = new_logprobs - mb_logprobs + ratio = torch.exp(logprobs_diff) + pg_losses = -mb_advantage * ratio + pg_losses2 = -mb_advantage * torch.clamp( + ratio, 1.0 - args.cliprange, 1.0 + args.cliprange + ) + pg_loss_max = torch.max(pg_losses, pg_losses2) + pg_loss = masked_mean( + pg_loss_max, ~padding_mask[micro_batch_inds] + ) + loss = pg_loss + args.vf_coef * vf_loss + accelerator.backward(loss) + optimizer.step() + optimizer.zero_grad() + with torch.no_grad(): + pg_clipfrac = masked_mean( + (pg_losses2 > pg_losses).float(), + ~padding_mask[micro_batch_inds], ) - pg_loss_max = torch.max(pg_losses, pg_losses2) - pg_loss = masked_mean( - pg_loss_max, ~padding_mask[micro_batch_inds] + prob_dist = torch.nn.functional.softmax(logits, dim=-1) + entropy = torch.logsumexp(logits, dim=-1) - torch.sum( + prob_dist * logits, dim=-1 ) - loss = pg_loss + args.vf_coef * vf_loss - accelerator.backward(loss) - optimizer.step() - optimizer.zero_grad() - with torch.no_grad(): - pg_clipfrac = masked_mean( - (pg_losses2 > pg_losses).float(), - ~padding_mask[micro_batch_inds], - ) - prob_dist = torch.nn.functional.softmax(logits, dim=-1) - entropy = torch.logsumexp(logits, dim=-1) - torch.sum( - prob_dist * logits, dim=-1 - ) - approxkl = 0.5 * (logprobs_diff**2).mean() - approxkl_stats[ - ppo_epoch_idx, - minibatch_idx, - gradient_accumulation_idx, - ] = approxkl - pg_clipfrac_stats[ - ppo_epoch_idx, - minibatch_idx, - gradient_accumulation_idx, - ] = pg_clipfrac - pg_loss_stats[ - ppo_epoch_idx, - minibatch_idx, - gradient_accumulation_idx, - ] = pg_loss - vf_loss_stats[ - ppo_epoch_idx, - minibatch_idx, - gradient_accumulation_idx, - ] = vf_loss - vf_clipfrac_stats[ - ppo_epoch_idx, - minibatch_idx, - gradient_accumulation_idx, - ] = vf_clipfrac - entropy_stats[ - ppo_epoch_idx, - minibatch_idx, - gradient_accumulation_idx, - ] = entropy.mean() - ratio_stats[ - ppo_epoch_idx, - minibatch_idx, - gradient_accumulation_idx, - ] = ratio.mean() - gradient_accumulation_idx += 1 + approxkl = 0.5 * (logprobs_diff**2).mean() + approxkl_stats[ + ppo_epoch_idx, + minibatch_idx, + gradient_accumulation_idx, + ] = approxkl + pg_clipfrac_stats[ + ppo_epoch_idx, + minibatch_idx, + gradient_accumulation_idx, + ] = pg_clipfrac + pg_loss_stats[ + ppo_epoch_idx, + minibatch_idx, + gradient_accumulation_idx, + ] = pg_loss + vf_loss_stats[ + ppo_epoch_idx, + minibatch_idx, + gradient_accumulation_idx, + ] = vf_loss + vf_clipfrac_stats[ + ppo_epoch_idx, + minibatch_idx, + gradient_accumulation_idx, + ] = vf_clipfrac + entropy_stats[ + ppo_epoch_idx, + minibatch_idx, + gradient_accumulation_idx, + ] = entropy.mean() + ratio_stats[ + ppo_epoch_idx, + minibatch_idx, + gradient_accumulation_idx, + ] = ratio.mean() + gradient_accumulation_idx += 1 minibatch_idx += 1 # del everything and empty cache # fmt: off diff --git a/benchmarks/ppo/ppo_continual.py b/benchmarks/ppo/ppo_continual.py index 5fe18513..8db6aff3 100644 --- a/benchmarks/ppo/ppo_continual.py +++ b/benchmarks/ppo/ppo_continual.py @@ -3,6 +3,7 @@ import os import torch +import wandb as wb from continual_ppo_trainer import ( ContinualPPOArguments, ContinualPPOConfig, @@ -23,7 +24,6 @@ ) from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE -import wandb as wb from benchmarks.dataloading import init_continual_dataset @@ -100,6 +100,7 @@ def main( if '.' in clean_dataset_name: clean_dataset_name = clean_dataset_name.split('.')[0] + print(f'Training PPO on {len(continual_dataset)} tasks') # check if the reward models are present either in the path or in the hub if training_args.reward_model_path is not None: for i in range(len(continual_dataset)): @@ -143,6 +144,10 @@ def main( eval_dataset=dataset[script_args.dataset_test_split], peft_config=peft_config, ) + + # if i == 0: + # trainer.save_model(os.path.join(training_args.output_dir, 'checkpoint-0')) + # Set current task in trainer for task-based logging trainer.set_task(f'task_{i}') @@ -164,8 +169,9 @@ def main( trainer.save_metrics('eval', metrics) # Log metrics to WandB - wb.log({'eval': {'last': metrics}}) # type: ignore[attr-defined] - wb.log({f'task/{custom_repo_name}/last': metrics}) # type: ignore[attr-defined] + if training_args.local_rank in (None, -1, 0): + wb.log({'eval': {'last': metrics}}) # type: ignore[attr-defined] + wb.log({f'task/{custom_repo_name}/last': metrics}) # type: ignore[attr-defined] # Save model checkpoint and optionally push if not training_args.push_to_hub: diff --git a/benchmarks/ppo_ewc/ppo_EWC_continual.py b/benchmarks/ppo_ewc/ppo_EWC_continual.py index 211bc56a..c71e90e6 100644 --- a/benchmarks/ppo_ewc/ppo_EWC_continual.py +++ b/benchmarks/ppo_ewc/ppo_EWC_continual.py @@ -3,6 +3,7 @@ import os import torch +import wandb as wb from datasets import Dataset from transformers import ( AutoModelForCausalLM, @@ -18,7 +19,6 @@ ) from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE -import wandb as wb from benchmarks.dataloading import init_continual_dataset from benchmarks.ppo_ewc.continual_ppo_EWC_trainer import ( ContinualPPOEWCArguments, @@ -176,8 +176,9 @@ def main( trainer.save_metrics('eval', metrics) # Log metrics to WandB - wb.log({'eval': {'last': metrics}}) # type: ignore[attr-defined] - wb.log({f'task/{custom_repo_name}/last': metrics}) # type: ignore[attr-defined] + if training_args.local_rank in (None, -1, 0): + wb.log({'eval': {'last': metrics}}) # type: ignore[attr-defined] + wb.log({f'task/{custom_repo_name}/last': metrics}) # type: ignore[attr-defined] # Save model checkpoint and optionally push if not training_args.push_to_hub: diff --git a/benchmarks/reward_modeling.py b/benchmarks/reward_modeling.py index 93155178..f214cfb2 100644 --- a/benchmarks/reward_modeling.py +++ b/benchmarks/reward_modeling.py @@ -129,6 +129,7 @@ def train_model( trust_remote_code=model_args.trust_remote_code, **model_kwargs, ) + # Align padding tokens between tokenizer and model model.config.pad_token_id = tokenizer.pad_token_id @@ -226,6 +227,9 @@ def train_model( except Exception as e: print(f'Job {i + 1} failed with error: {e}') else: + print( + f'Running on {script_args.dataset_index+1} task out of {len(continual_dataset)} tasks' + ) dataset = continual_dataset[script_args.dataset_index] train_model( script_args, training_args, model_args, dataset, script_args.dataset_index diff --git a/pyproject.toml b/pyproject.toml index d547b374..7e6413dc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,10 @@ dependencies = [ "numpy>=2.0.2", "openai>=1.61.1", "pydantic>=2.10.4", + "pytest-asyncio>=0.25.3", + "pytest-mock>=3.14.0", +# "torch @ https://download.pytorch.org/whl/rocm6.0/torch-2.3.0%2Brocm6.0-cp310-cp310-linux_x86_64.whl#sha256=266af54cf4704aae08719305c205f0d12f40874006d3b8058f38e2f8ed08f56d", + "types-pyyaml>=6.0.12.20241230", "torch>=2.6.0", ] diff --git a/uv.lock b/uv.lock index ecabef39..d0f6c71e 100644 --- a/uv.lock +++ b/uv.lock @@ -40,7 +40,10 @@ dependencies = [ { name = "numpy" }, { name = "openai" }, { name = "pydantic" }, + { name = "pytest-asyncio" }, + { name = "pytest-mock" }, { name = "torch" }, + { name = "types-pyyaml" }, ] [package.dev-dependencies] @@ -81,7 +84,10 @@ requires-dist = [ { name = "numpy", specifier = ">=2.0.2" }, { name = "openai", specifier = ">=1.61.1" }, { name = "pydantic", specifier = ">=2.10.4" }, + { name = "pytest-asyncio", specifier = ">=0.25.3" }, + { name = "pytest-mock", specifier = ">=3.14.0" }, { name = "torch", specifier = ">=2.6.0" }, + { name = "types-pyyaml", specifier = ">=6.0.12.20241230" }, ] [package.metadata.requires-dev] @@ -372,7 +378,7 @@ name = "click" version = "8.1.8" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "platform_system == 'Windows'" }, + { name = "colorama", marker = "sys_platform == 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/b9/2e/0090cbf739cee7d23781ad4b89a9894a41538e4fcf4c31dcdd705b78eb8b/click-8.1.8.tar.gz", hash = "sha256:ed53c9d8990d83c2a27deae68e4ee337473f6330c040a31d4225c9574d16096a", size = 226593 } wheels = [ @@ -1013,7 +1019,7 @@ version = "1.6.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "click" }, - { name = "colorama", marker = "platform_system == 'Windows'" }, + { name = "colorama", marker = "sys_platform == 'win32'" }, { name = "ghp-import" }, { name = "jinja2" }, { name = "markdown" }, @@ -1429,7 +1435,6 @@ name = "nvidia-cublas-cu12" version = "12.4.5.8" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/7f/7f/7fbae15a3982dc9595e49ce0f19332423b260045d0a6afe93cdbe2f1f624/nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_aarch64.whl", hash = "sha256:0f8aa1706812e00b9f19dfe0cdb3999b092ccb8ca168c0db5b8ea712456fd9b3", size = 363333771 }, { url = "https://files.pythonhosted.org/packages/ae/71/1c91302526c45ab494c23f61c7a84aa568b8c1f9d196efa5993957faf906/nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl", hash = "sha256:2fc8da60df463fdefa81e323eef2e36489e1c94335b5358bcb38360adf75ac9b", size = 363438805 }, ] @@ -1438,7 +1443,6 @@ name = "nvidia-cuda-cupti-cu12" version = "12.4.127" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/93/b5/9fb3d00386d3361b03874246190dfec7b206fd74e6e287b26a8fcb359d95/nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:79279b35cf6f91da114182a5ce1864997fd52294a87a16179ce275773799458a", size = 12354556 }, { url = "https://files.pythonhosted.org/packages/67/42/f4f60238e8194a3106d06a058d494b18e006c10bb2b915655bd9f6ea4cb1/nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:9dec60f5ac126f7bb551c055072b69d85392b13311fcc1bcda2202d172df30fb", size = 13813957 }, ] @@ -1447,7 +1451,6 @@ name = "nvidia-cuda-nvrtc-cu12" version = "12.4.127" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/77/aa/083b01c427e963ad0b314040565ea396f914349914c298556484f799e61b/nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:0eedf14185e04b76aa05b1fea04133e59f465b6f960c0cbf4e37c3cb6b0ea198", size = 24133372 }, { url = "https://files.pythonhosted.org/packages/2c/14/91ae57cd4db3f9ef7aa99f4019cfa8d54cb4caa7e00975df6467e9725a9f/nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a178759ebb095827bd30ef56598ec182b85547f1508941a3d560eb7ea1fbf338", size = 24640306 }, ] @@ -1456,7 +1459,6 @@ name = "nvidia-cuda-runtime-cu12" version = "12.4.127" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/a1/aa/b656d755f474e2084971e9a297def515938d56b466ab39624012070cb773/nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:961fe0e2e716a2a1d967aab7caee97512f71767f852f67432d572e36cb3a11f3", size = 894177 }, { url = "https://files.pythonhosted.org/packages/ea/27/1795d86fe88ef397885f2e580ac37628ed058a92ed2c39dc8eac3adf0619/nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:64403288fa2136ee8e467cdc9c9427e0434110899d07c779f25b5c068934faa5", size = 883737 }, ] @@ -1465,7 +1467,7 @@ name = "nvidia-cudnn-cu12" version = "9.1.0.70" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12" }, + { name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/9f/fd/713452cd72343f682b1c7b9321e23829f00b842ceaedcda96e742ea0b0b3/nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl", hash = "sha256:165764f44ef8c61fcdfdfdbe769d687e06374059fbb388b6c89ecb0e28793a6f", size = 664752741 }, @@ -1476,10 +1478,9 @@ name = "nvidia-cufft-cu12" version = "11.2.1.3" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-nvjitlink-cu12" }, + { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/7a/8a/0e728f749baca3fbeffad762738276e5df60851958be7783af121a7221e7/nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_aarch64.whl", hash = "sha256:5dad8008fc7f92f5ddfa2101430917ce2ffacd86824914c82e28990ad7f00399", size = 211422548 }, { url = "https://files.pythonhosted.org/packages/27/94/3266821f65b92b3138631e9c8e7fe1fb513804ac934485a8d05776e1dd43/nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f083fc24912aa410be21fa16d157fed2055dab1cc4b6934a0e03cba69eb242b9", size = 211459117 }, ] @@ -1488,7 +1489,6 @@ name = "nvidia-curand-cu12" version = "10.3.5.147" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/80/9c/a79180e4d70995fdf030c6946991d0171555c6edf95c265c6b2bf7011112/nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_aarch64.whl", hash = "sha256:1f173f09e3e3c76ab084aba0de819c49e56614feae5c12f69883f4ae9bb5fad9", size = 56314811 }, { url = "https://files.pythonhosted.org/packages/8a/6d/44ad094874c6f1b9c654f8ed939590bdc408349f137f9b98a3a23ccec411/nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a88f583d4e0bb643c49743469964103aa59f7f708d862c3ddb0fc07f851e3b8b", size = 56305206 }, ] @@ -1497,12 +1497,11 @@ name = "nvidia-cusolver-cu12" version = "11.6.1.9" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12" }, - { name = "nvidia-cusparse-cu12" }, - { name = "nvidia-nvjitlink-cu12" }, + { name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-cusparse-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/46/6b/a5c33cf16af09166845345275c34ad2190944bcc6026797a39f8e0a282e0/nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_aarch64.whl", hash = "sha256:d338f155f174f90724bbde3758b7ac375a70ce8e706d70b018dd3375545fc84e", size = 127634111 }, { url = "https://files.pythonhosted.org/packages/3a/e1/5b9089a4b2a4790dfdea8b3a006052cfecff58139d5a4e34cb1a51df8d6f/nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl", hash = "sha256:19e33fa442bcfd085b3086c4ebf7e8debc07cfe01e11513cc6d332fd918ac260", size = 127936057 }, ] @@ -1511,10 +1510,9 @@ name = "nvidia-cusparse-cu12" version = "12.3.1.170" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-nvjitlink-cu12" }, + { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/96/a9/c0d2f83a53d40a4a41be14cea6a0bf9e668ffcf8b004bd65633f433050c0/nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_aarch64.whl", hash = "sha256:9d32f62896231ebe0480efd8a7f702e143c98cfaa0e8a76df3386c1ba2b54df3", size = 207381987 }, { url = "https://files.pythonhosted.org/packages/db/f7/97a9ea26ed4bbbfc2d470994b8b4f338ef663be97b8f677519ac195e113d/nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl", hash = "sha256:ea4f11a2904e2a8dc4b1833cc1b5181cde564edd0d5cd33e3c168eff2d1863f1", size = 207454763 }, ] @@ -1523,7 +1521,6 @@ name = "nvidia-cusparselt-cu12" version = "0.6.2" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/98/8e/675498726c605c9441cf46653bd29cb1b8666da1fb1469ffa25f67f20c58/nvidia_cusparselt_cu12-0.6.2-py3-none-manylinux2014_aarch64.whl", hash = "sha256:067a7f6d03ea0d4841c85f0c6f1991c5dda98211f6302cb83a4ab234ee95bef8", size = 149422781 }, { url = "https://files.pythonhosted.org/packages/78/a8/bcbb63b53a4b1234feeafb65544ee55495e1bb37ec31b999b963cbccfd1d/nvidia_cusparselt_cu12-0.6.2-py3-none-manylinux2014_x86_64.whl", hash = "sha256:df2c24502fd76ebafe7457dbc4716b2fec071aabaed4fb7691a201cde03704d9", size = 150057751 }, ] @@ -1549,7 +1546,6 @@ name = "nvidia-nvjitlink-cu12" version = "12.4.127" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/02/45/239d52c05074898a80a900f49b1615d81c07fceadd5ad6c4f86a987c0bc4/nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:4abe7fef64914ccfa909bc2ba39739670ecc9e820c83ccc7a6ed414122599b83", size = 20552510 }, { url = "https://files.pythonhosted.org/packages/ff/ff/847841bacfbefc97a00036e0fce5a0f086b640756dc38caea5e1bb002655/nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:06b3b9b25bf3f8af351d664978ca26a16d2c5127dbd53c0497e28d1fb9611d57", size = 21066810 }, ] @@ -1558,7 +1554,6 @@ name = "nvidia-nvtx-cu12" version = "12.4.127" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/06/39/471f581edbb7804b39e8063d92fc8305bdc7a80ae5c07dbe6ea5c50d14a5/nvidia_nvtx_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:7959ad635db13edf4fc65c06a6e9f9e55fc2f92596db928d169c0bb031e88ef3", size = 100417 }, { url = "https://files.pythonhosted.org/packages/87/20/199b8713428322a2f22b722c62b8cc278cc53dffa9705d744484b5035ee9/nvidia_nvtx_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:781e950d9b9f60d8241ccea575b32f5105a5baf4c2351cab5256a24869f12a1a", size = 99144 }, ] @@ -2474,23 +2469,23 @@ dependencies = [ { name = "fsspec" }, { name = "jinja2" }, { name = "networkx", version = "3.2.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11' and sys_platform != 'linux'" }, - { name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform == 'linux' or python_full_version >= '3.11'" }, - { name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cusparselt-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-nvjitlink-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, + { name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11' or sys_platform == 'linux'" }, + { name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusparselt-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nvjitlink-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "setuptools", marker = "python_full_version >= '3.12'" }, { name = "sympy" }, - { name = "triton", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, + { name = "triton", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "typing-extensions" }, ] wheels = [ @@ -2517,7 +2512,7 @@ name = "tqdm" version = "4.67.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "platform_system == 'Windows'" }, + { name = "colorama", marker = "sys_platform == 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/a8/4b/29b4ef32e036bb34e4ab51796dd745cdba7ed47ad142a9f4a1eb8e0c744d/tqdm-4.67.1.tar.gz", hash = "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2", size = 169737 } wheels = [ @@ -2571,6 +2566,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/7b/29/25378447c48359843de0e4ce1995d367210601c3b437ddf1c779b6393d74/trl-0.15.2-py3-none-any.whl", hash = "sha256:bf2b88e3cf5da08cd533dc03273d977965bd5d86c5878f76285fba45d9cb9634", size = 318931 }, ] +[[package]] +name = "types-pyyaml" +version = "6.0.12.20250402" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/2d/68/609eed7402f87c9874af39d35942744e39646d1ea9011765ec87b01b2a3c/types_pyyaml-6.0.12.20250402.tar.gz", hash = "sha256:d7c13c3e6d335b6af4b0122a01ff1d270aba84ab96d1a1a1063ecba3e13ec075", size = 17282 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ed/56/1fe61db05685fbb512c07ea9323f06ea727125951f1eb4dff110b3311da3/types_pyyaml-6.0.12.20250402-py3-none-any.whl", hash = "sha256:652348fa9e7a203d4b0d21066dfb00760d3cbd5a15ebb7cf8d33c88a49546681", size = 20329 }, +] + [[package]] name = "typing-extensions" version = "4.12.2"