From 94c8f7d50d9a1b90d7c91d025c826dbde94e52d9 Mon Sep 17 00:00:00 2001 From: guojialiang <2802427218@qq.com> Date: Sun, 28 Sep 2025 14:42:52 +0800 Subject: [PATCH 1/7] add:open_r1 --- open_r1/readme.md | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 open_r1/readme.md diff --git a/open_r1/readme.md b/open_r1/readme.md new file mode 100644 index 000000000..e69de29bb From 18d6733710d5e95e6de97af8a5ca1842a99979dd Mon Sep 17 00:00:00 2001 From: guojialiang <2802427218@qq.com> Date: Sun, 28 Sep 2025 17:58:26 +0800 Subject: [PATCH 2/7] add:openr1 --- open_r1/src/mind_openr1/__init__.py | 8 + open_r1/src/mind_openr1/configs.py | 342 ++++++++ open_r1/src/mind_openr1/sft.py | 300 +++++++ open_r1/src/mind_openr1/sft_trainer.py | 449 +++++++++++ open_r1/src/mind_openr1/utils/__init__.py | 8 + open_r1/src/mind_openr1/utils/callbacks.py | 92 +++ .../src/mind_openr1/utils/code_providers.py | 366 +++++++++ .../utils/competitive_programming/__init__.py | 19 + .../competitive_programming/cf_scoring.py | 146 ++++ .../competitive_programming/code_patcher.py | 123 +++ .../competitive_programming/ioi_scoring.py | 335 ++++++++ .../competitive_programming/ioi_utils.py | 41 + .../competitive_programming/morph_client.py | 742 ++++++++++++++++++ .../competitive_programming/piston_client.py | 224 ++++++ .../utils/competitive_programming/utils.py | 11 + open_r1/src/mind_openr1/utils/data.py | 65 ++ open_r1/src/mind_openr1/utils/evaluation.py | 118 +++ open_r1/src/mind_openr1/utils/hub.py | 132 ++++ open_r1/src/mind_openr1/utils/import_utils.py | 30 + open_r1/src/mind_openr1/utils/model_utils.py | 42 + open_r1/src/mind_openr1/utils/routed_morph.py | 120 +++ .../src/mind_openr1/utils/routed_sandbox.py | 109 +++ .../src/mind_openr1/utils/wandb_logging.py | 13 + 23 files changed, 3835 insertions(+) create mode 100644 open_r1/src/mind_openr1/__init__.py create mode 100644 open_r1/src/mind_openr1/configs.py create mode 100644 open_r1/src/mind_openr1/sft.py create mode 100644 open_r1/src/mind_openr1/sft_trainer.py create mode 100644 open_r1/src/mind_openr1/utils/__init__.py create mode 100644 open_r1/src/mind_openr1/utils/callbacks.py create mode 100644 open_r1/src/mind_openr1/utils/code_providers.py create mode 100644 open_r1/src/mind_openr1/utils/competitive_programming/__init__.py create mode 100644 open_r1/src/mind_openr1/utils/competitive_programming/cf_scoring.py create mode 100644 open_r1/src/mind_openr1/utils/competitive_programming/code_patcher.py create mode 100644 open_r1/src/mind_openr1/utils/competitive_programming/ioi_scoring.py create mode 100644 open_r1/src/mind_openr1/utils/competitive_programming/ioi_utils.py create mode 100644 open_r1/src/mind_openr1/utils/competitive_programming/morph_client.py create mode 100644 open_r1/src/mind_openr1/utils/competitive_programming/piston_client.py create mode 100644 open_r1/src/mind_openr1/utils/competitive_programming/utils.py create mode 100644 open_r1/src/mind_openr1/utils/data.py create mode 100644 open_r1/src/mind_openr1/utils/evaluation.py create mode 100644 open_r1/src/mind_openr1/utils/hub.py create mode 100644 open_r1/src/mind_openr1/utils/import_utils.py create mode 100644 open_r1/src/mind_openr1/utils/model_utils.py create mode 100644 open_r1/src/mind_openr1/utils/routed_morph.py create mode 100644 open_r1/src/mind_openr1/utils/routed_sandbox.py create mode 100644 open_r1/src/mind_openr1/utils/wandb_logging.py diff --git a/open_r1/src/mind_openr1/__init__.py b/open_r1/src/mind_openr1/__init__.py new file mode 100644 index 000000000..31cb1f347 --- /dev/null +++ b/open_r1/src/mind_openr1/__init__.py @@ -0,0 +1,8 @@ +""" +Mind-OpenR1: MindSpore implementation of OpenR1 +""" + +from .sft_trainer import SFTTrainer, SFTConfig +from .configs import ScriptArguments + +__all__ = ["SFTTrainer", "SFTConfig", "ScriptArguments"] diff --git a/open_r1/src/mind_openr1/configs.py b/open_r1/src/mind_openr1/configs.py new file mode 100644 index 000000000..59682ed5d --- /dev/null +++ b/open_r1/src/mind_openr1/configs.py @@ -0,0 +1,342 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Any, Literal, Optional + +import trl + + +@dataclass +class DatasetConfig: + """Configuration for a dataset in a mixture.""" + + id: str + config: Optional[str] = None + split: str = "train" + columns: Optional[list[str]] = None + weight: Optional[float] = None + + +@dataclass +class DatasetMixtureConfig: + """Configuration for a mixture of datasets.""" + + datasets: list[DatasetConfig] + seed: int = 0 + test_split_size: Optional[float] = None + + +@dataclass +class ScriptArguments(trl.ScriptArguments): + """ + Extended version of ScriptArguments with support for dataset mixtures. + + Args: + dataset_mixture (`dict[str, Any]` or `None`, *optional*, defaults to `None`): + Configuration for creating dataset mixtures with advanced options. + Format: + dataset_mixture: + datasets: + - id: dataset_id1 + config: config_name + columns: + - col1 + - col2 + weight: 0.5 + - id: dataset_id2 + config: config_name + columns: + - col1 + - col2 + weight: 0.5 + seed: 42 + test_split_size: 0.1 + """ + + # Override the dataset_name to make it optional + dataset_name: Optional[str] = field( + default=None, metadata={"help": "Dataset name. Can be omitted if using dataset_mixture."} + ) + dataset_mixture: Optional[dict[str, Any]] = field( + default=None, + metadata={"help": "Configuration for creating dataset mixtures with advanced options like shuffling."}, + ) + + # Limit number of samples used for training. If None, use full training set + max_train_samples: Optional[int] = field( + default=None, + metadata={ + "help": ( + "Maximum number of training samples to use. " + "If None (default), do not truncate the training dataset." + ) + }, + ) + + def __post_init__(self): + if self.dataset_name is None and self.dataset_mixture is None: + raise ValueError("Either `dataset_name` or `dataset_mixture` must be provided") + + if self.dataset_mixture is not None: + if not isinstance(self.dataset_mixture, dict) or "datasets" not in self.dataset_mixture: + raise ValueError( + "dataset_mixture must be a dictionary with a 'datasets' key. " + "Expected format: {'datasets': [...], 'seed': int}" + ) + + datasets_list = [] + datasets_data = self.dataset_mixture.get("datasets", []) + + if isinstance(datasets_data, list): + for dataset_config in datasets_data: + datasets_list.append( + DatasetConfig( + id=dataset_config.get("id"), + config=dataset_config.get("config"), + split=dataset_config.get("split", "train"), + columns=dataset_config.get("columns"), + weight=dataset_config.get("weight", 1.0), + ) + ) + else: + raise ValueError("'datasets' must be a list of dataset configurations") + + self.dataset_mixture = DatasetMixtureConfig( + datasets=datasets_list, + seed=self.dataset_mixture.get("seed", 0), + test_split_size=self.dataset_mixture.get("test_split_size", None), + ) + + # Check that column names are consistent across all dataset configs + columns_sets = [set(dataset.columns) for dataset in datasets_list if dataset.columns is not None] + if columns_sets: + first_columns = columns_sets[0] + if not all(columns == first_columns for columns in columns_sets): + raise ValueError( + "Column names must be consistent across all dataset configurations in a mixture. " + f"Found different column sets: {[list(cols) for cols in columns_sets]}" + ) + + +# TODO: add the shared options with a mixin to reduce code duplication +@dataclass +class GRPOConfig(trl.GRPOConfig): + """ + args for callbacks, benchmarks etc + """ + + benchmarks: list[str] = field( + default_factory=lambda: [], + metadata={"help": "The benchmarks to run after training."}, + ) + callbacks: list[str] = field( + default_factory=lambda: [], + metadata={"help": "The callbacks to run during training."}, + ) + chat_template: Optional[str] = field(default=None, metadata={"help": "The chat template to use."}) + hub_model_revision: Optional[str] = field( + default="main", metadata={"help": "The Hub model branch to push the model to."} + ) + num_completions_to_print: int = field(default=0, metadata={"help": "Number of completions to print."}) + overwrite_hub_revision: bool = field(default=False, metadata={"help": "Whether to overwrite the Hub revision."}) + push_to_hub_revision: bool = field(default=False, metadata={"help": "Whether to push to a Hub revision/branch."}) + system_prompt: Optional[str] = field( + default=None, + metadata={"help": "The optional system prompt to use."}, + ) + wandb_log_unique_prompts: bool = field( + default=True, + metadata={ + "help": ("Whether to log the unique prompts to wandb. This will create a new run for each unique prompt.") + }, + ) + wandb_entity: Optional[str] = field( + default=None, + metadata={"help": ("The entity to store runs under.")}, + ) + wandb_project: Optional[str] = field( + default=None, + metadata={"help": ("The project to store runs under.")}, + ) + wandb_run_group: Optional[str] = field( + default=None, + metadata={"help": ("The group to store runs under.")}, + ) + + +@dataclass +class SFTConfig(trl.SFTConfig): + """ + args for callbacks, benchmarks etc + """ + + benchmarks: list[str] = field( + default_factory=lambda: [], + metadata={"help": "The benchmarks to run after training."}, + ) + callbacks: list[str] = field( + default_factory=lambda: [], + metadata={"help": "The callbacks to run during training."}, + ) + chat_template: Optional[str] = field(default=None, metadata={"help": "The chat template to use."}) + system_prompt: Optional[str] = field( + default=None, + metadata={"help": "The optional system prompt to use for benchmarking."}, + ) + hub_model_revision: Optional[str] = field( + default="main", + metadata={"help": "The Hub model branch to push the model to."}, + ) + overwrite_hub_revision: bool = field(default=False, metadata={"help": "Whether to overwrite the Hub revision."}) + push_to_hub_revision: bool = field(default=False, metadata={"help": "Whether to push to a Hub revision/branch."}) + wandb_entity: Optional[str] = field( + default=None, + metadata={"help": ("The entity to store runs under.")}, + ) + wandb_project: Optional[str] = field( + default=None, + metadata={"help": ("The project to store runs under.")}, + ) + wandb_run_group: Optional[str] = field( + default=None, + metadata={"help": ("The group to store runs under.")}, + ) + + +@dataclass +class GRPOScriptArguments(ScriptArguments): + """ + Script arguments for the GRPO training script. + + Args: + reward_funcs (`list[str]`): + List of reward functions. Possible values: 'accuracy', 'format', 'reasoning_steps', 'cosine', 'repetition_penalty', 'length', 'tag_count', 'code', 'ioi_code', 'code_format', 'soft_overlong_punishment'. + cosine_min_value_wrong (`float`): + Minimum reward for cosine scaling for wrong answers. + cosine_max_value_wrong (`float`): + Maximum reward for cosine scaling for wrong answers. + cosine_min_value_correct (`float`): + Minimum reward for cosine scaling for correct answers. + cosine_max_value_correct (`float`): + Maximum reward for cosine scaling for correct answers. + cosine_max_len (`int`): + Maximum length for cosine scaling. + code_language (`str`): + Language for code format reward. + max_completion_len (`int`): + Maximum number of tokens in completion. + soft_punish_cache (`int`): + Minimum number of tokens in completion. + """ + + reward_funcs: list[str] = field( + default_factory=lambda: ["accuracy", "format", "tag_count"], + metadata={ + "help": "List of reward functions. Possible values: 'accuracy', 'format', 'reasoning_steps', 'cosine', 'repetition_penalty', 'length', tag_count', 'code', 'code_format'" + }, + ) + cosine_min_value_wrong: float = field( + default=0.0, + metadata={"help": "Minimum reward for wrong answers"}, + ) + cosine_max_value_wrong: float = field( + default=-0.5, + metadata={"help": "Maximum reward for wrong answers"}, + ) + cosine_min_value_correct: float = field( + default=0.5, + metadata={"help": "Minimum reward for correct answers"}, + ) + cosine_max_value_correct: float = field( + default=1.0, + metadata={"help": "Maximum reward for correct answers"}, + ) + cosine_max_len: int = field( + default=1000, + metadata={"help": "Maximum length for scaling"}, + ) + repetition_n_grams: int = field( + default=3, + metadata={"help": "Number of n-grams for repetition penalty reward"}, + ) + repetition_max_penalty: float = field( + default=-1.0, + metadata={"help": "Maximum (negative) penalty for for repetition penalty reward"}, + ) + code_language: str = field( + default="python", + # '(?:python|cpp)' + metadata={ + "help": "Language for code format reward. Based on E2B supported languages https://e2b.dev/docs/code-interpreting/supported-languages", + "choices": ["python", "javascript", "r", "java", "bash", "cpp"], + }, + ) + code_eval_test_batch_size: int = field( + default=1, + metadata={ + "help": "for each generation, evaluate these many test cases in parallel, then check if any of them failed (0 score): if so stop evaluating; otherwise continue with the next batch of test cases. Useful to avoid overloading the eval server + save time on wrong solutions" + }, + ) + code_eval_scoring_mode: Literal["pass_fail", "partial", "weighted_sum"] = field( + default="weighted_sum", + metadata={"help": "use fraction of passed test cases as reward. If false, use 0/1 scoring."}, + ) + parallel_code_exec_per_proc: int = field( + default=2, + metadata={ + "help": "Number of parallel E2B code executions per process. Default of 2 is suitable for the Free Hobby tier of E2B with 8 GPUs used for training." + }, + ) + + dataset_prompt_column: str = field( + default="prompt", + metadata={"help": "Column to use as prompts for training."}, + ) + + e2b_router_url: Optional[str] = field( + default=None, + metadata={"help": "URL for the E2B router. See scripts/e2b_router.py"}, + ) + + morph_router_url: Optional[str] = field( + default=None, + metadata={"help": "URL for the MorphCloud router. See scripts/morph_router.py"}, + ) + + code_provider: Optional[str] = field( + default="e2b", + metadata={ + "help": "Provider for code execution. Options: 'e2b', 'local', 'morph'.", + "choices": ["e2b", "local", "morph"], + }, + ) + + ioi_provider: Optional[str] = field( + default="piston", + metadata={ + "help": "Provider for IOI code execution. Options: 'piston', 'morph'.", + "choices": ["piston", "morph"], + }, + ) + + max_completion_len: int = field( + default=16384, + metadata={"help": "Maximum number of characters in completion."}, + ) + soft_punish_cache: int = field( + default=4096, + metadata={"help": "Minimum number of characters in completion."}, + ) diff --git a/open_r1/src/mind_openr1/sft.py b/open_r1/src/mind_openr1/sft.py new file mode 100644 index 000000000..c51f6f82f --- /dev/null +++ b/open_r1/src/mind_openr1/sft.py @@ -0,0 +1,300 @@ +import logging +import os +import sys +from dataclasses import dataclass +from typing import Optional + +import mindspore +from mindspore import context as ms_context +import mindnlp + +import datasets +from mindnlp.transformers import ( + set_seed, + AutoTokenizer, + AutoModelForCausalLM, + get_last_checkpoint +) + +from mind_openr1.configs import ScriptArguments +from mind_openr1.sft_trainer import SFTTrainer, SFTConfig +from mind_openr1.utils import get_dataset +from mind_openr1.utils.callbacks import get_callbacks + +ms_context.set_context(mode=ms_context.PYNATIVE_MODE) +logger = logging.getLogger(__name__) + + +@dataclass +class ModelConfig: + """Model configuration compatible with mindnlp""" + model_name_or_path: str + model_revision: str = "main" + trust_remote_code: bool = False + use_flash_attention_2: bool = False + lora_r: Optional[int] = None + lora_alpha: Optional[int] = None + lora_dropout: Optional[float] = None + lora_target_modules: Optional[list] = None + use_peft: bool = False + + +def get_tokenizer_mindnlp(model_args: ModelConfig): + """Get tokenizer using mindnlp""" + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path, + revision=model_args.model_revision, + trust_remote_code=model_args.trust_remote_code, + ) + + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + return tokenizer + + +def get_model_mindnlp(model_args: ModelConfig): + """Get model using mindnlp""" + model = AutoModelForCausalLM.from_pretrained( + model_args.model_name_or_path, + revision=model_args.model_revision, + trust_remote_code=model_args.trust_remote_code, + ms_dtype=mindspore.float16 if model_args.use_flash_attention_2 else mindspore.float32, + ) + + return model + + +def get_peft_config_dict(model_args: ModelConfig): + """Get PEFT configuration if enabled""" + if not model_args.use_peft: + return None + + peft_config = { + "r": model_args.lora_r or 16, + "lora_alpha": model_args.lora_alpha or 32, + "lora_dropout": model_args.lora_dropout or 0.1, + "target_modules": model_args.lora_target_modules or ["q_proj", "v_proj"], + "bias": "none", + "task_type": "CAUSAL_LM", + } + + return peft_config + + +def setup_chat_format(model, tokenizer): + """Setup chat format for model and tokenizer""" + if tokenizer.chat_template is None: + logger.info("No chat template provided, setting up ChatML format") + # Simple ChatML template + tokenizer.chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" + + # Add special tokens if needed + special_tokens = { + "additional_special_tokens": ["<|im_start|>", "<|im_end|>"] + } + tokenizer.add_special_tokens(special_tokens) + + # Resize model embeddings + model.resize_token_embeddings(len(tokenizer)) + + return model, tokenizer + + +def main(script_args, training_args, model_args): + set_seed(training_args.seed) + + ############### + # Setup logging + ############### + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout)], + ) + log_level = logging.INFO if training_args.logging_steps > 0 else logging.WARNING + logger.setLevel(log_level) + datasets.utils.logging.set_verbosity(log_level) + + logger.info(f"Model parameters {model_args}") + logger.info(f"Script parameters {script_args}") + logger.info(f"Training parameters {training_args}") + + # Check for last checkpoint + last_checkpoint = None + if os.path.isdir(training_args.output_dir): + last_checkpoint = get_last_checkpoint(training_args.output_dir) + if last_checkpoint is not None and training_args.resume_from_checkpoint is None: + logger.info(f"Checkpoint detected, resuming training at {last_checkpoint}.") + + ###################################### + # Load dataset, tokenizer, and model # + ###################################### + dataset = get_dataset(script_args) + + # Optionally truncate training split if max_train_samples is provided + if getattr(script_args, "max_train_samples", None): + train_split = script_args.dataset_train_split + max_n = int(script_args.max_train_samples) + if max_n > 0: + dataset[train_split] = dataset[train_split].select(range(min(max_n, len(dataset[train_split])))) + + tokenizer = get_tokenizer_mindnlp(model_args) + model = get_model_mindnlp(model_args) + + # Setup chat format if needed + model, tokenizer = setup_chat_format(model, tokenizer) + + ############################ + # Initialize the SFT Trainer + ############################ + trainer = SFTTrainer( + model=model, + args=training_args, + train_dataset=dataset[script_args.dataset_train_split], + eval_dataset=(dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None), + processing_class=tokenizer, + peft_config=get_peft_config_dict(model_args), + callbacks=get_callbacks(training_args, model_args), + ) + + ############### + # Training loop + ############### + logger.info("*** Train ***") + checkpoint = None + if training_args.resume_from_checkpoint is not None: + checkpoint = training_args.resume_from_checkpoint + elif last_checkpoint is not None: + checkpoint = last_checkpoint + + train_result = trainer.train(resume_from_checkpoint=checkpoint) + + metrics = train_result + metrics["train_samples"] = len(dataset[script_args.dataset_train_split]) + trainer.log_metrics("train", metrics) + trainer.save_metrics("train", metrics) + trainer.save_state() + + ################################## + # Save model and create model card + ################################## + logger.info("*** Save model ***") + + # Save model + trainer.save_model(training_args.output_dir) + logger.info(f"Model saved to {training_args.output_dir}") + + # Save everything else on main process + kwargs = { + "dataset_name": script_args.dataset_name, + "tags": ["open-r1", "mindspore", "mindnlp"], + } + + # Create model card + trainer.create_model_card(**kwargs) + + ########## + # Evaluate + ########## + if training_args.do_eval: + logger.info("*** Evaluate ***") + metrics = trainer.evaluate() + metrics["eval_samples"] = len(dataset[script_args.dataset_test_split]) + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) + + ############# + # push to hub + ############# + if training_args.push_to_hub: + logger.info("Pushing to hub...") + trainer.push_to_hub(**kwargs) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + + # Script arguments + parser.add_argument("--dataset_name", type=str, required=True) + parser.add_argument("--dataset_config", type=str, default=None) + parser.add_argument("--dataset_train_split", type=str, default="train") + parser.add_argument("--dataset_test_split", type=str, default="test") + parser.add_argument("--max_train_samples", type=int, default=None) + + # Model arguments + parser.add_argument("--model_name_or_path", type=str, required=True) + parser.add_argument("--model_revision", type=str, default="main") + parser.add_argument("--trust_remote_code", action="store_true") + parser.add_argument("--use_flash_attention_2", action="store_true") + parser.add_argument("--use_peft", action="store_true") + parser.add_argument("--lora_r", type=int, default=16) + parser.add_argument("--lora_alpha", type=int, default=32) + parser.add_argument("--lora_dropout", type=float, default=0.1) + parser.add_argument("--lora_target_modules", type=str, nargs="+", default=None) + + # Training arguments + parser.add_argument("--output_dir", type=str, required=True) + parser.add_argument("--num_train_epochs", type=int, default=3) + parser.add_argument("--per_device_train_batch_size", type=int, default=8) + parser.add_argument("--per_device_eval_batch_size", type=int, default=8) + parser.add_argument("--learning_rate", type=float, default=5e-5) + parser.add_argument("--weight_decay", type=float, default=0.0) + parser.add_argument("--max_seq_length", type=int, default=512) + parser.add_argument("--logging_steps", type=int, default=10) + parser.add_argument("--save_steps", type=int, default=500) + parser.add_argument("--eval_steps", type=int, default=500) + parser.add_argument("--eval_strategy", type=str, default="steps") + parser.add_argument("--max_steps", type=int, default=-1) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--do_eval", action="store_true") + parser.add_argument("--push_to_hub", action="store_true") + parser.add_argument("--resume_from_checkpoint", type=str, default=None) + parser.add_argument("--dataset_text_field", type=str, default="text") + + args = parser.parse_args() + + # Create config objects + script_args = ScriptArguments( + dataset_name=args.dataset_name, + dataset_config=args.dataset_config, + dataset_train_split=args.dataset_train_split, + dataset_test_split=args.dataset_test_split, + max_train_samples=args.max_train_samples, + ) + + model_args = ModelConfig( + model_name_or_path=args.model_name_or_path, + model_revision=args.model_revision, + trust_remote_code=args.trust_remote_code, + use_flash_attention_2=args.use_flash_attention_2, + use_peft=args.use_peft, + lora_r=args.lora_r, + lora_alpha=args.lora_alpha, + lora_dropout=args.lora_dropout, + lora_target_modules=args.lora_target_modules, + ) + + training_args = SFTConfig( + output_dir=args.output_dir, + num_train_epochs=args.num_train_epochs, + per_device_train_batch_size=args.per_device_train_batch_size, + per_device_eval_batch_size=args.per_device_eval_batch_size, + learning_rate=args.learning_rate, + weight_decay=args.weight_decay, + max_seq_length=args.max_seq_length, + logging_steps=args.logging_steps, + save_steps=args.save_steps, + eval_steps=args.eval_steps, + eval_strategy=args.eval_strategy, + max_steps=args.max_steps, + seed=args.seed, + do_eval=args.do_eval, + push_to_hub=args.push_to_hub, + resume_from_checkpoint=args.resume_from_checkpoint, + dataset_text_field=args.dataset_text_field, + ) + + main(script_args, training_args, model_args) \ No newline at end of file diff --git a/open_r1/src/mind_openr1/sft_trainer.py b/open_r1/src/mind_openr1/sft_trainer.py new file mode 100644 index 000000000..b0b349787 --- /dev/null +++ b/open_r1/src/mind_openr1/sft_trainer.py @@ -0,0 +1,449 @@ +""" +Supervised Fine-tuning Trainer for MindSpore/MindNLP +""" +import logging +import os +import sys +from typing import Dict, List, Optional, Union, Any, Callable +from dataclasses import dataclass, field + +import mindspore +from mindspore import nn, ops, Tensor +from mindspore.dataset import GeneratorDataset +import mindspore.context as ms_context +import mindspore.communication as comm + +import datasets +from mindnlp.transformers import ( + AutoTokenizer, + AutoModelForCausalLM, + PreTrainedTokenizer, + PreTrainedModel, + TrainingArguments as BaseTrainingArguments +) + +logger = logging.getLogger(__name__) + + +@dataclass +class SFTConfig(BaseTrainingArguments): + """ + Configuration class for SFT training specific parameters. + Inherits from mindnlp TrainingArguments. + """ + max_seq_length: int = field( + default=512, + metadata={"help": "Maximum sequence length for input"} + ) + dataset_text_field: str = field( + default="text", + metadata={"help": "Field name containing text in the dataset"} + ) + packing: bool = field( + default=False, + metadata={"help": "Whether to pack multiple examples in a single sequence"} + ) + dataset_train_split: str = field( + default="train", + metadata={"help": "Name of the training data split"} + ) + dataset_test_split: str = field( + default="test", + metadata={"help": "Name of the test data split"} + ) + + def __post_init__(self): + # Ensure output directory exists + if self.output_dir: + os.makedirs(self.output_dir, exist_ok=True) + + +class SFTTrainer: + """ + Supervised Fine-tuning Trainer for MindSpore/MindNLP + + This trainer handles the training loop for supervised fine-tuning of language models. + """ + + def __init__( + self, + model: Optional[PreTrainedModel] = None, + args: Optional[SFTConfig] = None, + train_dataset: Optional[Union[datasets.Dataset, GeneratorDataset]] = None, + eval_dataset: Optional[Union[datasets.Dataset, GeneratorDataset]] = None, + processing_class: Optional[PreTrainedTokenizer] = None, + peft_config: Optional[Dict] = None, + callbacks: Optional[List[Callable]] = None, + ): + self.model = model + self.args = args or SFTConfig() + self.train_dataset = train_dataset + self.eval_dataset = eval_dataset + self.tokenizer = processing_class + self.peft_config = peft_config + self.callbacks = callbacks or [] + + # Training state + self.global_step = 0 + self.epoch = 0 + self.best_metric = None + self.best_model_checkpoint = None + + # Setup + self._setup_model() + self._setup_optimizer() + self._setup_datasets() + + def _setup_model(self): + """Setup model for training""" + if self.model is None: + raise ValueError("Model must be provided") + + # Set model to training mode + self.model.set_train(True) + + # Apply PEFT config if provided + if self.peft_config: + logger.info("Applying PEFT configuration") + # TODO: Implement PEFT integration + + def _setup_optimizer(self): + """Setup optimizer and learning rate scheduler""" + # Get trainable parameters + trainable_params = self.model.trainable_params() + + # Create optimizer + if self.args.learning_rate is None: + self.args.learning_rate = 5e-5 + + self.optimizer = nn.Adam( + trainable_params, + learning_rate=self.args.learning_rate, + beta1=0.9, + beta2=0.999, + eps=1e-8, + weight_decay=self.args.weight_decay + ) + + def _setup_datasets(self): + """Setup datasets for training""" + if self.train_dataset is None: + raise ValueError("Training dataset must be provided") + + # Process datasets if needed + self.train_dataset = self._prepare_dataset(self.train_dataset, is_train=True) + if self.eval_dataset is not None: + self.eval_dataset = self._prepare_dataset(self.eval_dataset, is_train=False) + + def _prepare_dataset(self, dataset, is_train=True): + """Prepare dataset for training/evaluation""" + # If it's already a GeneratorDataset, return as is + if isinstance(dataset, GeneratorDataset): + return dataset + + # Convert HuggingFace dataset to MindSpore dataset + def generator(): + for item in dataset: + yield self._preprocess_function(item) + + column_names = ["input_ids", "attention_mask", "labels"] + + ms_dataset = GeneratorDataset( + generator, + column_names=column_names, + shuffle=is_train + ) + + # Batch the dataset + ms_dataset = ms_dataset.batch( + batch_size=self.args.per_device_train_batch_size if is_train else self.args.per_device_eval_batch_size, + drop_remainder=is_train + ) + + return ms_dataset + + def _preprocess_function(self, examples): + """Preprocess a single example""" + # Get text from the configured field + text = examples.get(self.args.dataset_text_field, "") + + # Tokenize + tokenized = self.tokenizer( + text, + truncation=True, + padding="max_length", + max_length=self.args.max_seq_length, + return_tensors="ms" + ) + + # For causal LM, labels are the same as input_ids + labels = tokenized["input_ids"].copy() + + # Replace padding token id with -100 for loss computation + if self.tokenizer.pad_token_id is not None: + labels[labels == self.tokenizer.pad_token_id] = -100 + + return { + "input_ids": tokenized["input_ids"], + "attention_mask": tokenized["attention_mask"], + "labels": labels + } + + def compute_loss(self, model, inputs): + """Compute loss for a batch of inputs""" + # Forward pass + outputs = model(**inputs) + + # Get loss + if hasattr(outputs, "loss"): + return outputs.loss + else: + # Compute loss manually if model doesn't return it + logits = outputs.logits + labels = inputs.get("labels") + + if labels is not None: + # Shift for causal LM + shift_logits = logits[..., :-1, :].reshape(-1, logits.shape[-1]) + shift_labels = labels[..., 1:].reshape(-1) + + # Compute cross entropy loss + loss_fn = nn.CrossEntropyLoss() + loss = loss_fn(shift_logits, shift_labels) + return loss + + return None + + def training_step(self, batch): + """Perform a single training step""" + # Convert batch to model inputs + inputs = { + "input_ids": batch[0], + "attention_mask": batch[1], + "labels": batch[2] + } + + # Forward pass and compute loss + loss = self.compute_loss(self.model, inputs) + + # Backward pass + grads = ops.grad(self.compute_loss, self.model.trainable_params())(self.model, inputs) + + # Update parameters + self.optimizer(grads) + + return loss + + def train(self, resume_from_checkpoint=None): + """Main training loop""" + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(self.train_dataset)}") + logger.info(f" Num Epochs = {self.args.num_train_epochs}") + logger.info(f" Batch size = {self.args.per_device_train_batch_size}") + logger.info(f" Total optimization steps = {self.args.max_steps}") + + # Resume from checkpoint if provided + if resume_from_checkpoint: + self._load_checkpoint(resume_from_checkpoint) + + # Training loop + for epoch in range(int(self.args.num_train_epochs)): + self.epoch = epoch + epoch_loss = 0.0 + num_batches = 0 + + # Iterate through batches + for step, batch in enumerate(self.train_dataset.create_tuple_iterator()): + loss = self.training_step(batch) + + epoch_loss += loss.asnumpy() + num_batches += 1 + self.global_step += 1 + + # Logging + if self.global_step % self.args.logging_steps == 0: + avg_loss = epoch_loss / num_batches + logger.info(f"Step: {self.global_step}, Loss: {avg_loss:.4f}") + self.log_metrics("train", {"loss": avg_loss}) + + # Save checkpoint + if self.global_step % self.args.save_steps == 0: + self.save_checkpoint() + + # Evaluation + if self.args.eval_strategy != "no" and self.global_step % self.args.eval_steps == 0: + self.evaluate() + + # Check max steps + if self.args.max_steps > 0 and self.global_step >= self.args.max_steps: + break + + # End of epoch + avg_epoch_loss = epoch_loss / num_batches + logger.info(f"Epoch {epoch} completed. Average Loss: {avg_epoch_loss:.4f}") + + # Run callbacks + for callback in self.callbacks: + callback(self, epoch=epoch, loss=avg_epoch_loss) + + if self.args.max_steps > 0 and self.global_step >= self.args.max_steps: + break + + # Save final model + self.save_model() + + return {"global_step": self.global_step} + + def evaluate(self): + """Evaluation loop""" + if self.eval_dataset is None: + return {} + + logger.info("***** Running evaluation *****") + self.model.set_train(False) + + total_loss = 0.0 + num_batches = 0 + + for batch in self.eval_dataset.create_tuple_iterator(): + inputs = { + "input_ids": batch[0], + "attention_mask": batch[1], + "labels": batch[2] + } + + loss = self.compute_loss(self.model, inputs) + total_loss += loss.asnumpy() + num_batches += 1 + + avg_loss = total_loss / num_batches + logger.info(f"Evaluation Loss: {avg_loss:.4f}") + + self.model.set_train(True) + + metrics = {"eval_loss": avg_loss} + self.log_metrics("eval", metrics) + + return metrics + + def save_model(self, output_dir=None): + """Save model to disk""" + output_dir = output_dir or self.args.output_dir + os.makedirs(output_dir, exist_ok=True) + + # Save model weights + save_checkpoint(self.model, os.path.join(output_dir, "model.ckpt")) + + # Save tokenizer + if self.tokenizer is not None: + self.tokenizer.save_pretrained(output_dir) + + # Save training arguments + with open(os.path.join(output_dir, "training_args.json"), "w") as f: + import json + json.dump(vars(self.args), f, indent=2) + + logger.info(f"Model saved to {output_dir}") + + def save_checkpoint(self): + """Save training checkpoint""" + checkpoint_dir = os.path.join(self.args.output_dir, f"checkpoint-{self.global_step}") + self.save_model(checkpoint_dir) + + # Save optimizer state + save_checkpoint(self.optimizer, os.path.join(checkpoint_dir, "optimizer.ckpt")) + + # Save training state + state = { + "global_step": self.global_step, + "epoch": self.epoch, + "best_metric": self.best_metric, + } + with open(os.path.join(checkpoint_dir, "trainer_state.json"), "w") as f: + import json + json.dump(state, f, indent=2) + + def _load_checkpoint(self, checkpoint_path): + """Load checkpoint from disk""" + # Load model weights + load_checkpoint(os.path.join(checkpoint_path, "model.ckpt"), self.model) + + # Load optimizer state + if os.path.exists(os.path.join(checkpoint_path, "optimizer.ckpt")): + load_checkpoint(os.path.join(checkpoint_path, "optimizer.ckpt"), self.optimizer) + + # Load training state + state_path = os.path.join(checkpoint_path, "trainer_state.json") + if os.path.exists(state_path): + with open(state_path, "r") as f: + import json + state = json.load(f) + self.global_step = state.get("global_step", 0) + self.epoch = state.get("epoch", 0) + self.best_metric = state.get("best_metric") + + logger.info(f"Resumed from checkpoint: {checkpoint_path}") + + def log_metrics(self, split, metrics): + """Log metrics""" + # Simple console logging + log_str = f"[{split}] Step {self.global_step}: " + log_str += ", ".join([f"{k}={v:.4f}" for k, v in metrics.items()]) + logger.info(log_str) + + def save_metrics(self, split, metrics): + """Save metrics to file""" + metrics_file = os.path.join(self.args.output_dir, f"{split}_metrics.json") + with open(metrics_file, "w") as f: + import json + json.dump(metrics, f, indent=2) + + def save_state(self): + """Save trainer state""" + state_file = os.path.join(self.args.output_dir, "trainer_state.json") + state = { + "global_step": self.global_step, + "epoch": self.epoch, + "best_metric": self.best_metric, + } + with open(state_file, "w") as f: + import json + json.dump(state, f, indent=2) + + def create_model_card(self, **kwargs): + """Create model card for the trained model""" + # Simple model card creation + model_card = f""" +# Model Card + +## Model Details +- Model type: Causal Language Model +- Training framework: MindSpore/MindNLP +- Dataset: {kwargs.get('dataset_name', 'Unknown')} + +## Training Details +- Number of epochs: {self.args.num_train_epochs} +- Batch size: {self.args.per_device_train_batch_size} +- Learning rate: {self.args.learning_rate} +- Total steps: {self.global_step} + +## Tags +{kwargs.get('tags', [])} +""" + + with open(os.path.join(self.args.output_dir, "README.md"), "w") as f: + f.write(model_card) + + def push_to_hub(self, **kwargs): + """Push model to model hub (placeholder)""" + logger.warning("push_to_hub is not implemented for MindSpore models yet") + + +def save_checkpoint(model, path): + """Save model checkpoint""" + mindspore.save_checkpoint(model, path) + + +def load_checkpoint(path, model): + """Load model checkpoint""" + mindspore.load_checkpoint(path, model) diff --git a/open_r1/src/mind_openr1/utils/__init__.py b/open_r1/src/mind_openr1/utils/__init__.py new file mode 100644 index 000000000..be8bdcef0 --- /dev/null +++ b/open_r1/src/mind_openr1/utils/__init__.py @@ -0,0 +1,8 @@ +from .data import get_dataset +from .import_utils import is_e2b_available, is_morph_available +from .model_utils import get_model, get_tokenizer +from .callbacks import get_callbacks +from .wandb_logging import init_wandb_training + + +__all__ = ["get_tokenizer", "is_e2b_available", "is_morph_available", "get_model", "get_dataset", "get_callbacks", "init_wandb_training"] diff --git a/open_r1/src/mind_openr1/utils/callbacks.py b/open_r1/src/mind_openr1/utils/callbacks.py new file mode 100644 index 000000000..88e656243 --- /dev/null +++ b/open_r1/src/mind_openr1/utils/callbacks.py @@ -0,0 +1,92 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import subprocess +from typing import List + +from transformers import TrainerCallback +from transformers.trainer_callback import TrainerControl, TrainerState +from transformers.training_args import TrainingArguments + +from .evaluation import run_benchmark_jobs +from .hub import push_to_hub_revision + + +def is_slurm_available() -> bool: + # returns true if a slurm queueing system is available + try: + subprocess.run(["sinfo"], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + return True + except FileNotFoundError: + return False + + +class DummyConfig: + def __init__(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + + +class PushToHubRevisionCallback(TrainerCallback): + def __init__(self, model_config) -> None: + self.model_config = model_config + + def on_save( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + if state.is_world_process_zero: + global_step = state.global_step + + # WARNING: if you use dataclasses.replace(args, ...) the accelerator dist state will be broken, so I do this workaround + # Also if you instantiate a new SFTConfig, the accelerator dist state will be broken + dummy_config = DummyConfig( + hub_model_id=args.hub_model_id, + hub_model_revision=f"{args.hub_model_revision}-step-{global_step:09d}", + output_dir=f"{args.output_dir}/checkpoint-{global_step}", + system_prompt=args.system_prompt, + ) + + future = push_to_hub_revision( + dummy_config, extra_ignore_patterns=["*.pt"] + ) # don't push the optimizer states + + if is_slurm_available(): + dummy_config.benchmarks = args.benchmarks + + def run_benchmark_callback(_): + print(f"Checkpoint {global_step} pushed to hub.") + run_benchmark_jobs(dummy_config, self.model_config) + + future.add_done_callback(run_benchmark_callback) + + +CALLBACKS = { + "push_to_hub_revision": PushToHubRevisionCallback, +} + + +def get_callbacks(train_config, model_config) -> List[TrainerCallback]: + callbacks = [] + for callback_name in train_config.callbacks: + if callback_name not in CALLBACKS: + raise ValueError(f"Callback {callback_name} not found in CALLBACKS.") + callbacks.append(CALLBACKS[callback_name](model_config)) + + return callbacks diff --git a/open_r1/src/mind_openr1/utils/code_providers.py b/open_r1/src/mind_openr1/utils/code_providers.py new file mode 100644 index 000000000..71830e6ae --- /dev/null +++ b/open_r1/src/mind_openr1/utils/code_providers.py @@ -0,0 +1,366 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Code execution providers for executing and evaluating code snippets.""" + +import abc +import asyncio +from typing import List, Optional + +from ..utils import is_e2b_available, is_morph_available + + +if is_e2b_available(): + from e2b_code_interpreter import AsyncSandbox + from e2b_code_interpreter.models import Execution + + from .routed_sandbox import RoutedSandbox +else: + AsyncSandbox = None + Execution = None + RoutedSandbox = None + +if is_morph_available(): + from morphcloud.api import MorphCloudClient + from morphcloud.sandbox import Sandbox + + from .routed_morph import RoutedMorphSandbox +else: + MorphCloudClient = None + Sandbox = None + RoutedMorphSandbox = None + + +class CodeExecutionProvider(abc.ABC): + """Abstract base class for code execution providers.""" + + @abc.abstractmethod + def execute_scripts(self, scripts: List[str], languages: List[str]) -> List[float]: + """Execute multiple scripts and return their reward values. + + Args: + scripts: List of code scripts to execute + language: The programming language of the scripts + + Returns: + List of float rewards (one per script) + """ + pass + + +class E2BProvider(CodeExecutionProvider): + """Provider that executes code using E2B sandboxes.""" + + def __init__(self, num_parallel: int = 2, e2b_router_url: Optional[str] = None): + """Initialize the E2B provider. + + Args: + num_parallel: Number of parallel sandboxes to use + e2b_router_url: URL for the E2B router (if using router mode) + """ + if not is_e2b_available(): + raise ImportError( + "E2B is not available and required for this provider. Please install E2B with " + "`pip install e2b-code-interpreter` and add an API key to a `.env` file." + ) + + self.num_parallel = num_parallel + self.e2b_router_url = e2b_router_url + + def execute_scripts(self, scripts: List[str], languages: List[str]) -> List[float]: + """Execute scripts using E2B sandboxes. + + If e2b_router_url is provided, uses the RoutedSandbox for batch processing. + Otherwise, uses direct AsyncSandbox with parallelization. + """ + if self.e2b_router_url is not None: + routed_sandbox = RoutedSandbox(router_url=self.e2b_router_url) + + executions = routed_sandbox.run_code( + scripts=scripts, + languages=languages, + timeout=30, + request_timeout=28, + ) + + rewards = [] + for execution in executions: + try: + reward = float(execution.text) + rewards.append(reward) + except Exception: + rewards.append(None) + return rewards + + try: + rewards = self._run_async_from_sync(scripts, languages, self.num_parallel) + except Exception as e: + print(f"Error from E2B executor: {e}") + rewards = [0.0] * len(scripts) + + return rewards + + def _run_async_from_sync(self, scripts: List[str], languages: List[str], num_parallel: int) -> List[float]: + """Function wrapping the `_run_async` function.""" + try: + rewards = asyncio.run(self._run_async(scripts, languages, num_parallel)) + except Exception as e: + print(f"Error from E2B executor async: {e}") + raise e + + return rewards + + async def _run_async(self, scripts: List[str], languages: List[str], num_parallel: int) -> List[float]: + semaphore = asyncio.Semaphore(num_parallel) + + tasks = [self._run_script(script, languages, semaphore) for script in scripts] + + results = await asyncio.gather(*tasks) + rewards = list(results) + + return rewards + + async def _run_script(self, script: str, languages: List[str], semaphore: asyncio.Semaphore) -> float: + # We set a timeout margin, as the AsyncSandbox timeout does not seem to work + # These values are based on running 256 examples with the gold solution + # from open-r1/verifiable-coding-problems-python_decontaminated + # see scripts/benchmark_e2b.py + + SANDBOX_TIMEOUT = 30 + MARGIN = 2 + REQUEST_TIMEOUT = SANDBOX_TIMEOUT - MARGIN + ASYNCIO_TIMEOUT = SANDBOX_TIMEOUT + MARGIN + + async with semaphore: + try: + sandbox = await AsyncSandbox.create(timeout=SANDBOX_TIMEOUT, request_timeout=REQUEST_TIMEOUT) + execution = await asyncio.wait_for( + sandbox.run_code(script, languages=languages), + timeout=ASYNCIO_TIMEOUT, + ) + return float(execution.text) + except (TypeError, ValueError): + return 0.0 + except asyncio.TimeoutError: + print("Operation timed out") + return 0.0 + except Exception as e: + print(f"Error in `_run_script` from E2B sandbox ID {sandbox.sandbox_id} : {e}") + return 0.0 + finally: + try: + await sandbox.kill() + except Exception as e: + print(f"Error from E2B executor kill with sandbox ID {sandbox.sandbox_id} : {e}") + + +class MorphProvider(CodeExecutionProvider): + """Provider that executes code using MorphCloud's Sandbox API.""" + + def __init__(self, num_parallel: int = 2, morph_router_url: Optional[str] = None): + """Initialize the Morph provider. + + Args: + num_parallel: Number of parallel executions to use + morph_router_url: URL for the MorphCloud router (if using router mode) + """ + if not is_morph_available(): + raise ImportError( + "MorphCloud is not available and required for this provider. Please install MorphCloud with " + "`pip install morphcloud` and add an API key to a `.env` file." + ) + + try: + from dotenv import load_dotenv + + load_dotenv() + except ImportError: + print("Warning: python-dotenv not installed. Environment variables must be set directly.") + + self.num_parallel = num_parallel + self.morph_router_url = morph_router_url + + if self.morph_router_url is not None: + self.routed_sandbox = RoutedMorphSandbox(router_url=self.morph_router_url) + return + + import os + + self.api_key = os.getenv("MORPH_API_KEY") + if not self.api_key: + raise ValueError("MorphCloud API key not found. Please set the MORPH_API_KEY environment variable.") + + try: + self.client = MorphCloudClient(api_key=self.api_key) + self.Sandbox = Sandbox + except ImportError as e: + raise ImportError(f"Required MorphCloud dependencies not installed: {e}") + + def execute_scripts(self, scripts: List[str], languages: List[str]) -> List[float]: + """Execute scripts using MorphCloud Sandbox API. + + Args: + scripts: List of Python scripts to execute + language: Programming language + + Returns: + List of float rewards (one per script) + """ + + if hasattr(self, "routed_sandbox"): + try: + results = self.routed_sandbox.run_code( + scripts=scripts, + languages=languages, + timeout=90, + request_timeout=96, + ) + + rewards = [] + for result in results: + try: + reward = float(result.text) + rewards.append(reward) + except (ValueError, AttributeError): + rewards.append(0.0) + return rewards + except Exception as e: + print(f"Error from MorphCloud router: {e}") + return [0.0] * len(scripts) + + import asyncio + + try: + rewards = asyncio.run(self._run_async(scripts, languages, self.num_parallel)) + except Exception as e: + print(f"Error from MorphCloud executor: {e}") + rewards = [0.0] * len(scripts) + + return rewards + + async def _run_async(self, scripts: List[str], languages: List[str], num_parallel: int) -> List[float]: + """Run multiple scripts concurrently with limited parallelism. + + Args: + scripts: List of scripts to execute + language: Programming language + num_parallel: Maximum number of concurrent executions + + Returns: + List of rewards + """ + + semaphore = asyncio.Semaphore(num_parallel) + + tasks = [self._run_script(script, languages, semaphore) for script in scripts] + + results = await asyncio.gather(*tasks) + + return list(results) + + async def _run_script(self, script: str, languages: List[str], semaphore: asyncio.Semaphore) -> float: + """Execute a single script in a MorphCloud Sandbox. + + Args: + script: The script to execute + language: Programming language + semaphore: Semaphore to limit concurrency + + Returns: + Float reward from script execution + """ + SANDBOX_TIMEOUT = 90 + MARGIN = 6 + ASYNCIO_TIMEOUT = SANDBOX_TIMEOUT + MARGIN + + sandbox = None + async with semaphore: + try: + sandbox = await asyncio.to_thread(self.Sandbox.new, client=self.client, ttl_seconds=SANDBOX_TIMEOUT) + result = await asyncio.wait_for( + asyncio.to_thread( + sandbox.run_code, + script, + languages=languages, + timeout=SANDBOX_TIMEOUT, + ), + timeout=ASYNCIO_TIMEOUT, + ) + + reward = 0.0 + try: + if hasattr(result, "text") and result.text: + lines = result.text.strip().split("\n") + if lines: + try: + reward = float(lines[-1]) + except ValueError: + try: + reward = float(result.text.strip()) + except ValueError: + pass + elif hasattr(result, "stdout") and result.stdout: + lines = result.stdout.strip().split("\n") + if lines: + try: + reward = float(lines[-1]) + except ValueError: + pass + except (ValueError, AttributeError): + pass + + return reward + + except asyncio.TimeoutError: + return 0.0 + except Exception: + return 0.0 + finally: + if sandbox: + try: + await asyncio.to_thread(sandbox.close) + await asyncio.to_thread(sandbox.shutdown) + except Exception: + pass + + +def get_provider(provider_type: str = "e2b", **kwargs) -> CodeExecutionProvider: + """Factory function to get the appropriate code execution provider. + + Args: + provider_type: Type of provider to use ("e2b", "morph") + **kwargs: Additional arguments to pass to the provider + + Returns: + An instance of CodeExecutionProvider + """ + num_parallel = kwargs.pop("num_parallel", 2) + + if provider_type == "e2b": + # Extract E2B-specific arguments + e2b_router_url = kwargs.pop("e2b_router_url", None) + return E2BProvider( + num_parallel=num_parallel, + e2b_router_url=e2b_router_url, + ) + elif provider_type == "morph": + # Extract Morph-specific arguments + morph_router_url = kwargs.pop("morph_router_url", None) + return MorphProvider( + num_parallel=num_parallel, + morph_router_url=morph_router_url, + ) + else: + raise ValueError(f"Unknown provider type: {provider_type}") diff --git a/open_r1/src/mind_openr1/utils/competitive_programming/__init__.py b/open_r1/src/mind_openr1/utils/competitive_programming/__init__.py new file mode 100644 index 000000000..081e16fea --- /dev/null +++ b/open_r1/src/mind_openr1/utils/competitive_programming/__init__.py @@ -0,0 +1,19 @@ +from .cf_scoring import score_submission +from .code_patcher import patch_code +from .ioi_scoring import SubtaskResult, score_subtask, score_subtasks +from .ioi_utils import add_includes +from .morph_client import get_morph_client_from_env +from .piston_client import get_piston_client_from_env, get_slurm_piston_endpoints + + +__all__ = [ + "get_piston_client_from_env", + "get_slurm_piston_endpoints", + "get_morph_client_from_env", + "patch_code", + "score_submission", + "score_subtask", + "score_subtasks", + "add_includes", + "SubtaskResult", +] diff --git a/open_r1/src/mind_openr1/utils/competitive_programming/cf_scoring.py b/open_r1/src/mind_openr1/utils/competitive_programming/cf_scoring.py new file mode 100644 index 000000000..d3ede4f7e --- /dev/null +++ b/open_r1/src/mind_openr1/utils/competitive_programming/cf_scoring.py @@ -0,0 +1,146 @@ +import asyncio +import os +from io import BytesIO +from typing import Literal + +from async_lru import alru_cache + +from .piston_client import PistonClient +from .utils import batched + + +async def score_single_test_case( + client: PistonClient, + problem_data: dict, + test_input: str, + test_output: str, + submission: str, + submission_language: str = "cpp", +) -> tuple[str, str]: + if submission_language not in ["python", "cpp"]: + raise ValueError(f"Invalid submission language: {submission_language}") + try: + result = await client.send_execute( + { + "files": [ + {"name": f"main.{submission_language}", "content": submission}, + *( + [{"name": "checker.py", "content": problem_data["generated_checker"]}] + if problem_data["generated_checker"] + else [] + ), + {"name": "input.txt", "content": test_input}, + {"name": "correct_output.txt", "content": test_output}, + { + "name": "grader_config", + "content": "\n".join( + f"{key}={value}" + for key, value in { + "TIME_LIMIT": problem_data["time_limit"], + "MEMORY_LIMIT": problem_data["memory_limit"], + "INPUT_MODE": problem_data["input_mode"], + }.items() + ), + }, + ], + "run_timeout": (problem_data["time_limit"] + 10) * 1000, + # +10 seconds hard limit. time limits are handled by the codeforces script + }, + language="cf_python3" if submission_language == "python" else "c++17", + ) + except Exception as e: + print(f"Error scoring submission: {e}") + return False + + return result + + +@alru_cache(maxsize=32) # TODO make this configurable +async def get_generated_contest_tests(contest_id: str) -> list[dict]: + import pandas as pd + + import aiofiles + import aiofiles.os + + tests_folder = os.environ.get("CF_TESTS_FOLDER", None) + if not tests_folder: + raise ValueError( + "CF_TESTS_FOLDER environment variable not set! Please download the codeforces generated tests and set CF_TESTS_FOLDER to the folder path. See https://huggingface.co/datasets/open-r1/codeforces for more information." + ) + if not await aiofiles.os.path.exists(tests_folder): + raise ValueError( + f"CF_TESTS_FOLDER path '{tests_folder}' does not exist! Please download the codeforces generated tests and set CF_TESTS_FOLDER to the folder path. See https://huggingface.co/datasets/open-r1/codeforces for more information." + ) + parquet_path = os.path.join(tests_folder, f"test_cases_{int(contest_id):04d}.parquet") + if not await aiofiles.os.path.exists(parquet_path): + return {} + + # Read parquet file asynchronously + async with aiofiles.open(parquet_path, "rb") as f: + content = await f.read() + df = pd.read_parquet(BytesIO(content)) + + # Group by problem_id and convert to dictionary of lists + grouped_tests = df.groupby("problem_id").apply(lambda x: x[["input", "output"]].to_dict("records")).to_dict() + + return grouped_tests + + +async def get_generated_tests(problem_id: str) -> list[dict]: + contest_id = problem_id.split("/")[0] + return (await get_generated_contest_tests(contest_id)).get(problem_id, []) + + +async def score_submission( + client: PistonClient, + problem_data: dict, + submission: str, + test_batch_size: int = 1, + scoring_mode: Literal["pass_fail", "partial", "weighted_sum"] = "weighted_sum", + no_compile_reward: float = -0.1, + no_submission_reward: float = -1.0, + submission_language: str = "cpp", +) -> float: + if submission_language not in ["python", "cpp"]: + raise ValueError(f"Invalid submission language: {submission_language}") + test_cases = problem_data["official_tests"] + (await get_generated_tests(problem_data["id"])) + # invalid/not a coding problem + if test_cases is None or len(test_cases) == 0: + return None + # no code extracted + if not submission: + return no_submission_reward + + passed_test_cases = 0 + # run one batch, check if any of them failed (0 score): if so stop evaluating (assuming non partial score); otherwise continue with the next batch of test cases. + for test_batch_to_run in batched(test_cases, test_batch_size) if test_batch_size >= 1 else [test_cases]: + results = await asyncio.gather( + *[ + asyncio.create_task( + score_single_test_case( + client, problem_data, test_case["input"], test_case["output"], submission, submission_language + ) + ) + for test_case in test_batch_to_run + ] + ) + if any(result and result["compile"]["code"] != 0 for result in results): + return no_compile_reward + + tests_passed_results = [ + result and result["run"]["code"] == 0 and result["run"]["stdout"].strip() == "1" for result in results + ] + if scoring_mode == "pass_fail" and any(not test_passed for test_passed in tests_passed_results): + break + passed_test_cases += sum(1 for test_passed in tests_passed_results if test_passed) + + pass_fail_score = 1.0 if passed_test_cases == len(test_cases) else 0.0 + + if scoring_mode == "pass_fail": + return pass_fail_score + elif scoring_mode == "partial": + return passed_test_cases / len(test_cases) + elif scoring_mode == "weighted_sum": + return pass_fail_score + 0.1 * (passed_test_cases / len(test_cases)) + else: + raise ValueError(f"Invalid scoring mode: {scoring_mode}") diff --git a/open_r1/src/mind_openr1/utils/competitive_programming/code_patcher.py b/open_r1/src/mind_openr1/utils/competitive_programming/code_patcher.py new file mode 100644 index 000000000..4d5536020 --- /dev/null +++ b/open_r1/src/mind_openr1/utils/competitive_programming/code_patcher.py @@ -0,0 +1,123 @@ +import re + + +def fix_python3_imports(source_code): + """ + Fix common import and function changes between Python 3 versions + + Args: + source_code (str): The Python source code to update + + Returns: + str: The updated source code + """ + # Dictionary of patterns to replacements + replacements = [ + # Fix collections.abc imports (changed in Python 3.3+) + ( + r"from collections import (Mapping|Sequence|Set|Container|MutableMapping|MutableSet|MutableSequence)", + r"from collections.abc import \1", + ), + # Fix imp module deprecation (deprecated in 3.4) + (r"import imp", r"import importlib"), + # Fix asyncio.async() to asyncio.ensure_future() (renamed in 3.4.4) + (r"asyncio\.async\(", r"asyncio.ensure_future("), + # Fix inspect.getargspec to inspect.getfullargspec (deprecated in 3.5) + (r"inspect\.getargspec", r"inspect.getfullargspec"), + # Fix array.array 'c' type code to 'b' (removed in 3.9) + (r"array\.array\('c'", r"array.array('b'"), + # Fix backslash line continuation with multiple newlines (Python-specific issue) + (r"\\(\r\n|\r|\n)+", "\\\n"), + # some solutions use getlogin() to check if they are debugging or on an actual submission + (r"(?:os\s*\.\s*)?getlogin\s*\(\s*\)", "False"), + # Fix usage of fractions.gcd (moved to math in 3.5) + # 1. Fix direct usage: fractions.gcd -> math.gcd + (r"\bfractions\.gcd\b", r"math.gcd"), + # 2. Fix 'from fractions import gcd, X' -> 'from fractions import X' (start/middle) + (r"(from\s+fractions\s+import\s+(?:\([^)]*)?)\bgcd\s*,\s*", r"\1"), + # 3. Fix 'from fractions import X, gcd' -> 'from fractions import X' (end) + (r"(from\s+fractions\s+import\s+.*?\S)\s*,\s*\bgcd(\s*\)?\s*(?:#.*)?)", r"\1\2"), + # 4. Fix standalone 'from fractions import gcd' -> 'from math import gcd' + (r"from\s+fractions\s+import\s+\(?\s*gcd\s*\)?", r""), + # --- End: Replacement for the faulty line --- + ] + + lines = source_code.splitlines() + last_import = max( + [ + i + for i, line in enumerate(lines) + if line.strip().startswith("import") or (line.strip().startswith("from") and "import" in line) + ], + default=0, + ) + import_section = "\n".join(lines[: last_import + 1]) + main_source = "\n".join(lines[last_import:]) + + if "fractions.gcd" in source_code and "import math" not in source_code: + import_section += "\nimport math" + elif "gcd" in source_code and "from math import gcd" not in source_code: + import_section += "\nfrom math import gcd" + + if "set_int_max_str_digits" not in source_code: + import_section += "\nimport sys\nsys.set_int_max_str_digits(0)" + + source_code = import_section + "\n" + main_source + + # Apply each replacement + for pattern, replacement in replacements: + source_code = re.sub(pattern, replacement, source_code) + + source_code = source_code.rstrip("\\") + + return source_code + + +def fix_cpp_includes(source_code): + # has most of the useful functions + code_header = "#include \n" + # use namespace std since models forget std:: often + if "using namespace std;" not in source_code and "std::" not in source_code: + code_header += "\nusing namespace std;\n\n" + return code_header + source_code + + +def is_patchable(lang): + return lang in ("python", "python3", "Python 3", "PyPy 3", "PyPy 3-64", "cpp") or "C++" in lang + + +def patch_code(text, lang): + if not text: + return text + if lang in ("python", "python3", "Python 3", "PyPy 3", "PyPy 3-64"): + return fix_python3_imports(text) + elif "cpp" in lang or "C++" in lang: + return fix_cpp_includes(text) + return text + + +tests = [ + """read = lambda: map(int, input().split()) +n, m, z = read() +from fractions import gcd +ans = z // (n * m // gcd(n, m)) +print(ans)""", + """from fractions import Fraction,gcd + +a,b,c,d = [int(x) for x in input().split()] + +if a*d > b*c: + num = a*d-b*c + denom = a*d +else: + num = b*c-a*d + denom = b*c +div = gcd(num,denom) +print('%d/%d'%(num//div,denom//div))""", +] + +if __name__ == "__main__": + for test in tests: + print("ORIGINAL:", test, sep="\n\n") + print("PATCHED:", patch_code(test, "Python 3"), sep="\n\n") + print("=" * 50) diff --git a/open_r1/src/mind_openr1/utils/competitive_programming/ioi_scoring.py b/open_r1/src/mind_openr1/utils/competitive_programming/ioi_scoring.py new file mode 100644 index 000000000..357156c89 --- /dev/null +++ b/open_r1/src/mind_openr1/utils/competitive_programming/ioi_scoring.py @@ -0,0 +1,335 @@ +import asyncio +from dataclasses import asdict, dataclass, field +from typing import Union + +from .ioi_utils import load_ioi_tests +from .piston_client import PistonClient, PistonError +from .utils import batched + + +@dataclass +class TestResult: + """ + Represents the result of a single test case execution. + + Attributes: + test_name: Name of the test case + score: Score achieved for this test (0.0 to 1.0) + status: Status code of the test result (e.g., 'AC', 'WA', 'TLE') + feedback: Detailed feedback message from the judge or an error message + """ + + test_name: str + score: float = 0.0 + status: str = "SKIPPED" + feedback: str = None + + +@dataclass +class SubtaskResult: + """ + Represents the result of a subtask containing multiple test cases. + + Attributes: + problem: Problem identifier + subtask: Subtask identifier + points: Maximum points available for this subtask + score_precision: Number of decimal places for score rounding + test_results: List of individual test case results + """ + + problem: str = None + subtask: str = None + + points: float = 0.0 + score_precision: int = 2 + + test_results: list[TestResult] = field(default_factory=list) + + @property + def status(self): + """ + Determines the overall status of the subtask based on the worst status among test results. + Status priorities are ordered from worst to best. + + Returns: + str: The status with the highest priority (lowest value) + """ + status_prios = {"CE": -1, "RE": 0, "WA": 1, "MLE": 2, "TLE": 3, "PA": 4, "AC": 5, "SKIPPED": 999} + return min([x.status for x in self.test_results], key=lambda x: status_prios[x]) + + @property + def score(self): + """ + Calculates the raw score for the subtask as the minimum score across all test results. + + Returns: + float: The rounded minimum score + """ + return ( + 0 + if not self.test_results + else round(min([test_result.score for test_result in self.test_results]), self.score_precision) + ) + + @property + def weighted_score(self): + """ + Calculates the weighted score by multiplying the raw score by the available points. + + Returns: + float: The rounded weighted score + """ + return ( + 0 + if not self.test_results + else round( + min([test_result.score for test_result in self.test_results]) * self.points, self.score_precision + ) + ) + + def to_dict(self): + """ + Converts the SubtaskResult to a dictionary representation. + + Returns: + dict: Dictionary containing all subtask result data + """ + return { + "problem": self.problem, + "subtask": self.subtask, + "score": self.score, + "weighted_score": self.weighted_score, + "points": self.points, + "score_precision": self.score_precision, + "status": self.status, + "test_results": [asdict(test_result) for test_result in self.test_results], + } + + +def _extract_single_status(score: float, feedback: str) -> str: + """ + Determines the status code based on the score and feedback message. + + Args: + score: The numeric score (0.0 to 1.0) + feedback: The feedback message from the execution + + Returns: + str: Status code ('CE', 'MLE', 'TLE', 'WA', 'RE', 'AC', or 'PA') + """ + if score == 0.0: + if "Compilation error" in feedback: + return "CE" + elif "Memory limit exceeded" in feedback: + return "MLE" + elif "Time limit exceeded" in feedback: + return "TLE" + elif "Output isn't correct" in feedback: + return "WA" + else: + return "RE" + elif score == 1.0: + return "AC" + else: + return "PA" + + +async def score_single_test_case( + client: PistonClient, subtask: dict, test_name: str, test_input: str, test_output: str, submission: str +) -> TestResult: + """ + Scores a single test case by running the submission against the provided input and output. + + Args: + client: PistonClient instance for executing code + subtask: Dictionary containing subtask configuration + test_name: Name of the test case + test_input: Input data for the test case + test_output: Expected output for the test case + submission: Source code of the submission + + Returns: + TestResult: Result of the test case execution + """ + # Run submission for this test case + score, feedback = await run_submission(client, subtask, test_input, submission, test_output) + score = float(score) + + return TestResult( + test_name=test_name, score=score, status=_extract_single_status(score, feedback), feedback=feedback + ) + + +async def score_subtask( + client: PistonClient, + subtask: dict, + submission: str, + test_case_run_cache: Union[dict, None] = None, + test_batch_size: int = 1, +) -> SubtaskResult: + """ + Scores all test cases in a subtask. + + Args: + client: PistonClient instance for executing code + subtask: Dictionary containing subtask configuration + test_cases: Dictionary mapping test names to (input, output) tuples + submission: Source code of the submission + test_case_run_cache: Optional cache of previously run test cases + test_batch_size: evaluate these many test cases in parallel, then check if any of them failed (0 score): if so stop evaluating; otherwise continue with the next batch of test cases. + -1 to evaluate all test cases in parallel + Returns: + SubtaskResult: Result of the subtask evaluation + """ + subtask_result = SubtaskResult( + problem=subtask["id"], + subtask=subtask["subtask"], + points=subtask["score"], + score_precision=subtask["score_precision"], + test_results=[], + ) + + # tests that are not cached + tests_to_run = [ + (ti, test_name) + for ti, test_name in enumerate(subtask["test_names"]) + if test_case_run_cache is None or test_name not in test_case_run_cache + ] + + # initialize test results with cached results or empty (SKIPPED) TestResult objects + subtask_result.test_results = [ + test_case_run_cache[test_name] + if test_case_run_cache is not None and test_name in test_case_run_cache + else TestResult(test_name=test_name) + for test_name in subtask["test_names"] + ] + + # we skip submissions where no code was extracted + # no need to do anything, as we have a failed cached result + if not submission or any( + test_result.status != "SKIPPED" and test_result.score == 0.0 for test_result in subtask_result.test_results + ): + return subtask_result + + if "test_cases" in subtask: + test_cases = subtask["test_cases"] + if isinstance(subtask["test_cases"], list): + test_cases = {test_name: test for test_name, test in zip(subtask["test_names"], subtask["test_cases"])} + else: + test_cases = load_ioi_tests(subtask["year"], subtask["id"]) + + # run one batch, check if any of them failed (0 score): if so stop evaluating; otherwise continue with the next batch of test cases. + for test_batch_to_run in batched(tests_to_run, test_batch_size): + results = await asyncio.gather( + *[ + asyncio.create_task( + score_single_test_case( + client, subtask, test_name, test_cases[test_name][0], test_cases[test_name][1], submission + ) + ) + for _, test_name in test_batch_to_run + ] + ) + for (ti, test_name), test_result in zip(test_batch_to_run, results): + if test_case_run_cache is not None: + test_case_run_cache[test_name] = test_result + subtask_result.test_results[ti] = test_result + + # Stop early if it failed + if any(test_result.score == 0.0 for test_result in results): + break + + return subtask_result + + +async def score_subtasks( + client: PistonClient, subtasks: list[dict], submission: str, skip_mode: bool = True +) -> list[SubtaskResult]: + """ + Scores multiple subtasks for a submission. + + Args: + client: PistonClient instance for executing code + subtasks: List of dictionaries containing subtask configurations + submission: Source code of the submission + skip_mode: If True, evaluates test by test and stops after the first failure. Otherwise, runs all tests in parallel. Should be True when evaluating a large number of submissions. + + Returns: + list[SubtaskResult]: Results for all subtasks + """ + # avoid rerunning tests present in multiple subtasks + test_case_run_cache = {} + + return [await score_subtask(client, subtask, submission, test_case_run_cache, skip_mode) for subtask in subtasks] + + +async def run_submission( + client: PistonClient, problem: dict, test_input: str, submission: str, test_output: str | None = None +) -> tuple[str, str]: + """ + Executes a submission against a test case using the Piston execution environment. + + Args: + client: PistonClient instance for executing code + problem: Dictionary containing problem configuration + test_input: Input data for the test case + submission: Source code of the submission + test_output: Optional expected output for the test case + + Returns: + tuple[str, str]: A tuple containing (score, feedback) + """ + data = { + "files": [ + # the actual submission + {"name": f"graders/{problem['id'].lower()}.cpp", "content": submission}, + # pass the input + {"name": "input.txt", "content": test_input}, + # pass the expected output + *([{"name": "correct_output.txt", "content": test_output}] if test_output else []), + # grader files + *({"name": name, "content": content} for name, content in problem["grader_files"] if content), + ], + "run_timeout": round( + (problem["time_limit"] + 3) * 1000 + ), # +3 seconds hard limit. time limits are handled by the ioi script + "run_memory_limit": problem["memory_limit"], + } + return await execute_ioi(client, data) + + +async def execute_ioi(client, data) -> tuple[str, str]: + """ + Requests to the IOI package return the score as a float in the stdout, as well as optional feedback/errors in stderr. + Returns a tuple of (score, feedback). + """ + response = await client.send_execute(data) + + if "message" in response: + raise PistonError(response["message"]) + + if "compile" in response and response["compile"]["code"] != 0: + return "0", "Compilation error exit code " + str(response["compile"]["code"]) + "\n" + response["compile"][ + "stderr" + ] + + if "run" not in response: + raise PistonError(response) + + if response["run"]["code"] == 1 and "MemoryError" in response["run"]["stderr"]: + return "0", "Memory limit exceeded" + + # successful result + if response["run"]["stdout"]: + return response["run"]["stdout"], response["run"]["stderr"] + + if response["run"]["signal"] == "SIGKILL": + return "0", "Time limit exceeded" + + # other issues + if response["run"]["code"] != 0: + raise PistonError( + f"language={response['language']}, version={response['version']}, exit code={response['run']['code']}, stderr={response['run']['stderr']}, signal={response['run']['signal']}" + ) + return "0", "Unknown error" diff --git a/open_r1/src/mind_openr1/utils/competitive_programming/ioi_utils.py b/open_r1/src/mind_openr1/utils/competitive_programming/ioi_utils.py new file mode 100644 index 000000000..02fe2b39b --- /dev/null +++ b/open_r1/src/mind_openr1/utils/competitive_programming/ioi_utils.py @@ -0,0 +1,41 @@ +from collections import defaultdict +from functools import lru_cache + +from datasets import load_dataset + + +def add_includes(code: str, problem_id: str) -> str: + """ + Fix common compilation errors for IOI problems. + """ + if not code: + return code + # has most of the useful functions + code_header = "#include \n" + # include the problem header + problem_header_include = f'#include "{problem_id}.h"' + if problem_header_include not in code: + code_header += problem_header_include + "\n" + # use namespace std since models forget std:: often + if "using namespace std;" not in code and "std::" not in code: + code_header += "\nusing namespace std;\n\n" + return code_header + code + + +@lru_cache +def load_ioi_tests_for_year(year: int) -> dict[str, dict[str, tuple[str, str]]]: + """ + Load IOI tests for a given year. + """ + tests_dataset = load_dataset("open-r1/ioi-test-cases", name=f"{year}", split="train") + test_cases = defaultdict(dict) + for test_case in tests_dataset: + test_cases[test_case["problem_id"]][test_case["test_name"]] = test_case["test_input"], test_case["test_output"] + return test_cases + + +def load_ioi_tests(year: int, problem_id: str) -> dict[str, tuple[str, str]]: + """ + Load IOI tests for a given year and problem id. + """ + return load_ioi_tests_for_year(year)[problem_id] diff --git a/open_r1/src/mind_openr1/utils/competitive_programming/morph_client.py b/open_r1/src/mind_openr1/utils/competitive_programming/morph_client.py new file mode 100644 index 000000000..559b7f8a2 --- /dev/null +++ b/open_r1/src/mind_openr1/utils/competitive_programming/morph_client.py @@ -0,0 +1,742 @@ +import asyncio +import json +import logging +import os +import tempfile +from typing import Any, Dict, Optional, Tuple + +from dotenv import load_dotenv +from open_r1.utils.import_utils import is_morph_available + + +# Replace direct imports with conditional imports +if is_morph_available(): + from morphcloud.api import Instance, InstanceExecResponse, MorphCloudClient +else: + Instance = None + InstanceExecResponse = None + MorphCloudClient = None + + +# Silence verbose logs from dependencies +logging.getLogger("paramiko").setLevel(logging.ERROR) +logging.getLogger("httpx").setLevel(logging.ERROR) + + +class MorphCloudError(Exception): + pass + + +class MorphCloudExecutionClient: + def __init__( + self, + api_key: Optional[str] = None, + base_url: Optional[str] = None, + spans_log_path: Optional[str] = None, + ): + """ + Initialize the MorphCloud execution client. + + Args: + api_key: Optional API key for MorphCloud. If not provided, will use MORPH_API_KEY env var. + base_url: Optional base URL for MorphCloud API. If not provided, will use default. + spans_log_path: Path to log API call spans to. Defaults to 'logs/morph_api_spans.jsonl'. + """ + + self.client = MorphCloudClient(api_key=api_key, base_url=base_url) + self._snapshot_lock = asyncio.Lock() + + async def _prepare_instance(self, snapshot_id=None) -> Instance: + """ + Prepare and start a MorphCloud instance. + + Args: + snapshot_id: Optional snapshot ID to use. If None, will get or create base snapshot. + + Returns: + Instance: The ready-to-use MorphCloud instance + + Raises: + TimeoutError: If instance fails to start or become ready + """ + + if not snapshot_id: + snapshot = await self._get_or_create_base_snapshot() + snapshot_id = snapshot.id + + try: + instance = await self.client.instances.astart( + snapshot_id, ttl_seconds=600 + ) # Auto-terminate after 10 minutes + await instance.await_until_ready(timeout=300) + return instance + except asyncio.TimeoutError as e: + print(f"Timeout while preparing instance: {str(e)}") + if instance: + try: + await instance.astop() + except Exception: + pass + raise + + async def _prepare_files(self, data: Dict[str, Any], temp_dir: str) -> Tuple[str, Dict[str, Any], Dict[str, str]]: + """ + Process files, determine problem ID, and prepare configuration. + + Args: + data: Dictionary containing file information + temp_dir: Local temporary directory for file operations + + Returns: + tuple: (problem_id, grader_config, local_files) + + Raises: + ValueError: If problem ID cannot be determined + """ + # Extract problem ID + problem_id = None + graders_files = [] + for file in data["files"]: + if file["name"].startswith("graders/") and file["name"].endswith(".cpp"): + potential_id = os.path.basename(file["name"]).split(".")[0] + if potential_id not in ["grader", "manager", "stub"]: + problem_id = potential_id + + if file["name"].startswith("graders/"): + graders_files.append(file) + + if not problem_id: + raise ValueError("Could not determine problem ID from files") + + grader_config = { + "task_type": "Batch", + "code": problem_id, + "time_limit": data["run_timeout"] / 1000, + "memory_limit": data["run_memory_limit"] * 1024 * 1024, + } + + for file in graders_files: + if "manager.cpp" in file["name"]: + grader_config["task_type"] = "Communication" + grader_config["task_type_parameters_Communication_num_processes"] = 1 + grader_config["task_type_parameters_Communication_user_io"] = "std_io" + break + + config_path = os.path.join(temp_dir, "grader_config.json") + with open(config_path, "w") as f: + json.dump(grader_config, f) + + local_files = {"grader_config.json": config_path} + + for file in data["files"]: + local_path = os.path.join(temp_dir, os.path.basename(file["name"])) + with open(local_path, "w") as f: + f.write(file["content"]) + local_files[file["name"]] = local_path + + return problem_id, grader_config, local_files + + async def _upload_files(self, instance: Instance, local_files: Dict[str, str]) -> bool: + """ + Upload all necessary files to the instance. + + Args: + instance: The MorphCloud instance + local_files: Dictionary mapping remote paths to local file paths + + Returns: + bool: True if all uploads were successful + + Raises: + TimeoutError: If uploads time out + """ + for remote_name, local_path in local_files.items(): + target_path = f"/workspace/{remote_name}" + dir_path = os.path.dirname(target_path) + + if dir_path != "/workspace": + await instance.aexec(f"mkdir -p {dir_path}") + + await instance.aupload(local_path, target_path) + + await instance.aupload(local_files["grader_config.json"], "/workspace/graders/grader_config.json") + + return True + + async def _compile_code(self, instance: Instance) -> InstanceExecResponse: + """ + Compile the code on the instance. + + Args: + instance: The MorphCloud instance + + Returns: + InstanceExecResponse: Result of compilation + + Raises: + RuntimeError: If compilation fails + """ + compile_result = await instance.aexec("cd /workspace && ./compile") + + if compile_result.exit_code != 0: + raise RuntimeError(f"Compilation error exit code {compile_result.exit_code}\n{compile_result.stderr}") + + return compile_result + + async def _run_tests(self, instance: Instance, data: Dict[str, Any]) -> Tuple[str, str]: + """ + Run tests and evaluate results. + + Args: + instance: The MorphCloud instance + data: Dictionary containing runtime parameters + + Returns: + tuple: (score, feedback) + + Raises: + TimeoutError: If test execution times out + """ + hard_timeout = data["run_timeout"] / 1000 + 3 + run_command = f"cd /workspace && timeout {hard_timeout}s ./run" + + run_result = await instance.aexec(run_command) + + if run_result.exit_code == 124 or run_result.exit_code == 137 or run_result.exit_code == 143: + return "0", "Time limit exceeded" + + if run_result.exit_code != 0 and "Memory limit exceeded" in run_result.stderr: + return "0", "Memory limit exceeded" + + if run_result.stdout: + return run_result.stdout.strip(), run_result.stderr.strip() + + if run_result.exit_code != 0: + return ( + "0", + f"Runtime error with exit code {run_result.exit_code}\n{run_result.stderr}", + ) + + return "0", "Unknown error" + + async def _execute_with_instance(self, instance: Instance, data: Dict[str, Any], temp_dir: str) -> Tuple[str, str]: + """Execute code using a prepared instance. + + Args: + instance: Ready MorphCloud instance + data: Execution data + temp_dir: Temporary directory for file operations + + Returns: + Tuple of (score, feedback) + + Raises: + Exception: Passes through exceptions for retry handling + """ + await instance.await_until_ready(timeout=300) + + problem_id, grader_config, local_files = await self._prepare_files(data, temp_dir) + + await self._upload_files(instance, local_files) + + try: + await self._compile_code(instance) + except RuntimeError as e: + return "0", str(e) + + score, feedback = await self._run_tests(instance, data) + return score, feedback + + async def _execute(self, data: Dict[str, Any]) -> Tuple[str, str]: + """ + Internal implementation of execute with no retry logic. + + Args: + data: Dictionary containing execution data + + Returns: + Tuple of (score, feedback) + + Raises: + Exception: If execution fails + """ + instance = None + + # Set timeouts to ensure we don't block indefinitely + # INSTANCE_TIMEOUT = 300 # 5 minutes for instance operations + TOTAL_EXECUTION_TIMEOUT = 600 # 10 minutes total execution time + + with tempfile.TemporaryDirectory(prefix="morph_exec_") as temp_dir: + snapshot = await self._get_or_create_base_snapshot() + instance = await self.client.instances.astart( + snapshot.id, ttl_seconds=600 + ) # Auto-terminate after 10 minutes + + async with instance: + # Use asyncio.wait_for to add overall timeout to the execution process + return await asyncio.wait_for( + self._execute_with_instance(instance, data, temp_dir), + timeout=TOTAL_EXECUTION_TIMEOUT, + ) + + async def execute(self, data: Dict[str, Any]) -> Tuple[str, str]: + """ + Execute code on MorphCloud based on the provided data with enhanced debugging and recovery. + + Orchestrates the following steps with proper error handling and retries: + 1. Prepare an instance (with retry) + 2. Set up workspace (with retry) + 3. Prepare and upload files (with retry) + 4. Compile code (with retry) + 5. Run tests (with retry) + + Args: + data: Dictionary containing: + - files: List of file objects with name and content fields + - run_timeout: Timeout in milliseconds + - run_memory_limit: Memory limit in MB + + Returns: + Tuple of (score, feedback) where: + - score is a string representation of a float between 0.0 and 1.0 + - feedback is a string with execution details + """ + # TODO: would be faster to pass info about the subtask as well to create a snapshot per subtask + # would cache the uploads of all files other than the submission: input.txt, correct_output.txt, grader files + # rather than reusing the snapshot that only has the compile/run scripts on it + # currently, run_submission -> client.execute(data) does not easily pass subtask info + + # Retry configuration + max_retries = 4 + base_delay = 1.0 + + # Try execution with retries and exponential backoff + for attempt in range(max_retries + 1): + try: + return await self._execute(data) + + except asyncio.TimeoutError: + if attempt < max_retries: + print(f"Execution timed out, retrying ({attempt + 1}/{max_retries})") + else: + return "0", "Execution timed out after multiple retries" + + except Exception as e: + # Calculate exponential backoff + if attempt < max_retries: + retry_delay = min(base_delay * (2**attempt), 30) # Exponential backoff, capped at 30 seconds + + print( + f"Execution failed with {type(e).__name__}: {str(e)}, retrying in {retry_delay:.2f}s ({attempt + 1}/{max_retries})" + ) + await asyncio.sleep(retry_delay) + else: + print(f"Execution failed after {max_retries} retries: {type(e).__name__}: {str(e)}") + return "0", f"Execution failed after multiple retries: {str(e)}" + + async def _get_or_create_base_snapshot(self): + """Get or create a snapshot with the necessary dependencies and scripts for evaluation.""" + + async with self._snapshot_lock: + base_snapshots = await self.client.snapshots.alist(digest="ioi-evaluation-morph") + + if not base_snapshots: + print("Creating base snapshot with build-essential cmake and g++") + + # Create base snapshot with minimal specs + base_snapshot = await self.client.snapshots.acreate( + vcpus=2, + memory=4096, + disk_size=10240, + metadata={"purpose": "ioi_evaluation"}, + ) + + # Start a temporary instance from the base snapshot + temp_instance = await self.client.instances.astart( + base_snapshot.id, ttl_seconds=900 + ) # Auto-terminate after 15 minutes + + try: + # Wait for the instance to be ready + await temp_instance.await_until_ready(timeout=300) + + # Get script contents + compile_script = await self._get_compile_script() + run_script = await self._get_run_script() + + # Use temporary directory to store scripts + with tempfile.TemporaryDirectory(prefix="morph_setup_") as temp_dir: + # Create paths for script files + compile_path = os.path.join(temp_dir, "compile.sh") + run_path = os.path.join(temp_dir, "run.sh") + + # Write scripts to temp files + with open(compile_path, "w") as f: + f.write(compile_script) + + with open(run_path, "w") as f: + f.write(run_script) + + async with temp_instance: + # Install dependencies + await temp_instance.aexec("apt-get update && apt-get install -y build-essential cmake g++") + + # Create workspace directory + await temp_instance.aexec( + "mkdir -p /workspace && mkdir -p /workspace/graders && chmod 777 /workspace" + ) + + # Upload scripts to instance + await temp_instance.aupload(compile_path, "/workspace/compile") + await temp_instance.aupload(run_path, "/workspace/run") + + # Make scripts executable + await temp_instance.aexec("chmod +x /workspace/compile /workspace/run") + + # Create snapshot from the prepared instance + final_snapshot = await temp_instance.asnapshot(digest="ioi-evaluation-morph") + + except Exception as e: + # Ensure instance is stopped if anything fails + await temp_instance.astop() + raise e + else: + final_snapshot = base_snapshots[0] + + return final_snapshot + + async def _get_compile_script(self): + """Get the compile script content.""" + return """#!/bin/bash + +manager_files=() # Array to store manager filenames +current_dir="$(pwd)" + +# Checker compilation path +checker_dir="$current_dir/checker" +checker_src="$checker_dir/checker.cpp" + +if [ -e "$checker_src" ]; then + echo "Compiling checker" + checker_exe="$checker_dir/checker" + g++ -x c++ -std=gnu++17 -O2 -o "$checker_exe" "$checker_src" + chmod +x "$checker_exe" + if [ $? -ne 0 ]; then + echo "Could not compile checker" >&2 + exit 1 + fi + echo "Compiled checker" +else + echo "No checker found at $checker_src" +fi + +# Graders path +graders_dir="$current_dir/graders" +if [ ! -e "$graders_dir" ]; then + echo "Grader folder was not found" >&2 + exit 1 +fi + +# Find and compile manager if it exists +manager_src="$graders_dir/manager.cpp" +if [ -e "$manager_src" ]; then + echo "Compiling manager" + manager_exe="$graders_dir/manager" + g++ -x c++ -std=gnu++17 -O2 -o "$manager_exe" "$manager_src" + chmod +x "$manager_exe" + if [ $? -ne 0 ]; then + echo "Could not compile manager" >&2 + exit 1 + fi + manager_files+=("manager") +fi + +# Process other graders +graders_list=($(ls "$graders_dir" | grep -v 'manager.cpp')) +for grader_name in "${graders_list[@]}"; do + manager_files+=("$grader_name") +done + +# Extract problem name and compile necessary files +problem_name='?' +for file in "${manager_files[@]}"; do + if [[ "$file" == *.h && "$file" != "testlib.h" ]]; then + problem_name="${file%.h}" + echo "Problem name: $problem_name" + break + fi +done + +files_to_compile=("graders/$problem_name.cpp") +[ -e graders/grader.cpp ] && files_to_compile+=("graders/grader.cpp") +[ -e graders/stub.cpp ] && files_to_compile+=("graders/stub.cpp") + +g++ -DEVAL -std=gnu++17 -O2 -pipe -s -o graders/"$problem_name" "${files_to_compile[@]}" +if [ $? -ne 0 ]; then + echo "Failed to compile $problem_name" >&2 + exit 1 +fi +chmod +x graders/"$problem_name" +echo "Compiled $problem_name from ${files_to_compile[@]} successfully" + +echo "Manager files: ${manager_files[@]}" +""" + + async def _get_run_script(self): + """Get the run script content.""" + return """#!/usr/bin/env bash +# disable stack limit so you don't get RE with recursion +ulimit -s unlimited +# some problems have 10MB+ input/output files in their test cases and you might get RE. uncomment if needed +# ulimit -f 2097152 + +# Check if grader_config.json exists +if [ ! -f "graders/grader_config.json" ]; then + echo "Error: graders/grader_config.json not found" >&2 + echo "Current directory contents:" >&2 + find . -type f -o -type d | sed -e 's/[^-][^\/]*\// |/g' -e 's/|\([^ ]\)/|-\1/' >&2 + exit 1 +fi + +# Read task type, code, and time limit from grader_config.json using grep and sed +TASK_TYPE=$(grep -o '"task_type":[^,}]*' graders/grader_config.json | sed 's/"task_type":\\s*"\\([^"]*\\)"/\\1/') +TASK_NAME=$(grep -o '"code":[^,}]*' graders/grader_config.json | sed 's/"code":\\s*"\\([^"]*\\)"/\\1/') +TIME_LIMIT=$(grep -o '"time_limit":[^,}]*' graders/grader_config.json | sed 's/"time_limit":\\s*\\([^,}]*\\)/\\1/') +MEMORY_LIMIT=$(grep -o '"memory_limit":[^,}]*' graders/grader_config.json | sed 's/"memory_limit":\\s*\\([^,}]*\\)/\\1/') +TASK_EXECUTABLE="graders/$TASK_NAME" + +# Set memory limit in KB (convert from bytes) +MEMORY_LIMIT_KB=0 +if [ -n "$MEMORY_LIMIT" ]; then + MEMORY_LIMIT_KB=$(($MEMORY_LIMIT / 1024)) + # Set the memory limit for the entire script and all child processes + ulimit -v $MEMORY_LIMIT_KB +fi + +# "Securely" handle the correct output file +CORRECT_OUTPUT="" +if [ -f "correct_output.txt" ]; then + # Read the content and immediately remove the file + CORRECT_OUTPUT=$(cat correct_output.txt) + rm -f correct_output.txt +fi + +# Create a temporary file for solution output +SOLUTION_OUTPUT=$(mktemp) + +# Global variables for process tracking +declare -a ALL_PIDS +declare -a FIFO_DIRS + +# Define cleanup function - simplified assuming timeout exists +function cleanup { + # Kill all tracked processes silently + exec 2>/dev/null + for pid in "${ALL_PIDS[@]:-}"; do + kill -9 "$pid" 2>/dev/null || true + done + + # Clean up FIFO directories + for dir in "${FIFO_DIRS[@]:-}"; do + [ -d "$dir" ] && rm -rf "$dir" + done + + # Clean up temporary files + rm -f "$SOLUTION_OUTPUT" || true + exec 2>&2 +} + +# Set up signal handling +trap cleanup EXIT INT TERM + +# Function to handle exit codes consistently across task types +function handle_exit_code { + local exit_code=$1 + + # Check for known timeout exit codes: + # - 124: standard timeout exit code + # - 137: SIGKILL (128+9), used for hard timeouts + # - 143: SIGTERM (128+15), can also be used for timeouts + if [ $exit_code -eq 124 ] || [ $exit_code -eq 137 ] || [ $exit_code -eq 143 ]; then + echo "0" + echo "Time limit exceeded (${TIME_LIMIT}s)" >&2 + return 124 + # All other non-zero exit codes should be treated as runtime errors + elif [ $exit_code -ne 0 ]; then + echo "0" + echo "Runtime error with exit code $exit_code" >&2 + return $exit_code + fi + + # Success case - return 0 + return 0 +} + +# Function to run a command with timeout (simplified assuming timeout exists) +function run_with_timeout { + local soft_limit=$1; shift + local command_to_run="$@" + + timeout --preserve-status "$soft_limit" "$@" + return $? +} + +case "$TASK_TYPE" in + "Batch") + # Simple batch execution with timeout + run_with_timeout "$TIME_LIMIT" ./$TASK_EXECUTABLE < input.txt > "$SOLUTION_OUTPUT" + exit_code=$? + + # Handle non-zero exit codes + handle_exit_code $exit_code + if [ $? -ne 0 ]; then + exit $? + fi + + # Check the output if we have a correct output + if [ -n "$CORRECT_OUTPUT" ]; then + # Restore the correct output file + echo "$CORRECT_OUTPUT" > correct_output.txt + + # Check if there's a custom checker + if [ -f "checker/checker" ]; then + # Let the checker handle everything + ./checker/checker input.txt correct_output.txt "$SOLUTION_OUTPUT" + exit $? + else + # Simple diff-based checking + if diff -bq <(echo "$CORRECT_OUTPUT") "$SOLUTION_OUTPUT" >/dev/null; then + echo "1" + echo "Output is correct (diff)" >&2 + else + echo "0" + echo "Output isn't correct (diff)" >&2 + exit 0 + fi + fi + else + # If no correct output was provided, just output the solution's output + cat "$SOLUTION_OUTPUT" + fi + ;; + + "Communication") + # Read Communication-specific parameters + NUM_PROCESSES=$(grep -o '"task_type_parameters_Communication_num_processes":[^,}]*' graders/grader_config.json | sed 's/.*:\\s*\\([0-9]*\\)/\\1/' || true) + if [ -z "$NUM_PROCESSES" ]; then + NUM_PROCESSES=1 + fi + USER_IO=$(grep -o '"task_type_parameters_Communication_user_io":[^,}]*' graders/grader_config.json | sed 's/.*:\\s*"\\([^"]*\\)"/\\1/' || echo "std_io") + + # Read custom manager arguments if they exist + MANAGER_CUSTOM_ARGS="" + if grep -q '"task_type_parameters_Communication_manager_args"' graders/grader_config.json; then + MANAGER_CUSTOM_ARGS=$(grep -o '"task_type_parameters_Communication_manager_args":[^,}]*' graders/grader_config.json | sed 's/.*:\\s*"\\([^"]*\\)"/\\1/') + fi + + # Create temporary directories for FIFOs + for i in $(seq 0 $((NUM_PROCESSES-1))); do + FIFO_DIRS[$i]=$(mktemp -d) + + # Create FIFOs for this process + mkfifo "${FIFO_DIRS[$i]}/u${i}_to_m" + mkfifo "${FIFO_DIRS[$i]}/m_to_u${i}" + chmod 755 "${FIFO_DIRS[$i]}" + chmod 666 "${FIFO_DIRS[$i]}/u${i}_to_m" "${FIFO_DIRS[$i]}/m_to_u${i}" + done + + # Prepare manager arguments + MANAGER_ARGS="" + for i in $(seq 0 $((NUM_PROCESSES-1))); do + MANAGER_ARGS="$MANAGER_ARGS ${FIFO_DIRS[$i]}/u${i}_to_m ${FIFO_DIRS[$i]}/m_to_u${i}" + done + + # Add custom manager arguments if specified + if [ -n "$MANAGER_CUSTOM_ARGS" ]; then + MANAGER_ARGS="$MANAGER_ARGS $MANAGER_CUSTOM_ARGS" + fi + + # Start all user processes first + for i in $(seq 0 $((NUM_PROCESSES-1))); do + if [ "$USER_IO" = "fifo_io" ]; then + # Pass FIFOs as arguments + ARGS="${FIFO_DIRS[$i]}/m_to_u${i} ${FIFO_DIRS[$i]}/u${i}_to_m" + if [ "$NUM_PROCESSES" -ne 1 ]; then + ARGS="$ARGS $i" + fi + ./$TASK_EXECUTABLE $ARGS & + ALL_PIDS+=($!) + else + # Use stdin/stdout redirection + if [ "$NUM_PROCESSES" -ne 1 ]; then + ./$TASK_EXECUTABLE "$i" < "${FIFO_DIRS[$i]}/m_to_u${i}" > "${FIFO_DIRS[$i]}/u${i}_to_m" 2>/dev/null & + ALL_PIDS+=($!) + else + ./$TASK_EXECUTABLE < "${FIFO_DIRS[$i]}/m_to_u${i}" > "${FIFO_DIRS[$i]}/u${i}_to_m" 2>/dev/null & + ALL_PIDS+=($!) + fi + fi + done + + # Run the manager with timeout using direct pipe from input.txt + run_with_timeout "$TIME_LIMIT" ./graders/manager $MANAGER_ARGS < input.txt > "$SOLUTION_OUTPUT" + + exit_code=$? + + # Handle non-zero exit codes + handle_exit_code $exit_code + if [ $? -ne 0 ]; then + exit $? + fi + + # Check the output if we have a correct output AND there's a checker (otherwise we assume the manager handles everything) + if [ -n "$CORRECT_OUTPUT" ] && [ -f "checker/checker" ]; then + # Restore the correct output file + echo "$CORRECT_OUTPUT" > correct_output.txt + + # Let the checker handle it + ./checker/checker input.txt correct_output.txt "$SOLUTION_OUTPUT" + exit $? + else + # we assume the manager handles it + cat "$SOLUTION_OUTPUT" + fi + ;; + + *) + echo "0" + echo "Unsupported task type \"$TASK_TYPE\"" >&2 + exit 1 + ;; +esac +""" + + +def get_morph_client_from_env(session=None) -> MorphCloudExecutionClient: + """ + Creates a MorphCloudExecutionClient instance using environment variables. + + Environment variables: + MORPH_API_KEY: API key for MorphCloud + + Args: + session: Optional aiohttp.ClientSession to use for HTTP requests + + Returns: + MorphCloudExecutionClient: A configured MorphCloud execution client + """ + if not is_morph_available(): + raise ImportError( + "MorphCloud is not available and required for this function. Please install MorphCloud with " + "`pip install morphcloud` and add an API key to a `.env` file." + ) + + load_dotenv() + api_key = os.environ.get("MORPH_API_KEY") + if not api_key: + raise ValueError("MORPH_API_KEY environment variable is required") + + return MorphCloudExecutionClient(api_key=api_key) + + +# noqa: W293 diff --git a/open_r1/src/mind_openr1/utils/competitive_programming/piston_client.py b/open_r1/src/mind_openr1/utils/competitive_programming/piston_client.py new file mode 100644 index 000000000..7dfc9a5ec --- /dev/null +++ b/open_r1/src/mind_openr1/utils/competitive_programming/piston_client.py @@ -0,0 +1,224 @@ +import asyncio +import os +import random +import re +import subprocess +from collections import Counter +from functools import lru_cache + +import aiohttp + + +class PistonError(Exception): + pass + + +@lru_cache(maxsize=1) +def get_piston_client_from_env(session=None): + piston_endpoints = os.getenv("PISTON_ENDPOINTS") + if piston_endpoints is None: + raise ValueError( + "For IOI/CF problems Piston endpoints running our IOI package are required. Please add a list of valid Piston endpoints to a PISTON_ENDPOINTS variable in a `.env` file." + ) + piston_endpoints = sorted( + piston_endpoints.split(",") if piston_endpoints != "slurm" else get_slurm_piston_endpoints() + ) + gpu_nb = int(os.getenv("LOCAL_RANK", 0)) # per‑GPU index + world = int(os.getenv("WORLD_SIZE", 1)) # total GPUs + if world > 1: + print(f"Using a subset of piston endpoints for GPU#{gpu_nb}") + piston_endpoints = piston_endpoints[gpu_nb::world] + random.shuffle(piston_endpoints) + max_requests_per_endpoint = os.getenv("PISTON_MAX_REQUESTS_PER_ENDPOINT", "1") + return PistonClient(piston_endpoints, session, max_requests_per_endpoint=int(max_requests_per_endpoint)) + + +class PistonClient: + """ + A client that will automatically load balance across multiple Piston (https://github.com/engineer-man/piston) workers. + This assumes piston is running our custom cms_ioi package: https://github.com/guipenedo/piston/releases/ + We recommend starting the instances with the following script as otherwise some IOI problems will hit default limits: + ``` + export PISTON_COMPILE_TIMEOUT=60000 + export PISTON_RUN_TIMEOUT=60000 + export PISTON_OUTPUT_MAX_SIZE=1000000000 + export PISTON_MAX_FILE_SIZE=1000000000 + export PISTON_DISABLE_NETWORKING=true + export PISTON_REPO_URL=https://github.com/guipenedo/piston/releases/download/pkgs/index + mkdir /piston + + sed -i '/app.use(body_parser.urlencoded/c\ app.use(body_parser.urlencoded({ extended: true, limit: \"512mb\" }));' src/index.js + sed -i '/app.use(body_parser.json/c\ app.use(body_parser.json({ limit: \"512mb\" }));' src/index.js + + # Start server in background + node src``` + + Piston docs for API usage: https://piston.readthedocs.io/en/latest/api-v2/ + """ + + def __init__( + self, + base_endpoint: str | list[str] = "http://ip-10-53-80-65:3223/api/v2", + session=None, + max_requests_per_endpoint=1, + ): + self.max_requests_per_endpoint = max_requests_per_endpoint + self.base_endpoints = [base_endpoint] if isinstance(base_endpoint, str) else base_endpoint + if len(self.base_endpoints) == 0: + raise ValueError("No Piston endpoints provided. Please check your PISTON_ENDPOINTS environment variable.") + self.endpoint_ids = {endpoint: i for i, endpoint in enumerate(self.base_endpoints)} + + self._session = session + self.endpoint_tokens = asyncio.Queue(maxsize=max_requests_per_endpoint * len(self.base_endpoints)) + + for _ in range(max_requests_per_endpoint): + for base_endpoint in self.base_endpoints: + self.endpoint_tokens.put_nowait(base_endpoint) + self._endpoint_failures = Counter() + self._unhealthy_endpoints = set() + self._endpoint_failures_lock = asyncio.Lock() + + @property + def session(self): + if self._session is None: + self._session = aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(sock_read=30), + connector=aiohttp.TCPConnector( + limit=self.max_requests_per_endpoint * len(self.base_endpoints), + ttl_dns_cache=300, + keepalive_timeout=5 * 60, + ), + ) + return self._session + + async def _wait_for_endpoint(self): + endpoint = await self.endpoint_tokens.get() + return endpoint + + async def _release_endpoint(self, endpoint): + await self.endpoint_tokens.put(endpoint) + + async def _send_request(self, endpoint, route, data=None, method="post"): + async with self.session.request( + method, f"{endpoint.rstrip('/')}/{route}", json=data, headers={"Content-Type": "application/json"} + ) as response: + return await response.json(content_type=None) + + async def _send_to_all(self, route, data=None, method="post"): + return await asyncio.gather( + *[self._send_request(endpoint, route, data, method) for endpoint in self.base_endpoints] + ) + + async def _send_to_one(self, endpoint, route, data=None, method="post"): + return await self._send_request(endpoint, route, data, method) + + async def install_package(self, language, version): + return await self._send_to_all("packages", {"language": language, "version": version}, method="post") + + async def uninstall_package(self, language, version): + return await self._send_to_all("packages", {"language": language, "version": version}, method="delete") + + async def get_supported_runtimes(self): + return await self._send_to_all("runtimes", method="get") + + async def _check_failed_endpoint(self, endpoint): + async with self._endpoint_failures_lock: + if endpoint in self._unhealthy_endpoints: + return + try: + await asyncio.sleep(5) + await self.get_supported_runtimes() + except Exception as e: + print(f"Error checking endpoint {endpoint}, dropping it ({e})") + self._unhealthy_endpoints.add(endpoint) + if len(self._unhealthy_endpoints) >= len(self.base_endpoints): + raise PistonError("All endpoints are unhealthy. Please check your Piston workers.") + + async def send_execute(self, data, language="cms_ioi", max_retries=5): + data = data | { + "language": language, + "version": "*", + } + + base_delay = 1.0 + + status = None + endpoint = None + + for attempt in range(max_retries + 1): + try: + endpoint = await self._wait_for_endpoint() + if attempt > 0: + await asyncio.sleep(1) + async with self.session.post( + f"{endpoint.rstrip('/')}/execute", json=data, headers={"Content-Type": "application/json"} + ) as response: + status = response.status + res_json = await response.json(content_type=None) + + if status != 200: + raise PistonError(f"Server error. status={status}. {res_json}") + if res_json is None: + raise PistonError(f"Empty response. status={status}") + # piston overloaded + if "run" in res_json and "Resource temporarily unavailable" in res_json["run"].get("stderr", ""): + raise PistonError(f"Piston overloaded: {res_json['run']['stderr']}") + return res_json + + except (PistonError, asyncio.TimeoutError, aiohttp.ClientConnectionError, RuntimeError) as e: + # Only retry if we haven't reached max retries yet + if attempt < max_retries: + # Calculate backoff with jitter + delay = min(base_delay * (2**attempt), 10) # Exponential backoff, capped at 10 seconds + jitter = delay * 0.2 * (2 * asyncio.get_event_loop().time() % 1 - 0.5) # Add ±10% jitter + retry_delay = delay + jitter + print(f"Retrying in {retry_delay:.2f} seconds [{self.endpoint_ids[endpoint]}] {endpoint} - {e}") + + # special case: worker died + if isinstance(e, aiohttp.ClientConnectionError) and "Connect call failed" in str(e): + await self._check_failed_endpoint(endpoint) + else: + # hopefully we won't get this one again + await self._release_endpoint(endpoint) + endpoint = None + + await asyncio.sleep(retry_delay) + else: + await self._check_failed_endpoint(endpoint) + except Exception as e: + print(f"Propagating exception {type(e)}: {e}") + raise e + finally: + # Ensure endpoint is always released, even if an exception occurs + if endpoint is not None: + try: + await self._release_endpoint(endpoint) + except Exception as e: + print(f"Error releasing endpoint {endpoint}: {e}") + endpoint = None + + +def get_slurm_piston_endpoints(): + """Get list of active piston worker endpoints from squeue output""" + # Run squeue command to get job name, hostname and status, filtering for RUNNING state + result = subprocess.run( + ["squeue", '--format="%j %N %T"', "--noheader", "--states=RUNNING"], capture_output=True, text=True + ) + + # Split output into lines and skip header + lines = result.stdout.strip().split("\n") + + endpoints = [] + for line in lines: + # Parse job name from squeue output + fields = line.split() + job_name = fields[0].strip('"') # Remove quotes + hostname = fields[1] + + # Extract port if job name matches pattern + match = re.match(r"piston-worker-(\d+)", job_name) + if match: + port = match.group(1) + endpoints.append(f"http://{hostname}:{port}/api/v2") + + return endpoints diff --git a/open_r1/src/mind_openr1/utils/competitive_programming/utils.py b/open_r1/src/mind_openr1/utils/competitive_programming/utils.py new file mode 100644 index 000000000..7e1bf730f --- /dev/null +++ b/open_r1/src/mind_openr1/utils/competitive_programming/utils.py @@ -0,0 +1,11 @@ +from itertools import islice + + +def batched(iterable, n): + "Batch data into lists of length n. The last batch may be shorter." + # batched('ABCDEFG', 3) --> ABC DEF G + if n < 1: + return iterable + it = iter(iterable) + while batch := list(islice(it, n)): + yield batch diff --git a/open_r1/src/mind_openr1/utils/data.py b/open_r1/src/mind_openr1/utils/data.py new file mode 100644 index 000000000..b151a8a7f --- /dev/null +++ b/open_r1/src/mind_openr1/utils/data.py @@ -0,0 +1,65 @@ +import logging + +import datasets +from datasets import DatasetDict, concatenate_datasets + +from ..configs import ScriptArguments + + +logger = logging.getLogger(__name__) + + +def get_dataset(args: ScriptArguments) -> DatasetDict: + """Load a dataset or a mixture of datasets based on the configuration. + + Args: + args (ScriptArguments): Script arguments containing dataset configuration. + + Returns: + DatasetDict: The loaded datasets. + """ + if args.dataset_name and not args.dataset_mixture: + logger.info(f"Loading dataset: {args.dataset_name}") + return datasets.load_dataset(args.dataset_name, args.dataset_config) + elif args.dataset_mixture: + logger.info(f"Creating dataset mixture with {len(args.dataset_mixture.datasets)} datasets") + seed = args.dataset_mixture.seed + datasets_list = [] + + for dataset_config in args.dataset_mixture.datasets: + logger.info(f"Loading dataset for mixture: {dataset_config.id} (config: {dataset_config.config})") + ds = datasets.load_dataset( + dataset_config.id, + dataset_config.config, + split=dataset_config.split, + ) + if dataset_config.columns is not None: + ds = ds.select_columns(dataset_config.columns) + if dataset_config.weight is not None: + ds = ds.shuffle(seed=seed).select(range(int(len(ds) * dataset_config.weight))) + logger.info( + f"Subsampled dataset '{dataset_config.id}' (config: {dataset_config.config}) with weight={dataset_config.weight} to {len(ds)} examples" + ) + + datasets_list.append(ds) + + if datasets_list: + combined_dataset = concatenate_datasets(datasets_list) + combined_dataset = combined_dataset.shuffle(seed=seed) + logger.info(f"Created dataset mixture with {len(combined_dataset)} examples") + + if args.dataset_mixture.test_split_size is not None: + combined_dataset = combined_dataset.train_test_split( + test_size=args.dataset_mixture.test_split_size, seed=seed + ) + logger.info( + f"Split dataset into train and test sets with test size: {args.dataset_mixture.test_split_size}" + ) + return combined_dataset + else: + return DatasetDict({"train": combined_dataset}) + else: + raise ValueError("No datasets were loaded from the mixture configuration") + + else: + raise ValueError("Either `dataset_name` or `dataset_mixture` must be provided") diff --git a/open_r1/src/mind_openr1/utils/evaluation.py b/open_r1/src/mind_openr1/utils/evaluation.py new file mode 100644 index 000000000..e79cd2972 --- /dev/null +++ b/open_r1/src/mind_openr1/utils/evaluation.py @@ -0,0 +1,118 @@ +import subprocess +from typing import TYPE_CHECKING, Dict, Union + +from .hub import get_gpu_count_for_vllm, get_param_count_from_repo_id + + +if TYPE_CHECKING: + from trl import GRPOConfig, SFTConfig, ModelConfig + +import base64 +import os + + +# We need a special environment setup to launch vLLM from within Slurm training jobs. +# - Reference code: https://github.com/huggingface/brrr/blob/c55ba3505686d690de24c7ace6487a5c1426c0fd/brrr/lighteval/one_job_runner.py#L105 +# - Slack thread: https://huggingface.slack.com/archives/C043JTYE1MJ/p1726566494958269 +user_home_directory = os.path.expanduser("~") +VLLM_SLURM_PREFIX = [ + "env", + "-i", + "bash", + "-c", + f"for f in /etc/profile.d/*.sh; do source $f; done; export HOME={user_home_directory}; sbatch ", +] + + +def register_lighteval_task( + configs: Dict[str, str], + eval_suite: str, + task_name: str, + task_list: str, + num_fewshot: int = 0, +): + """Registers a LightEval task configuration. + + - Core tasks can be added from this table: https://github.com/huggingface/lighteval/blob/main/src/lighteval/tasks/tasks_table.jsonl + - Custom tasks that require their own metrics / scripts, should be stored in scripts/evaluation/extended_lighteval_tasks + + Args: + configs (Dict[str, str]): The dictionary to store the task configuration. + eval_suite (str, optional): The evaluation suite. + task_name (str): The name of the task. + task_list (str): The comma-separated list of tasks in the format "extended|{task_name}|{num_fewshot}|0" or "lighteval|{task_name}|{num_fewshot}|0". + num_fewshot (int, optional): The number of few-shot examples. Defaults to 0. + is_custom_task (bool, optional): Whether the task is a custom task. Defaults to False. + """ + # Format task list in lighteval format + task_list = ",".join(f"{eval_suite}|{task}|{num_fewshot}|0" for task in task_list.split(",")) + configs[task_name] = task_list + + +LIGHTEVAL_TASKS = {} + +register_lighteval_task(LIGHTEVAL_TASKS, "lighteval", "math_500", "math_500", 0) +register_lighteval_task(LIGHTEVAL_TASKS, "lighteval", "aime24", "aime24", 0) +register_lighteval_task(LIGHTEVAL_TASKS, "lighteval", "aime25", "aime25", 0) +register_lighteval_task(LIGHTEVAL_TASKS, "lighteval", "gpqa", "gpqa:diamond", 0) +register_lighteval_task(LIGHTEVAL_TASKS, "extended", "lcb", "lcb:codegeneration", 0) +register_lighteval_task(LIGHTEVAL_TASKS, "extended", "lcb_v4", "lcb:codegeneration_v4", 0) + + +def get_lighteval_tasks(): + return list(LIGHTEVAL_TASKS.keys()) + + +SUPPORTED_BENCHMARKS = get_lighteval_tasks() + + +def run_lighteval_job( + benchmark: str, + training_args: Union["SFTConfig", "GRPOConfig"], + model_args: "ModelConfig", +) -> None: + task_list = LIGHTEVAL_TASKS[benchmark] + model_name = training_args.hub_model_id + model_revision = training_args.hub_model_revision + # For large models >= 30b params or those running the MATH benchmark, we need to shard them across the GPUs to avoid OOM + num_gpus = get_gpu_count_for_vllm(model_name, model_revision) + if get_param_count_from_repo_id(model_name) >= 30_000_000_000: + tensor_parallel = True + else: + num_gpus = 2 # Hack while cluster is full + tensor_parallel = False + + cmd = VLLM_SLURM_PREFIX.copy() + cmd_args = [ + f"--gres=gpu:{num_gpus}", + f"--job-name=or1_{benchmark}_{model_name.split('/')[-1]}_{model_revision}", + "slurm/evaluate.slurm", + benchmark, + f'"{task_list}"', + model_name, + model_revision, + f"{tensor_parallel}", + f"{model_args.trust_remote_code}", + ] + if training_args.system_prompt is not None: + # encode to base64 to avoid issues with special characters + # we decode in the sbatch script + prompt_encoded = base64.b64encode(training_args.system_prompt.encode()).decode() + cmd_args.append(prompt_encoded) + cmd[-1] += " " + " ".join(cmd_args) + subprocess.run(cmd, check=True) + + +def run_benchmark_jobs(training_args: Union["SFTConfig", "GRPOConfig"], model_args: "ModelConfig") -> None: + benchmarks = training_args.benchmarks + if len(benchmarks) == 1 and benchmarks[0] == "all": + benchmarks = get_lighteval_tasks() + # Evaluate on all supported benchmarks. Later we may want to include a `chat` option + # that just evaluates on `ifeval` and `mt_bench` etc. + + for benchmark in benchmarks: + print(f"Launching benchmark `{benchmark}`") + if benchmark in get_lighteval_tasks(): + run_lighteval_job(benchmark, training_args, model_args) + else: + raise ValueError(f"Unknown benchmark {benchmark}") diff --git a/open_r1/src/mind_openr1/utils/hub.py b/open_r1/src/mind_openr1/utils/hub.py new file mode 100644 index 000000000..25c4311c7 --- /dev/null +++ b/open_r1/src/mind_openr1/utils/hub.py @@ -0,0 +1,132 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import re +from concurrent.futures import Future + +from transformers import AutoConfig + +from huggingface_hub import ( + create_branch, + create_repo, + get_safetensors_metadata, + list_repo_commits, + list_repo_files, + list_repo_refs, + repo_exists, + upload_folder, +) +from trl import GRPOConfig, SFTConfig + + +logger = logging.getLogger(__name__) + + +def push_to_hub_revision(training_args: SFTConfig | GRPOConfig, extra_ignore_patterns=[]) -> Future: + """Pushes the model to branch on a Hub repo.""" + + # Create a repo if it doesn't exist yet + repo_url = create_repo(repo_id=training_args.hub_model_id, private=True, exist_ok=True) + # Get initial commit to branch from + initial_commit = list_repo_commits(training_args.hub_model_id)[-1] + # Now create the branch we'll be pushing to + create_branch( + repo_id=training_args.hub_model_id, + branch=training_args.hub_model_revision, + revision=initial_commit.commit_id, + exist_ok=True, + ) + logger.info(f"Created target repo at {repo_url}") + logger.info(f"Pushing to the Hub revision {training_args.hub_model_revision}...") + ignore_patterns = ["checkpoint-*", "*.pth"] + ignore_patterns.extend(extra_ignore_patterns) + future = upload_folder( + repo_id=training_args.hub_model_id, + folder_path=training_args.output_dir, + revision=training_args.hub_model_revision, + commit_message=f"Add {training_args.hub_model_revision} checkpoint", + ignore_patterns=ignore_patterns, + run_as_future=True, + ) + logger.info(f"Pushed to {repo_url} revision {training_args.hub_model_revision} successfully!") + + return future + + +def check_hub_revision_exists(training_args: SFTConfig | GRPOConfig): + """Checks if a given Hub revision exists.""" + if repo_exists(training_args.hub_model_id): + if training_args.push_to_hub_revision is True: + # First check if the revision exists + revisions = [rev.name for rev in list_repo_refs(training_args.hub_model_id).branches] + # If the revision exists, we next check it has a README file + if training_args.hub_model_revision in revisions: + repo_files = list_repo_files( + repo_id=training_args.hub_model_id, + revision=training_args.hub_model_revision, + ) + if "README.md" in repo_files and training_args.overwrite_hub_revision is False: + raise ValueError( + f"Revision {training_args.hub_model_revision} already exists. " + "Use --overwrite_hub_revision to overwrite it." + ) + + +def get_param_count_from_repo_id(repo_id: str) -> int: + """Function to get model param counts from safetensors metadata or find patterns like 42m, 1.5b, 0.5m or products like 8x7b in a repo ID.""" + try: + metadata = get_safetensors_metadata(repo_id) + return list(metadata.parameter_count.values())[0] + except Exception: + # Pattern to match products (like 8x7b) and single values (like 42m) + pattern = r"((\d+(\.\d+)?)(x(\d+(\.\d+)?))?)([bm])" + matches = re.findall(pattern, repo_id.lower()) + + param_counts = [] + for full_match, number1, _, _, number2, _, unit in matches: + if number2: # If there's a second number, it's a product + number = float(number1) * float(number2) + else: # Otherwise, it's a single value + number = float(number1) + + if unit == "b": + number *= 1_000_000_000 # Convert to billion + elif unit == "m": + number *= 1_000_000 # Convert to million + + param_counts.append(number) + + if len(param_counts) > 0: + # Return the largest number + return int(max(param_counts)) + else: + # Return -1 if no match found + return -1 + + +def get_gpu_count_for_vllm(model_name: str, revision: str = "main", num_gpus: int = 8) -> int: + """vLLM enforces a constraint that the number of attention heads must be divisible by the number of GPUs and 64 must be divisible by the number of GPUs. + This function calculates the number of GPUs to use for decoding based on the number of attention heads in the model. + """ + config = AutoConfig.from_pretrained(model_name, revision=revision, trust_remote_code=True) + # Get number of attention heads + num_heads = config.num_attention_heads + # Reduce num_gpus so that num_heads is divisible by num_gpus and 64 is divisible by num_gpus + while num_heads % num_gpus != 0 or 64 % num_gpus != 0: + logger.info(f"Reducing num_gpus from {num_gpus} to {num_gpus - 1} to make num_heads divisible by num_gpus") + num_gpus -= 1 + return num_gpus diff --git a/open_r1/src/mind_openr1/utils/import_utils.py b/open_r1/src/mind_openr1/utils/import_utils.py new file mode 100644 index 000000000..5d6624302 --- /dev/null +++ b/open_r1/src/mind_openr1/utils/import_utils.py @@ -0,0 +1,30 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from transformers.utils.import_utils import _is_package_available + + +# Use same as transformers.utils.import_utils +_e2b_available = _is_package_available("e2b") + + +def is_e2b_available() -> bool: + return _e2b_available + + +_morph_available = _is_package_available("morphcloud") + + +def is_morph_available() -> bool: + return _morph_available diff --git a/open_r1/src/mind_openr1/utils/model_utils.py b/open_r1/src/mind_openr1/utils/model_utils.py new file mode 100644 index 000000000..8191c17ea --- /dev/null +++ b/open_r1/src/mind_openr1/utils/model_utils.py @@ -0,0 +1,42 @@ +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizer + +from trl import ModelConfig, get_kbit_device_map, get_quantization_config + +from ..configs import GRPOConfig, SFTConfig + + +def get_tokenizer(model_args: ModelConfig, training_args: SFTConfig | GRPOConfig) -> PreTrainedTokenizer: + """Get the tokenizer for the model.""" + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path, + revision=model_args.model_revision, + trust_remote_code=model_args.trust_remote_code, + ) + + if training_args.chat_template is not None: + tokenizer.chat_template = training_args.chat_template + + return tokenizer + + +def get_model(model_args: ModelConfig, training_args: SFTConfig | GRPOConfig) -> AutoModelForCausalLM: + """Get the model""" + torch_dtype = ( + model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) + ) + quantization_config = get_quantization_config(model_args) + model_kwargs = dict( + revision=model_args.model_revision, + trust_remote_code=model_args.trust_remote_code, + attn_implementation=model_args.attn_implementation, + torch_dtype=torch_dtype, + use_cache=False if training_args.gradient_checkpointing else True, + device_map=get_kbit_device_map() if quantization_config is not None else None, + quantization_config=quantization_config, + ) + model = AutoModelForCausalLM.from_pretrained( + model_args.model_name_or_path, + **model_kwargs, + ) + return model diff --git a/open_r1/src/mind_openr1/utils/routed_morph.py b/open_r1/src/mind_openr1/utils/routed_morph.py new file mode 100644 index 000000000..835c784af --- /dev/null +++ b/open_r1/src/mind_openr1/utils/routed_morph.py @@ -0,0 +1,120 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional + +import requests + + +class RoutedMorphSandbox: + """ + Client for the MorphCloud router service that mimics the API of MorphCloud's Sandbox. + + This class provides a simple interface to execute code via a central MorphCloud router, + which manages sandbox creation and cleanup. It allows batch processing of multiple scripts + in a single request for improved efficiency. + + Attributes: + router_url (str): The URL of the MorphCloud router service. + timeout (int): Execution timeout in seconds. + request_timeout (int): HTTP request timeout in seconds. + """ + + def __init__(self, router_url: str, timeout: int = 300, request_timeout: int = 60): + """ + Initialize the routed MorphCloud sandbox client. + + Args: + router_url: The URL of the MorphCloud router, including host and port. + timeout: Default execution timeout in seconds. + request_timeout: Default HTTP request timeout in seconds. + """ + self.router_url = router_url + self.timeout = timeout + self.request_timeout = request_timeout + + def run_code( + self, + scripts: List[str], + languages: Optional[List[str]] = None, + timeout: Optional[int] = None, + request_timeout: Optional[int] = None, + ) -> List: + """ + Execute multiple scripts using MorphCloud via the router. + + Args: + scripts: List of code scripts to execute. + languages: List of programming languages for each script. If None, defaults to Python for all scripts. + timeout: Execution timeout in seconds. If None, uses the instance timeout. + request_timeout: HTTP request timeout in seconds. If None, uses the instance request_timeout. + + Returns: + List of execution results with text and exception_str properties. + """ + + actual_timeout = timeout if timeout is not None else self.timeout + actual_request_timeout = request_timeout if request_timeout is not None else self.request_timeout + + # Default to Python for all scripts if languages is not provided + if languages is None: + languages = ["python"] * len(scripts) + + payload = { + "scripts": scripts, + "languages": languages, + "timeout": actual_timeout, + "request_timeout": actual_request_timeout, + } + + try: + endpoint = f"http://{self.router_url}/execute_batch" + response = requests.post(endpoint, json=payload, timeout=actual_request_timeout) + + if response.status_code != 200: + error = f"Request to MorphCloud router failed with status code: {response.status_code}" + print(error) + + results = [] + for _ in scripts: + results.append(type("obj", (object,), {"text": None, "exception_str": error})) + return results + + response_data = response.json() + results = [] + + for item in response_data: + # Log the response data to see what we're getting + # print(f"RoutedMorphSandbox: Got response item: {item}") + result = type( + "obj", + (object,), + { + "text": item.get("text"), + "exception_str": item.get("exception_str"), + }, + ) + results.append(result) + + return results + + except Exception as e: + error = f"Error communicating with MorphCloud router: {str(e)}" + print(error) + + results = [] + for _ in scripts: + results.append(type("obj", (object,), {"text": None, "exception_str": error})) + return results diff --git a/open_r1/src/mind_openr1/utils/routed_sandbox.py b/open_r1/src/mind_openr1/utils/routed_sandbox.py new file mode 100644 index 000000000..97bb65cf4 --- /dev/null +++ b/open_r1/src/mind_openr1/utils/routed_sandbox.py @@ -0,0 +1,109 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional + +import requests +from e2b_code_interpreter.models import Execution, ExecutionError, Result + + +class RoutedSandbox: + """ + A sandbox environment that routes code execution requests to the E2B Router. + This class is designed for batched execution of scripts, primarily for Python code. + It mimics the usage of 'Sandbox' from 'e2b_code_interpreter', but adds support for batch processing. + + Attributes: + router_url (str): The URL of the E2B Router to which code execution requests are sent. + """ + + def __init__(self, router_url: str): + """ + Initializes the RoutedSandbox with the specified router URL. + + Args: + router_url (str): The URL of the E2B Router. + """ + self.router_url = router_url + + def run_code( + self, + scripts: list[str], + languages: Optional[List[str]] = None, + timeout: Optional[int] = None, + request_timeout: Optional[int] = None, + ) -> list[Execution]: + """ + Executes a batch of scripts in the sandbox environment. + + Args: + scripts (list[str]): A list of code scripts to execute. + languages (list[str], optional): List of programming languages for each script. If None, defaults to Python for all scripts. + timeout (Optional[int], optional): The maximum execution time for each script in seconds. Defaults to 300 seconds. + request_timeout (Optional[int], optional): The timeout for the HTTP request in seconds. Defaults to 30 seconds. + + Returns: + list[Execution]: A list of Execution objects containing the results, logs, and errors (if any) for each script. + """ + # Set default values for timeouts if not provided + if timeout is None: + timeout = 300 # Default to 5 minutes + if request_timeout is None: + request_timeout = 30 # Default to 30 seconds + + # Default to Python for all scripts if languages is not provided + if languages is None: + languages = ["python"] * len(scripts) + + # Prepare the payload for the HTTP POST request + payload = { + "scripts": scripts, + "languages": languages, + "timeout": timeout, + "request_timeout": request_timeout, + } + + # Send the request to the E2B Router + response = requests.post(f"http://{self.router_url}/execute_batch", json=payload) + if not response.ok: + print(f"Request failed with status code: {response.status_code}") + + # Parse the response and construct Execution objects + results = response.json() + output = [] + for result in results: + if result["execution"] is None: + # If execution is None, create an empty Execution object + # This can happen when a script times out or fails to execute + execution = Execution() + else: + execution = Execution( + results=[Result(**r) for r in result["execution"]["results"]], + logs=result["execution"]["logs"], + error=(ExecutionError(**result["execution"]["error"]) if result["execution"]["error"] else None), + execution_count=result["execution"]["execution_count"], + ) + output.append(execution) + + return output + + +if __name__ == "__main__": + # for local testing launch an E2B router with: python scripts/e2b_router.py + sbx = RoutedSandbox(router_url="0.0.0.0:8000") + codes = ["print('hello world')", "print('hello world)"] + executions = sbx.run_code(codes) # Execute Python inside the sandbox + + print(executions) diff --git a/open_r1/src/mind_openr1/utils/wandb_logging.py b/open_r1/src/mind_openr1/utils/wandb_logging.py new file mode 100644 index 000000000..e52f911c8 --- /dev/null +++ b/open_r1/src/mind_openr1/utils/wandb_logging.py @@ -0,0 +1,13 @@ +import os + + +def init_wandb_training(training_args): + """ + Helper function for setting up Weights & Biases logging tools. + """ + if training_args.wandb_entity is not None: + os.environ["WANDB_ENTITY"] = training_args.wandb_entity + if training_args.wandb_project is not None: + os.environ["WANDB_PROJECT"] = training_args.wandb_project + if training_args.wandb_run_group is not None: + os.environ["WANDB_RUN_GROUP"] = training_args.wandb_run_group From 9b389bbe98a7ec1708249d6265b20b662f72c57a Mon Sep 17 00:00:00 2001 From: guojialiang <2802427218@qq.com> Date: Wed, 8 Oct 2025 13:26:26 +0800 Subject: [PATCH 3/7] add: sft --- open_r1/sh/sft.sh | 18 + open_r1/src/mind_openr1/__init__.py | 8 - open_r1/src/mind_openr1/sft.py | 227 ++----------- open_r1/src/mind_openr1/sft_trainer.py | 449 ------------------------- 4 files changed, 55 insertions(+), 647 deletions(-) create mode 100644 open_r1/sh/sft.sh delete mode 100644 open_r1/src/mind_openr1/sft_trainer.py diff --git a/open_r1/sh/sft.sh b/open_r1/sh/sft.sh new file mode 100644 index 000000000..aad426bd2 --- /dev/null +++ b/open_r1/sh/sft.sh @@ -0,0 +1,18 @@ +PYTHONPATH=/home/ma-user/work/mind-openr1/src python src/mind_openr1/sft.py \ + --model_name_or_path /home/ma-user/work/Qwen2.5-1.5B \ + --dataset_name open-r1/Mixture-of-Thoughts \ + --dataset_config math \ + --eos_token '<|im_end|>' \ + --learning_rate 4.0e-5 \ + --num_train_epochs 5 \ + --max_length 13312 \ + --per_device_train_batch_size 1 \ + --gradient_checkpointing \ + --bf16 True \ + --torch_dtype bfloat16 \ + --output_dir checkpoints/Qwen2.5-1.5B-SFT \ + --save_steps 100000 \ + + # --dataset_name /home/ma-user/work/mind-openr1/data/open-r1___mixture-of-thoughts \ + +# nohup bash sh/sft.sh > /home/ma-user/work/mind-openr1/logs/sft.log 2>&1 & \ No newline at end of file diff --git a/open_r1/src/mind_openr1/__init__.py b/open_r1/src/mind_openr1/__init__.py index 31cb1f347..e69de29bb 100644 --- a/open_r1/src/mind_openr1/__init__.py +++ b/open_r1/src/mind_openr1/__init__.py @@ -1,8 +0,0 @@ -""" -Mind-OpenR1: MindSpore implementation of OpenR1 -""" - -from .sft_trainer import SFTTrainer, SFTConfig -from .configs import ScriptArguments - -__all__ = ["SFTTrainer", "SFTConfig", "ScriptArguments"] diff --git a/open_r1/src/mind_openr1/sft.py b/open_r1/src/mind_openr1/sft.py index c51f6f82f..f6e3a4950 100644 --- a/open_r1/src/mind_openr1/sft.py +++ b/open_r1/src/mind_openr1/sft.py @@ -1,106 +1,29 @@ import logging import os import sys -from dataclasses import dataclass -from typing import Optional import mindspore from mindspore import context as ms_context import mindnlp import datasets -from mindnlp.transformers import ( - set_seed, - AutoTokenizer, - AutoModelForCausalLM, - get_last_checkpoint -) +import transformers +from transformers import set_seed +from transformers.trainer_utils import get_last_checkpoint -from mind_openr1.configs import ScriptArguments -from mind_openr1.sft_trainer import SFTTrainer, SFTConfig -from mind_openr1.utils import get_dataset +from mind_openr1.configs import ScriptArguments, SFTConfig +from mind_openr1.utils import get_dataset, get_model, get_tokenizer from mind_openr1.utils.callbacks import get_callbacks +from mind_openr1.utils.wandb_logging import init_wandb_training +from trl import ModelConfig, SFTTrainer, TrlParser, get_peft_config, setup_chat_format -ms_context.set_context(mode=ms_context.PYNATIVE_MODE) +# 支持 Ascend 设备运行,如果未配置则默认 CPU +device_target = os.environ.get("MS_DEVICE_TARGET", "Ascend") +device_id = int(os.environ.get("DEVICE_ID", "0")) +ms_context.set_context(mode=ms_context.PYNATIVE_MODE, device_target=device_target, device_id=device_id) logger = logging.getLogger(__name__) -@dataclass -class ModelConfig: - """Model configuration compatible with mindnlp""" - model_name_or_path: str - model_revision: str = "main" - trust_remote_code: bool = False - use_flash_attention_2: bool = False - lora_r: Optional[int] = None - lora_alpha: Optional[int] = None - lora_dropout: Optional[float] = None - lora_target_modules: Optional[list] = None - use_peft: bool = False - - -def get_tokenizer_mindnlp(model_args: ModelConfig): - """Get tokenizer using mindnlp""" - tokenizer = AutoTokenizer.from_pretrained( - model_args.model_name_or_path, - revision=model_args.model_revision, - trust_remote_code=model_args.trust_remote_code, - ) - - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token - - return tokenizer - - -def get_model_mindnlp(model_args: ModelConfig): - """Get model using mindnlp""" - model = AutoModelForCausalLM.from_pretrained( - model_args.model_name_or_path, - revision=model_args.model_revision, - trust_remote_code=model_args.trust_remote_code, - ms_dtype=mindspore.float16 if model_args.use_flash_attention_2 else mindspore.float32, - ) - - return model - - -def get_peft_config_dict(model_args: ModelConfig): - """Get PEFT configuration if enabled""" - if not model_args.use_peft: - return None - - peft_config = { - "r": model_args.lora_r or 16, - "lora_alpha": model_args.lora_alpha or 32, - "lora_dropout": model_args.lora_dropout or 0.1, - "target_modules": model_args.lora_target_modules or ["q_proj", "v_proj"], - "bias": "none", - "task_type": "CAUSAL_LM", - } - - return peft_config - - -def setup_chat_format(model, tokenizer): - """Setup chat format for model and tokenizer""" - if tokenizer.chat_template is None: - logger.info("No chat template provided, setting up ChatML format") - # Simple ChatML template - tokenizer.chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" - - # Add special tokens if needed - special_tokens = { - "additional_special_tokens": ["<|im_start|>", "<|im_end|>"] - } - tokenizer.add_special_tokens(special_tokens) - - # Resize model embeddings - model.resize_token_embeddings(len(tokenizer)) - - return model, tokenizer - - def main(script_args, training_args, model_args): set_seed(training_args.seed) @@ -112,9 +35,12 @@ def main(script_args, training_args, model_args): datefmt="%Y-%m-%d %H:%M:%S", handlers=[logging.StreamHandler(sys.stdout)], ) - log_level = logging.INFO if training_args.logging_steps > 0 else logging.WARNING + log_level = training_args.get_process_log_level() logger.setLevel(log_level) datasets.utils.logging.set_verbosity(log_level) + transformers.utils.logging.set_verbosity(log_level) + transformers.utils.logging.enable_default_handler() + transformers.utils.logging.enable_explicit_format() logger.info(f"Model parameters {model_args}") logger.info(f"Script parameters {script_args}") @@ -127,23 +53,25 @@ def main(script_args, training_args, model_args): if last_checkpoint is not None and training_args.resume_from_checkpoint is None: logger.info(f"Checkpoint detected, resuming training at {last_checkpoint}.") + if "wandb" in training_args.report_to: + init_wandb_training(training_args) + ###################################### # Load dataset, tokenizer, and model # ###################################### dataset = get_dataset(script_args) - # Optionally truncate training split if max_train_samples is provided if getattr(script_args, "max_train_samples", None): train_split = script_args.dataset_train_split max_n = int(script_args.max_train_samples) if max_n > 0: dataset[train_split] = dataset[train_split].select(range(min(max_n, len(dataset[train_split])))) - - tokenizer = get_tokenizer_mindnlp(model_args) - model = get_model_mindnlp(model_args) + tokenizer = get_tokenizer(model_args, training_args) + model = get_model(model_args, training_args) - # Setup chat format if needed - model, tokenizer = setup_chat_format(model, tokenizer) + if tokenizer.chat_template is None: + logger.info("No chat template provided, defaulting to ChatML.") + model, tokenizer = setup_chat_format(model, tokenizer, format="chatml") ############################ # Initialize the SFT Trainer @@ -154,7 +82,7 @@ def main(script_args, training_args, model_args): train_dataset=dataset[script_args.dataset_train_split], eval_dataset=(dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None), processing_class=tokenizer, - peft_config=get_peft_config_dict(model_args), + peft_config=get_peft_config(model_args), callbacks=get_callbacks(training_args, model_args), ) @@ -167,10 +95,8 @@ def main(script_args, training_args, model_args): checkpoint = training_args.resume_from_checkpoint elif last_checkpoint is not None: checkpoint = last_checkpoint - train_result = trainer.train(resume_from_checkpoint=checkpoint) - - metrics = train_result + metrics = train_result.metrics metrics["train_samples"] = len(dataset[script_args.dataset_train_split]) trainer.log_metrics("train", metrics) trainer.save_metrics("train", metrics) @@ -180,19 +106,22 @@ def main(script_args, training_args, model_args): # Save model and create model card ################################## logger.info("*** Save model ***") - - # Save model + # Align the model's generation config with the tokenizer's eos token + # to avoid unbounded generation in the transformers `pipeline()` function + trainer.model.generation_config.eos_token_id = tokenizer.eos_token_id trainer.save_model(training_args.output_dir) logger.info(f"Model saved to {training_args.output_dir}") # Save everything else on main process kwargs = { "dataset_name": script_args.dataset_name, - "tags": ["open-r1", "mindspore", "mindnlp"], + "tags": ["open-r1"], } - - # Create model card - trainer.create_model_card(**kwargs) + if trainer.accelerator.is_main_process: + trainer.create_model_card(**kwargs) + # Restore k,v cache for fast inference + trainer.model.config.use_cache = True + trainer.model.config.save_pretrained(training_args.output_dir) ########## # Evaluate @@ -213,88 +142,6 @@ def main(script_args, training_args, model_args): if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser() - - # Script arguments - parser.add_argument("--dataset_name", type=str, required=True) - parser.add_argument("--dataset_config", type=str, default=None) - parser.add_argument("--dataset_train_split", type=str, default="train") - parser.add_argument("--dataset_test_split", type=str, default="test") - parser.add_argument("--max_train_samples", type=int, default=None) - - # Model arguments - parser.add_argument("--model_name_or_path", type=str, required=True) - parser.add_argument("--model_revision", type=str, default="main") - parser.add_argument("--trust_remote_code", action="store_true") - parser.add_argument("--use_flash_attention_2", action="store_true") - parser.add_argument("--use_peft", action="store_true") - parser.add_argument("--lora_r", type=int, default=16) - parser.add_argument("--lora_alpha", type=int, default=32) - parser.add_argument("--lora_dropout", type=float, default=0.1) - parser.add_argument("--lora_target_modules", type=str, nargs="+", default=None) - - # Training arguments - parser.add_argument("--output_dir", type=str, required=True) - parser.add_argument("--num_train_epochs", type=int, default=3) - parser.add_argument("--per_device_train_batch_size", type=int, default=8) - parser.add_argument("--per_device_eval_batch_size", type=int, default=8) - parser.add_argument("--learning_rate", type=float, default=5e-5) - parser.add_argument("--weight_decay", type=float, default=0.0) - parser.add_argument("--max_seq_length", type=int, default=512) - parser.add_argument("--logging_steps", type=int, default=10) - parser.add_argument("--save_steps", type=int, default=500) - parser.add_argument("--eval_steps", type=int, default=500) - parser.add_argument("--eval_strategy", type=str, default="steps") - parser.add_argument("--max_steps", type=int, default=-1) - parser.add_argument("--seed", type=int, default=42) - parser.add_argument("--do_eval", action="store_true") - parser.add_argument("--push_to_hub", action="store_true") - parser.add_argument("--resume_from_checkpoint", type=str, default=None) - parser.add_argument("--dataset_text_field", type=str, default="text") - - args = parser.parse_args() - - # Create config objects - script_args = ScriptArguments( - dataset_name=args.dataset_name, - dataset_config=args.dataset_config, - dataset_train_split=args.dataset_train_split, - dataset_test_split=args.dataset_test_split, - max_train_samples=args.max_train_samples, - ) - - model_args = ModelConfig( - model_name_or_path=args.model_name_or_path, - model_revision=args.model_revision, - trust_remote_code=args.trust_remote_code, - use_flash_attention_2=args.use_flash_attention_2, - use_peft=args.use_peft, - lora_r=args.lora_r, - lora_alpha=args.lora_alpha, - lora_dropout=args.lora_dropout, - lora_target_modules=args.lora_target_modules, - ) - - training_args = SFTConfig( - output_dir=args.output_dir, - num_train_epochs=args.num_train_epochs, - per_device_train_batch_size=args.per_device_train_batch_size, - per_device_eval_batch_size=args.per_device_eval_batch_size, - learning_rate=args.learning_rate, - weight_decay=args.weight_decay, - max_seq_length=args.max_seq_length, - logging_steps=args.logging_steps, - save_steps=args.save_steps, - eval_steps=args.eval_steps, - eval_strategy=args.eval_strategy, - max_steps=args.max_steps, - seed=args.seed, - do_eval=args.do_eval, - push_to_hub=args.push_to_hub, - resume_from_checkpoint=args.resume_from_checkpoint, - dataset_text_field=args.dataset_text_field, - ) - - main(script_args, training_args, model_args) \ No newline at end of file + parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig)) + script_args, training_args, model_args = parser.parse_args_and_config() + main(script_args, training_args, model_args) diff --git a/open_r1/src/mind_openr1/sft_trainer.py b/open_r1/src/mind_openr1/sft_trainer.py deleted file mode 100644 index b0b349787..000000000 --- a/open_r1/src/mind_openr1/sft_trainer.py +++ /dev/null @@ -1,449 +0,0 @@ -""" -Supervised Fine-tuning Trainer for MindSpore/MindNLP -""" -import logging -import os -import sys -from typing import Dict, List, Optional, Union, Any, Callable -from dataclasses import dataclass, field - -import mindspore -from mindspore import nn, ops, Tensor -from mindspore.dataset import GeneratorDataset -import mindspore.context as ms_context -import mindspore.communication as comm - -import datasets -from mindnlp.transformers import ( - AutoTokenizer, - AutoModelForCausalLM, - PreTrainedTokenizer, - PreTrainedModel, - TrainingArguments as BaseTrainingArguments -) - -logger = logging.getLogger(__name__) - - -@dataclass -class SFTConfig(BaseTrainingArguments): - """ - Configuration class for SFT training specific parameters. - Inherits from mindnlp TrainingArguments. - """ - max_seq_length: int = field( - default=512, - metadata={"help": "Maximum sequence length for input"} - ) - dataset_text_field: str = field( - default="text", - metadata={"help": "Field name containing text in the dataset"} - ) - packing: bool = field( - default=False, - metadata={"help": "Whether to pack multiple examples in a single sequence"} - ) - dataset_train_split: str = field( - default="train", - metadata={"help": "Name of the training data split"} - ) - dataset_test_split: str = field( - default="test", - metadata={"help": "Name of the test data split"} - ) - - def __post_init__(self): - # Ensure output directory exists - if self.output_dir: - os.makedirs(self.output_dir, exist_ok=True) - - -class SFTTrainer: - """ - Supervised Fine-tuning Trainer for MindSpore/MindNLP - - This trainer handles the training loop for supervised fine-tuning of language models. - """ - - def __init__( - self, - model: Optional[PreTrainedModel] = None, - args: Optional[SFTConfig] = None, - train_dataset: Optional[Union[datasets.Dataset, GeneratorDataset]] = None, - eval_dataset: Optional[Union[datasets.Dataset, GeneratorDataset]] = None, - processing_class: Optional[PreTrainedTokenizer] = None, - peft_config: Optional[Dict] = None, - callbacks: Optional[List[Callable]] = None, - ): - self.model = model - self.args = args or SFTConfig() - self.train_dataset = train_dataset - self.eval_dataset = eval_dataset - self.tokenizer = processing_class - self.peft_config = peft_config - self.callbacks = callbacks or [] - - # Training state - self.global_step = 0 - self.epoch = 0 - self.best_metric = None - self.best_model_checkpoint = None - - # Setup - self._setup_model() - self._setup_optimizer() - self._setup_datasets() - - def _setup_model(self): - """Setup model for training""" - if self.model is None: - raise ValueError("Model must be provided") - - # Set model to training mode - self.model.set_train(True) - - # Apply PEFT config if provided - if self.peft_config: - logger.info("Applying PEFT configuration") - # TODO: Implement PEFT integration - - def _setup_optimizer(self): - """Setup optimizer and learning rate scheduler""" - # Get trainable parameters - trainable_params = self.model.trainable_params() - - # Create optimizer - if self.args.learning_rate is None: - self.args.learning_rate = 5e-5 - - self.optimizer = nn.Adam( - trainable_params, - learning_rate=self.args.learning_rate, - beta1=0.9, - beta2=0.999, - eps=1e-8, - weight_decay=self.args.weight_decay - ) - - def _setup_datasets(self): - """Setup datasets for training""" - if self.train_dataset is None: - raise ValueError("Training dataset must be provided") - - # Process datasets if needed - self.train_dataset = self._prepare_dataset(self.train_dataset, is_train=True) - if self.eval_dataset is not None: - self.eval_dataset = self._prepare_dataset(self.eval_dataset, is_train=False) - - def _prepare_dataset(self, dataset, is_train=True): - """Prepare dataset for training/evaluation""" - # If it's already a GeneratorDataset, return as is - if isinstance(dataset, GeneratorDataset): - return dataset - - # Convert HuggingFace dataset to MindSpore dataset - def generator(): - for item in dataset: - yield self._preprocess_function(item) - - column_names = ["input_ids", "attention_mask", "labels"] - - ms_dataset = GeneratorDataset( - generator, - column_names=column_names, - shuffle=is_train - ) - - # Batch the dataset - ms_dataset = ms_dataset.batch( - batch_size=self.args.per_device_train_batch_size if is_train else self.args.per_device_eval_batch_size, - drop_remainder=is_train - ) - - return ms_dataset - - def _preprocess_function(self, examples): - """Preprocess a single example""" - # Get text from the configured field - text = examples.get(self.args.dataset_text_field, "") - - # Tokenize - tokenized = self.tokenizer( - text, - truncation=True, - padding="max_length", - max_length=self.args.max_seq_length, - return_tensors="ms" - ) - - # For causal LM, labels are the same as input_ids - labels = tokenized["input_ids"].copy() - - # Replace padding token id with -100 for loss computation - if self.tokenizer.pad_token_id is not None: - labels[labels == self.tokenizer.pad_token_id] = -100 - - return { - "input_ids": tokenized["input_ids"], - "attention_mask": tokenized["attention_mask"], - "labels": labels - } - - def compute_loss(self, model, inputs): - """Compute loss for a batch of inputs""" - # Forward pass - outputs = model(**inputs) - - # Get loss - if hasattr(outputs, "loss"): - return outputs.loss - else: - # Compute loss manually if model doesn't return it - logits = outputs.logits - labels = inputs.get("labels") - - if labels is not None: - # Shift for causal LM - shift_logits = logits[..., :-1, :].reshape(-1, logits.shape[-1]) - shift_labels = labels[..., 1:].reshape(-1) - - # Compute cross entropy loss - loss_fn = nn.CrossEntropyLoss() - loss = loss_fn(shift_logits, shift_labels) - return loss - - return None - - def training_step(self, batch): - """Perform a single training step""" - # Convert batch to model inputs - inputs = { - "input_ids": batch[0], - "attention_mask": batch[1], - "labels": batch[2] - } - - # Forward pass and compute loss - loss = self.compute_loss(self.model, inputs) - - # Backward pass - grads = ops.grad(self.compute_loss, self.model.trainable_params())(self.model, inputs) - - # Update parameters - self.optimizer(grads) - - return loss - - def train(self, resume_from_checkpoint=None): - """Main training loop""" - logger.info("***** Running training *****") - logger.info(f" Num examples = {len(self.train_dataset)}") - logger.info(f" Num Epochs = {self.args.num_train_epochs}") - logger.info(f" Batch size = {self.args.per_device_train_batch_size}") - logger.info(f" Total optimization steps = {self.args.max_steps}") - - # Resume from checkpoint if provided - if resume_from_checkpoint: - self._load_checkpoint(resume_from_checkpoint) - - # Training loop - for epoch in range(int(self.args.num_train_epochs)): - self.epoch = epoch - epoch_loss = 0.0 - num_batches = 0 - - # Iterate through batches - for step, batch in enumerate(self.train_dataset.create_tuple_iterator()): - loss = self.training_step(batch) - - epoch_loss += loss.asnumpy() - num_batches += 1 - self.global_step += 1 - - # Logging - if self.global_step % self.args.logging_steps == 0: - avg_loss = epoch_loss / num_batches - logger.info(f"Step: {self.global_step}, Loss: {avg_loss:.4f}") - self.log_metrics("train", {"loss": avg_loss}) - - # Save checkpoint - if self.global_step % self.args.save_steps == 0: - self.save_checkpoint() - - # Evaluation - if self.args.eval_strategy != "no" and self.global_step % self.args.eval_steps == 0: - self.evaluate() - - # Check max steps - if self.args.max_steps > 0 and self.global_step >= self.args.max_steps: - break - - # End of epoch - avg_epoch_loss = epoch_loss / num_batches - logger.info(f"Epoch {epoch} completed. Average Loss: {avg_epoch_loss:.4f}") - - # Run callbacks - for callback in self.callbacks: - callback(self, epoch=epoch, loss=avg_epoch_loss) - - if self.args.max_steps > 0 and self.global_step >= self.args.max_steps: - break - - # Save final model - self.save_model() - - return {"global_step": self.global_step} - - def evaluate(self): - """Evaluation loop""" - if self.eval_dataset is None: - return {} - - logger.info("***** Running evaluation *****") - self.model.set_train(False) - - total_loss = 0.0 - num_batches = 0 - - for batch in self.eval_dataset.create_tuple_iterator(): - inputs = { - "input_ids": batch[0], - "attention_mask": batch[1], - "labels": batch[2] - } - - loss = self.compute_loss(self.model, inputs) - total_loss += loss.asnumpy() - num_batches += 1 - - avg_loss = total_loss / num_batches - logger.info(f"Evaluation Loss: {avg_loss:.4f}") - - self.model.set_train(True) - - metrics = {"eval_loss": avg_loss} - self.log_metrics("eval", metrics) - - return metrics - - def save_model(self, output_dir=None): - """Save model to disk""" - output_dir = output_dir or self.args.output_dir - os.makedirs(output_dir, exist_ok=True) - - # Save model weights - save_checkpoint(self.model, os.path.join(output_dir, "model.ckpt")) - - # Save tokenizer - if self.tokenizer is not None: - self.tokenizer.save_pretrained(output_dir) - - # Save training arguments - with open(os.path.join(output_dir, "training_args.json"), "w") as f: - import json - json.dump(vars(self.args), f, indent=2) - - logger.info(f"Model saved to {output_dir}") - - def save_checkpoint(self): - """Save training checkpoint""" - checkpoint_dir = os.path.join(self.args.output_dir, f"checkpoint-{self.global_step}") - self.save_model(checkpoint_dir) - - # Save optimizer state - save_checkpoint(self.optimizer, os.path.join(checkpoint_dir, "optimizer.ckpt")) - - # Save training state - state = { - "global_step": self.global_step, - "epoch": self.epoch, - "best_metric": self.best_metric, - } - with open(os.path.join(checkpoint_dir, "trainer_state.json"), "w") as f: - import json - json.dump(state, f, indent=2) - - def _load_checkpoint(self, checkpoint_path): - """Load checkpoint from disk""" - # Load model weights - load_checkpoint(os.path.join(checkpoint_path, "model.ckpt"), self.model) - - # Load optimizer state - if os.path.exists(os.path.join(checkpoint_path, "optimizer.ckpt")): - load_checkpoint(os.path.join(checkpoint_path, "optimizer.ckpt"), self.optimizer) - - # Load training state - state_path = os.path.join(checkpoint_path, "trainer_state.json") - if os.path.exists(state_path): - with open(state_path, "r") as f: - import json - state = json.load(f) - self.global_step = state.get("global_step", 0) - self.epoch = state.get("epoch", 0) - self.best_metric = state.get("best_metric") - - logger.info(f"Resumed from checkpoint: {checkpoint_path}") - - def log_metrics(self, split, metrics): - """Log metrics""" - # Simple console logging - log_str = f"[{split}] Step {self.global_step}: " - log_str += ", ".join([f"{k}={v:.4f}" for k, v in metrics.items()]) - logger.info(log_str) - - def save_metrics(self, split, metrics): - """Save metrics to file""" - metrics_file = os.path.join(self.args.output_dir, f"{split}_metrics.json") - with open(metrics_file, "w") as f: - import json - json.dump(metrics, f, indent=2) - - def save_state(self): - """Save trainer state""" - state_file = os.path.join(self.args.output_dir, "trainer_state.json") - state = { - "global_step": self.global_step, - "epoch": self.epoch, - "best_metric": self.best_metric, - } - with open(state_file, "w") as f: - import json - json.dump(state, f, indent=2) - - def create_model_card(self, **kwargs): - """Create model card for the trained model""" - # Simple model card creation - model_card = f""" -# Model Card - -## Model Details -- Model type: Causal Language Model -- Training framework: MindSpore/MindNLP -- Dataset: {kwargs.get('dataset_name', 'Unknown')} - -## Training Details -- Number of epochs: {self.args.num_train_epochs} -- Batch size: {self.args.per_device_train_batch_size} -- Learning rate: {self.args.learning_rate} -- Total steps: {self.global_step} - -## Tags -{kwargs.get('tags', [])} -""" - - with open(os.path.join(self.args.output_dir, "README.md"), "w") as f: - f.write(model_card) - - def push_to_hub(self, **kwargs): - """Push model to model hub (placeholder)""" - logger.warning("push_to_hub is not implemented for MindSpore models yet") - - -def save_checkpoint(model, path): - """Save model checkpoint""" - mindspore.save_checkpoint(model, path) - - -def load_checkpoint(path, model): - """Load model checkpoint""" - mindspore.load_checkpoint(path, model) From 370a07fa35cc927075daf87f01c372a7d1716c3d Mon Sep 17 00:00:00 2001 From: guojialiang <2802427218@qq.com> Date: Wed, 8 Oct 2025 13:38:45 +0800 Subject: [PATCH 4/7] add:grpo --- open_r1/sh/grpo.sh | 18 + open_r1/src/mind_openr1/grpo.py | 181 ++++++++ open_r1/src/mind_openr1/rewards.py | 706 +++++++++++++++++++++++++++++ 3 files changed, 905 insertions(+) create mode 100644 open_r1/sh/grpo.sh create mode 100644 open_r1/src/mind_openr1/grpo.py create mode 100644 open_r1/src/mind_openr1/rewards.py diff --git a/open_r1/sh/grpo.sh b/open_r1/sh/grpo.sh new file mode 100644 index 000000000..7891ac5f6 --- /dev/null +++ b/open_r1/sh/grpo.sh @@ -0,0 +1,18 @@ +PYTHONPATH=/home/ma-user/work/mind-openr1/src python src/mind_openr1/grpo.py \ + --model_name_or_path /home/ma-user/work/Qwen2.5-1.5B \ + --dataset_name open-r1/Mixture-of-Thoughts \ + --dataset_config math \ + --eos_token '<|im_end|>' \ + --learning_rate 4.0e-5 \ + --num_train_epochs 5 \ + --per_device_train_batch_size 1 \ + --gradient_checkpointing \ + --bf16 True \ + --torch_dtype bfloat16 \ + --output_dir checkpoints/Qwen2.5-1.5B-GRPO \ + --save_steps 100000 + + # --dataset_name /home/ma-user/work/mind-openr1/data/open-r1___mixture-of-thoughts + +# nohup bash sh/grpo.sh > /home/ma-user/work/mind-openr1/logs/grpo.log 2>&1 & + diff --git a/open_r1/src/mind_openr1/grpo.py b/open_r1/src/mind_openr1/grpo.py new file mode 100644 index 000000000..4dd812231 --- /dev/null +++ b/open_r1/src/mind_openr1/grpo.py @@ -0,0 +1,181 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import sys + +import datasets +import transformers +from transformers import set_seed +from transformers.trainer_utils import get_last_checkpoint + +from mind_openr1.configs import GRPOConfig, GRPOScriptArguments +from mind_openr1.rewards import get_reward_funcs +from mind_openr1.utils import get_dataset, get_model, get_tokenizer +from mind_openr1.utils.callbacks import get_callbacks +from mind_openr1.utils.wandb_logging import init_wandb_training +from trl import GRPOTrainer, ModelConfig, TrlParser, get_peft_config + + +logger = logging.getLogger(__name__) + + +def main(script_args, training_args, model_args): + # Set seed for reproducibility + set_seed(training_args.seed) + + ############### + # Setup logging + ############### + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout)], + ) + log_level = training_args.get_process_log_level() + logger.setLevel(log_level) + datasets.utils.logging.set_verbosity(log_level) + transformers.utils.logging.set_verbosity(log_level) + transformers.utils.logging.enable_default_handler() + transformers.utils.logging.enable_explicit_format() + + # Log on each process a small summary + logger.warning( + f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" + + f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" + ) + logger.info(f"Model parameters {model_args}") + logger.info(f"Script parameters {script_args}") + logger.info(f"Training parameters {training_args}") + + # Check for last checkpoint + last_checkpoint = None + if os.path.isdir(training_args.output_dir): + last_checkpoint = get_last_checkpoint(training_args.output_dir) + if last_checkpoint is not None and training_args.resume_from_checkpoint is None: + logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.") + + if "wandb" in training_args.report_to: + init_wandb_training(training_args) + + # Load the dataset + dataset = get_dataset(script_args) + + ################ + # Load tokenizer + ################ + tokenizer = get_tokenizer(model_args, training_args) + + ############## + # Load model # + ############## + logger.info("*** Loading model ***") + model = get_model(model_args, training_args) + + # Get reward functions from the registry + reward_funcs = get_reward_funcs(script_args) + + # Format into conversation + def make_conversation(example, prompt_column: str = script_args.dataset_prompt_column): + prompt = [] + + if training_args.system_prompt is not None: + prompt.append({"role": "system", "content": training_args.system_prompt}) + + if prompt_column not in example: + raise ValueError(f"Dataset Question Field Error: {prompt_column} is not supported.") + + prompt.append({"role": "user", "content": example[prompt_column]}) + return {"prompt": prompt} + + dataset = dataset.map(make_conversation) + + for split in dataset: + if "messages" in dataset[split].column_names: + dataset[split] = dataset[split].remove_columns("messages") + + ############################# + # Initialize the GRPO trainer + ############################# + trainer = GRPOTrainer( + model=model, + reward_funcs=reward_funcs, + args=training_args, + train_dataset=dataset[script_args.dataset_train_split], + eval_dataset=(dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None), + peft_config=get_peft_config(model_args), + callbacks=get_callbacks(training_args, model_args), + processing_class=tokenizer, + ) + + ############### + # Training loop + ############### + logger.info("*** Train ***") + checkpoint = None + if training_args.resume_from_checkpoint is not None: + checkpoint = training_args.resume_from_checkpoint + elif last_checkpoint is not None: + checkpoint = last_checkpoint + train_result = trainer.train(resume_from_checkpoint=checkpoint) + metrics = train_result.metrics + metrics["train_samples"] = len(dataset[script_args.dataset_train_split]) + trainer.log_metrics("train", metrics) + trainer.save_metrics("train", metrics) + trainer.save_state() + + ################################## + # Save model and create model card + ################################## + logger.info("*** Save model ***") + # Align the model's generation config with the tokenizer's eos token + # to avoid unbounded generation in the transformers `pipeline()` function + trainer.model.generation_config.eos_token_id = tokenizer.eos_token_id + trainer.save_model(training_args.output_dir) + logger.info(f"Model saved to {training_args.output_dir}") + + # Save everything else on main process + kwargs = { + "dataset_name": script_args.dataset_name, + "tags": ["open-r1"], + } + if trainer.accelerator.is_main_process: + trainer.create_model_card(**kwargs) + # Restore k,v cache for fast inference + trainer.model.config.use_cache = True + trainer.model.config.save_pretrained(training_args.output_dir) + + ########## + # Evaluate + ########## + if training_args.do_eval: + logger.info("*** Evaluate ***") + metrics = trainer.evaluate() + metrics["eval_samples"] = len(dataset[script_args.dataset_test_split]) + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) + + ############# + # push to hub + ############# + if training_args.push_to_hub: + logger.info("Pushing to hub...") + trainer.push_to_hub(**kwargs) + + +if __name__ == "__main__": + parser = TrlParser((GRPOScriptArguments, GRPOConfig, ModelConfig)) + script_args, training_args, model_args = parser.parse_args_and_config() + main(script_args, training_args, model_args) diff --git a/open_r1/src/mind_openr1/rewards.py b/open_r1/src/mind_openr1/rewards.py new file mode 100644 index 000000000..0b3662841 --- /dev/null +++ b/open_r1/src/mind_openr1/rewards.py @@ -0,0 +1,706 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Reward functions for GRPO training.""" + +import asyncio +import json +import math +import re +from functools import partial, update_wrapper +from typing import Callable, Dict, Literal, Optional + +from latex2sympy2_extended import NormalizationConfig +from math_verify import LatexExtractionConfig, parse, verify + +from .utils.code_providers import get_provider +from .utils.competitive_programming import ( + SubtaskResult, + add_includes, + get_morph_client_from_env, + get_piston_client_from_env, +) +from .utils.competitive_programming import patch_code as cf_patch_code +from .utils.competitive_programming import score_submission as cf_score_submission +from .utils.competitive_programming import score_subtask + + +def accuracy_reward(completions: list[list[dict[str, str]]], solution: list[str], **kwargs) -> list[Optional[float]]: + """Reward function that checks if the completion is the same as the ground truth.""" + contents = [completion[0]["content"] for completion in completions] + rewards = [] + for content, sol in zip(contents, solution): + gold_parsed = parse( + sol, + extraction_mode="first_match", + ) + if len(gold_parsed) != 0: + # We require the answer to be provided in correct latex (no malformed operators) + answer_parsed = parse( + content, + extraction_config=[ + LatexExtractionConfig( + normalization_config=NormalizationConfig( + nits=False, + malformed_operators=False, + basic_latex=True, + equations=True, + boxed="all", + units=True, + ), + # Ensures that boxed is tried first + boxed_match_priority=0, + try_extract_without_anchor=False, + ) + ], + extraction_mode="first_match", + ) + # Compute binary rewards if verifiable, `None` otherwise to skip this example + try: + reward = float(verify(gold_parsed, answer_parsed)) + except Exception as e: + print(f"verify failed: {e}, answer: {answer_parsed}, gold: {gold_parsed}") + reward = None + else: + # If the gold solution is not parseable, we assign `None` to skip this example + reward = None + print("Failed to parse gold solution: ", sol) + rewards.append(reward) + + return rewards + + +def format_reward(completions, **kwargs): + """Reward function that checks if the reasoning process is enclosed within and tags, while the final answer is enclosed within and tags.""" + pattern = r"^\n.*?\n\n\n.*?\n$" + completion_contents = [completion[0]["content"] for completion in completions] + matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completion_contents] + return [1.0 if match else 0.0 for match in matches] + + +def tag_count_reward(completions, **kwargs) -> list[float]: + """Reward function that checks if we produce the desired number of think and answer tags associated with `format_reward()`. + + Adapted from: https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb#file-grpo_demo-py-L90 + """ + + def count_tags(text: str) -> float: + count = 0.0 + if text.count("\n") == 1: + count += 0.25 + if text.count("\n\n") == 1: + count += 0.25 + if text.count("\n\n") == 1: + count += 0.25 + if text.count("\n") == 1: + count += 0.25 + return count + + contents = [completion[0]["content"] for completion in completions] + return [count_tags(c) for c in contents] + + +def reasoning_steps_reward(completions, **kwargs): + r"""Reward function that checks for clear step-by-step reasoning. + Regex pattern: + Step \d+: - matches "Step 1:", "Step 2:", etc. + ^\d+\. - matches numbered lists like "1.", "2.", etc. at start of line + \n- - matches bullet points with hyphens + \n\* - matches bullet points with asterisks + First,|Second,|Next,|Finally, - matches transition words + """ + pattern = r"(Step \d+:|^\d+\.|\n-|\n\*|First,|Second,|Next,|Finally,)" + completion_contents = [completion[0]["content"] for completion in completions] + matches = [len(re.findall(pattern, content)) for content in completion_contents] + + # Magic number 3 to encourage 3 steps and more, otherwise partial reward + return [min(1.0, count / 3) for count in matches] + + +def len_reward(completions: list[Dict[str, str]], solution: list[str], **kwargs) -> float: + """Compute length-based rewards to discourage overthinking and promote token efficiency. + + Taken from the Kimi 1.5 tech report: https://huggingface.co/papers/2501.12599 + + Args: + completions: List of model completions + solution: List of ground truth solutions + + Returns: + List of rewards where: + - For correct answers: reward = 0.5 - (len - min_len)/(max_len - min_len) + - For incorrect answers: reward = min(0, 0.5 - (len - min_len)/(max_len - min_len)) + """ + contents = [completion[0]["content"] for completion in completions] + + # First check correctness of answers + correctness = [] + for content, sol in zip(contents, solution): + gold_parsed = parse( + sol, + extraction_mode="first_match", + extraction_config=[LatexExtractionConfig()], + ) + if len(gold_parsed) == 0: + # Skip unparseable examples + correctness.append(True) # Treat as correct to avoid penalizing + print("Failed to parse gold solution: ", sol) + continue + + answer_parsed = parse( + content, + extraction_config=[ + LatexExtractionConfig( + normalization_config=NormalizationConfig( + nits=False, + malformed_operators=False, + basic_latex=True, + equations=True, + boxed=True, + units=True, + ), + boxed_match_priority=0, + try_extract_without_anchor=False, + ) + ], + extraction_mode="first_match", + ) + correctness.append(verify(answer_parsed, gold_parsed)) + + # Calculate lengths + lengths = [len(content) for content in contents] + min_len = min(lengths) + max_len = max(lengths) + + # If all responses have the same length, return zero rewards + if max_len == min_len: + return [0.0] * len(completions) + + rewards = [] + for length, is_correct in zip(lengths, correctness): + lambda_val = 0.5 - (length - min_len) / (max_len - min_len) + + if is_correct: + reward = lambda_val + else: + reward = min(0, lambda_val) + + rewards.append(float(reward)) + + return rewards + + +def get_cosine_scaled_reward( + min_value_wrong: float = -1.0, + max_value_wrong: float = -0.5, + min_value_correct: float = 0.5, + max_value_correct: float = 1.0, + max_len: int = 1000, +): + def cosine_scaled_reward(completions, solution, **kwargs): + """Reward function that scales based on completion length using a cosine schedule. + + Shorter correct solutions are rewarded more than longer ones. + Longer incorrect solutions are penalized less than shorter ones. + + Args: + completions: List of model completions + solution: List of ground truth solutions + + This function is parameterized by the following arguments: + min_value_wrong: Minimum reward for wrong answers + max_value_wrong: Maximum reward for wrong answers + min_value_correct: Minimum reward for correct answers + max_value_correct: Maximum reward for correct answers + max_len: Maximum length for scaling + """ + contents = [completion[0]["content"] for completion in completions] + rewards = [] + + for content, sol in zip(contents, solution): + gold_parsed = parse( + sol, + extraction_mode="first_match", + extraction_config=[LatexExtractionConfig()], + ) + if len(gold_parsed) == 0: + rewards.append(1.0) # Skip unparseable examples + print("Failed to parse gold solution: ", sol) + continue + + answer_parsed = parse( + content, + extraction_config=[ + LatexExtractionConfig( + normalization_config=NormalizationConfig( + nits=False, + malformed_operators=False, + basic_latex=True, + equations=True, + boxed=True, + units=True, + ), + boxed_match_priority=0, + try_extract_without_anchor=False, + ) + ], + extraction_mode="first_match", + ) + + is_correct = verify(answer_parsed, gold_parsed) + gen_len = len(content) + + # Apply cosine scaling based on length + progress = gen_len / max_len + cosine = math.cos(progress * math.pi) + + if is_correct: + min_value = min_value_correct + max_value = max_value_correct + else: + # Swap min/max for incorrect answers + min_value = max_value_wrong + max_value = min_value_wrong + + reward = min_value + 0.5 * (max_value - min_value) * (1.0 + cosine) + rewards.append(float(reward)) + + return rewards + + return cosine_scaled_reward + + +def get_repetition_penalty_reward(ngram_size: int, max_penalty: float, language: str = "en"): + """ + Computes N-gram repetition penalty as described in Appendix C.2 of https://huggingface.co/papers/2502.03373. + Reference implementation from: https://github.com/eddycmu/demystify-long-cot/blob/release/openrlhf/openrlhf/reward/repetition.py + + Args: + ngram_size: size of the n-grams + max_penalty: Maximum (negative) penalty for wrong answers + language: Language of the text, defaults to `en`. Used to choose the way to split the text into n-grams. + """ + if max_penalty > 0: + raise ValueError(f"max_penalty {max_penalty} should not be positive") + + if language == "en": + + def zipngram(text: str, ngram_size: int): + words = text.lower().split() + return zip(*[words[i:] for i in range(ngram_size)]), words + + elif language == "zh": + from transformers.utils.import_utils import _is_package_available + + if not _is_package_available("jieba"): + raise ValueError("Please install jieba to use Chinese language") + + def zipngram(text: str, ngram_size: int): + import jieba + + seg_list = list(jieba.cut(text)) + return zip(*[seg_list[i:] for i in range(ngram_size)]), seg_list + + else: + raise ValueError( + f"Word splitting for language `{language}` is not yet implemented. Please implement your own zip-ngram function." + ) + + def repetition_penalty_reward(completions, **kwargs) -> float: + """ + reward function the penalizes repetitions + ref implementation: https://github.com/eddycmu/demystify-long-cot/blob/release/openrlhf/openrlhf/reward/repetition.py + + Args: + completions: List of model completions + """ + + contents = [completion[0]["content"] for completion in completions] + rewards = [] + for completion in contents: + if completion == "": + rewards.append(0.0) + continue + + ngrams = set() + total = 0 + ngram_array, words = zipngram(completion, ngram_size) + + if len(words) < ngram_size: + rewards.append(0.0) + continue + + for ng in ngram_array: + ngrams.add(ng) + total += 1 + + scaling = 1 - len(ngrams) / total + reward = scaling * max_penalty + rewards.append(reward) + return rewards + + return repetition_penalty_reward + + +def _init_event_loop(): + """Initialize or get the current event loop.""" + try: + loop = asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + return loop + + +def ioi_code_reward(completions, test_batch_size: int = 1, provider_type: str = "piston", **kwargs) -> list[float]: + """Reward function that evaluates IOI problems using a specified execution client. + + Assumes the dataset has the same format as hf.co/datasets/open-r1/ioi + + Args: + completions: List of model completions to evaluate + test_batch_size: Evaluate these many test cases in parallel, then check if any of them failed (0 score): + if so stop evaluating; otherwise continue with the next batch of test cases. + provider_type: The execution provider to use (default: "piston"). Supported values: "piston", "morph" + **kwargs: Additional arguments passed from the dataset + """ + # Get the appropriate client based on provider_type + if provider_type == "morph": + execution_client = get_morph_client_from_env() + else: + # for info on setting up piston workers, see slurm/piston/README.md + execution_client = get_piston_client_from_env() + + code_snippets = [ + # note: grading is automatically skipped if no code is extracted + add_includes(extract_code(completion[-1]["content"], "cpp"), problem_id) + for completion, problem_id in zip(completions, kwargs["id"]) + ] + + async def run_catch_exceptions(task): + try: + return await task + except Exception as e: + print(f"Error from {provider_type} worker: {e}") + return SubtaskResult() + + problems_data = [dict(zip(kwargs.keys(), values)) for values in zip(*kwargs.values())] + + loop = _init_event_loop() + evals = [ + loop.create_task( + run_catch_exceptions( + score_subtask( + execution_client, + problem_data, + code, + test_batch_size=test_batch_size, + ) + ) + ) + for problem_data, code in zip(problems_data, code_snippets) + ] + results = loop.run_until_complete(asyncio.gather(*evals)) + + return [result.score for result in results] + + +def cf_code_reward( + completions, + test_batch_size: int = 1, + patch_code: bool = False, + scoring_mode: Literal["pass_fail", "partial", "weighted_sum"] = "weighted_sum", + **kwargs, +) -> list[float]: + """Reward function that evaluates Codeforces problems using Piston+our CF package. + + Assumes the dataset has the same format as hf.co/datasets/open-r1/codeforces (verifiable-prompts subset) + + test_batch_size: evaluate these many test cases in parallel, then check if any of them failed (0 score): if so stop evaluating; otherwise continue with the next batch of test cases. + """ + # for info on setting up piston workers, see slurm/piston/README.md + piston_client = get_piston_client_from_env() + + languages = kwargs["language"] if "language" in kwargs else [None] * len(completions) + code_snippets = [ + # note: grading is automatically skipped if a problem has no tests + cf_patch_code(extract_code(completion[-1]["content"], language), language) + if patch_code + else extract_code(completion[-1]["content"], language) + for completion, language in zip(completions, languages) + ] + + async def run_catch_exceptions(task): + try: + return await task + except Exception as e: + print(f"Error from Piston worker: {e}") + return None + + # load problem data. undo separating kwargs by column + problems_data = [dict(zip(kwargs.keys(), values)) for values in zip(*kwargs.values())] + + loop = _init_event_loop() + evals = [ + loop.create_task( + run_catch_exceptions( + cf_score_submission( + piston_client, + problem_data, + code, + test_batch_size=test_batch_size, + scoring_mode=scoring_mode, + submission_language=problem_data.get("language", None), + ) + ) + ) + for problem_data, code in zip(problems_data, code_snippets) + ] + results = loop.run_until_complete(asyncio.gather(*evals)) + + return results + + +def extract_code(completion: str, language: str | None = "python") -> str: + if language is None: + return "" + pattern = re.compile(rf"```{language}\n(.*?)```", re.DOTALL) + matches = pattern.findall(completion) + extracted_answer = matches[-1] if len(matches) >= 1 else "" + return extracted_answer + + +def binary_code_reward( + completions, + num_parallel: int = 2, + provider_type: str = "e2b", + enforce_same_language: bool = False, + **kwargs, +) -> list[float]: + rewards = code_reward( + completions, + num_parallel=num_parallel, + provider_type=provider_type, + enforce_same_language=enforce_same_language, + **kwargs, + ) + BINARY_THRESHOLD = 0.99 + + output = [] + for reward in rewards: + if reward is None: + output.append(None) + else: + output.append(1.0 if reward > BINARY_THRESHOLD else 0.0) + + return output + + +def code_reward( + completions, + num_parallel: int = 2, + provider_type: str = "e2b", + enforce_same_language: bool = False, + **kwargs, +) -> list[float]: + """Reward function that evaluates code snippets using a code execution provider. + + Assumes the dataset contains a `verification_info` column with test cases. + + Args: + completions: List of model completions to evaluate + num_parallel: Number of parallel code executions (default: 2) + provider_type: Which code execution provider to use (default: "e2b") + enforce_same_language: If True, verify all problems use the same language (default: False) + **kwargs: Additional arguments passed to the verification + """ + evaluation_script_template = """ + import subprocess + import json + + def evaluate_code(code, test_cases): + passed = 0 + total = len(test_cases) + exec_timeout = 5 + + for case in test_cases: + process = subprocess.run( + ["python3", "-c", code], + input=case["input"], + text=True, + capture_output=True, + timeout=exec_timeout + ) + + if process.returncode != 0: # Error in execution + continue + + output = process.stdout.strip() + + # TODO: implement a proper validator to compare against ground truth. For now we just check for exact string match on each line of stdout. + all_correct = True + for line1, line2 in zip(output.split('\\n'), case['output'].split('\\n')): + all_correct = all_correct and line1.strip() == line2.strip() + + if all_correct: + passed += 1 + + success_rate = (passed / total) + return success_rate + + code_snippet = {code} + test_cases = json.loads({test_cases}) + + evaluate_code(code_snippet, test_cases) + """ + + code_snippets = [extract_code(completion[-1]["content"]) for completion in completions] + verification_info = kwargs["verification_info"] + + template = evaluation_script_template + + scripts = [ + template.format(code=json.dumps(code), test_cases=json.dumps(json.dumps(info["test_cases"]))) + for code, info in zip(code_snippets, verification_info) + ] + + language = verification_info[0]["language"] + + if enforce_same_language: + all_same_language = all(v["language"] == language for v in verification_info) + if not all_same_language: + raise ValueError("All verification_info must have the same language", verification_info) + + execution_provider = get_provider( + provider_type=provider_type, + num_parallel=num_parallel, + **kwargs, + ) + + return execution_provider.execute_scripts(scripts, ["python"] * len(scripts)) + + +def get_code_format_reward(language: str = "python"): + """Format reward function specifically for code responses. + + Args: + language: Programming language supported by E2B https://e2b.dev/docs/code-interpreting/supported-languages + """ + + def code_format_reward(completions, **kwargs): + # if there is a language field, use it instead of the default language. This way we can have mixed language training. + languages = kwargs["language"] if "language" in kwargs else [language] * len(completions) + + completion_contents = [completion[0]["content"] for completion in completions] + matches = [ + re.match( + rf"^\n.*?\n\n\n.*?```{sample_language}.*?```.*?\n$", + content, + re.DOTALL | re.MULTILINE, + ) + for content, sample_language in zip(completion_contents, languages) + ] + return [1.0 if match else 0.0 for match in matches] + + return code_format_reward + + +def get_soft_overlong_punishment(max_completion_len, soft_punish_cache): + """ + Reward function that penalizes overlong completions. It is used to penalize overlong completions, + but not to reward shorter completions. Reference: Eq. (13) from the DAPO paper (https://huggingface.co/papers/2503.14476) + + Args: + max_completion_len: Maximum length of the completion + soft_punish_cache: Minimum length of the completion. If set to 0, no minimum length is applied. + """ + + def soft_overlong_punishment_reward(completion_ids: list[list[int]], **kwargs) -> list[float]: + """Reward function that penalizes overlong completions.""" + rewards = [] + for ids in completion_ids: + completion_length = len(ids) + if completion_length <= max_completion_len - soft_punish_cache: + rewards.append(0.0) + elif max_completion_len - soft_punish_cache < completion_length <= max_completion_len: + rewards.append((max_completion_len - soft_punish_cache - completion_length) / soft_punish_cache) + else: + rewards.append(-1.0) + return rewards + + return soft_overlong_punishment_reward + + +def get_reward_funcs(script_args) -> list[Callable]: + REWARD_FUNCS_REGISTRY = { + "accuracy": accuracy_reward, + "format": format_reward, + "reasoning_steps": reasoning_steps_reward, + "cosine": get_cosine_scaled_reward( + min_value_wrong=script_args.cosine_min_value_wrong, + max_value_wrong=script_args.cosine_max_value_wrong, + min_value_correct=script_args.cosine_min_value_correct, + max_value_correct=script_args.cosine_max_value_correct, + max_len=script_args.cosine_max_len, + ), + "repetition_penalty": get_repetition_penalty_reward( + ngram_size=script_args.repetition_n_grams, + max_penalty=script_args.repetition_max_penalty, + ), + "length": len_reward, + "code": update_wrapper( + partial( + code_reward, + num_parallel=script_args.parallel_code_exec_per_proc, + provider_type=script_args.code_provider, + enforce_same_language=getattr(script_args, "enforce_same_language", False), + ), + code_reward, + ), + "binary_code": update_wrapper( + partial( + binary_code_reward, + num_parallel=script_args.parallel_code_exec_per_proc, + provider_type=script_args.code_provider, + enforce_same_language=getattr(script_args, "enforce_same_language", False), + ), + binary_code_reward, + ), + "ioi_code": update_wrapper( + partial( + ioi_code_reward, + test_batch_size=script_args.code_eval_test_batch_size, + provider_type=getattr(script_args, "ioi_provider", "piston"), + ), + ioi_code_reward, + ), + "cf_code": update_wrapper( + partial( + cf_code_reward, + test_batch_size=script_args.code_eval_test_batch_size, + scoring_mode=script_args.code_eval_scoring_mode, + ), + cf_code_reward, + ), + "code_format": get_code_format_reward(language=script_args.code_language), + "tag_count": tag_count_reward, + "soft_overlong_punishment": get_soft_overlong_punishment( + max_completion_len=script_args.max_completion_len, + soft_punish_cache=script_args.soft_punish_cache, + ), + } + reward_funcs = [REWARD_FUNCS_REGISTRY[func] for func in script_args.reward_funcs] + + return reward_funcs From 61dfbd364af19fc6047c16a19b8c02276e82096d Mon Sep 17 00:00:00 2001 From: guojialiang <2802427218@qq.com> Date: Wed, 8 Oct 2025 14:16:55 +0800 Subject: [PATCH 5/7] add:readme --- open_r1/module.diff | 4771 +++++++++++++++++++++++++++++++++++++++++++ open_r1/readme.md | 310 +++ 2 files changed, 5081 insertions(+) create mode 100644 open_r1/module.diff diff --git a/open_r1/module.diff b/open_r1/module.diff new file mode 100644 index 000000000..309fbfa4e --- /dev/null +++ b/open_r1/module.diff @@ -0,0 +1,4771 @@ +diff --git a/mindtorch/nn/modules/module.py b/mindtorch/nn/modules/module.py +index bf975582..c7fa526d 100644 +--- a/mindtorch/nn/modules/module.py ++++ b/mindtorch/nn/modules/module.py +@@ -1,2373 +1,2393 @@ +-"""Module""" +-import warnings +-import weakref +-import functools +-import inspect +-from typing import Dict, Optional, Callable, Set, overload, TypeVar, Any, Iterator, Tuple, Union, \ +- Mapping, List +-import itertools +-from collections import OrderedDict, namedtuple +-import mindspore +-try: +- from mindspore.common._stub_tensor import StubTensor +-except: +- class StubTensor: pass +- +-import mindtorch +-from mindtorch import device, dtype, Tensor +- +-from ..parameter import Parameter, Buffer +-from ...utils import hooks +-from ...utils.hooks import RemovableHandle +- +-_grad_t = Union[Tuple[Tensor, ...], Tensor] +-T = TypeVar('T', bound='Module') +- +-class _IncompatibleKeys(namedtuple('IncompatibleKeys', ['missing_keys', 'unexpected_keys'])): +- def __repr__(self): +- if not self.missing_keys and not self.unexpected_keys: +- return '' +- return super().__repr__() +- +- __str__ = __repr__ +- +-def _addindent(s_, numSpaces): +- s = s_.split('\n') +- # don't do anything for single-line stuff +- if len(s) == 1: +- return s_ +- first = s.pop(0) +- s = [(numSpaces * ' ') + line for line in s] +- s = '\n'.join(s) +- s = first + '\n' + s +- return s +- +-_EXTRA_STATE_KEY_SUFFIX = '_extra_state' +- +-_global_buffer_registration_hooks: Dict[int, Callable] = OrderedDict() +-_global_module_registration_hooks: Dict[int, Callable] = OrderedDict() +-_global_parameter_registration_hooks: Dict[int, Callable] = OrderedDict() +- +- +-_global_backward_pre_hooks: Dict[int, Callable] = OrderedDict() +-_global_backward_hooks: Dict[int, Callable] = OrderedDict() +-_global_is_full_backward_hook: Optional[bool] = None +-_global_forward_pre_hooks: Dict[int, Callable] = OrderedDict() +-_global_forward_hooks: Dict[int, Callable] = OrderedDict() +-_global_forward_hooks_always_called: Dict[int, bool] = OrderedDict() +-_global_forward_hooks_with_kwargs: Dict[int, bool] = OrderedDict() +- +- +-class _WrappedHook: +- def __init__(self, hook: Callable, module: Optional["Module"] = None): +- self.hook: Callable = hook +- functools.update_wrapper(self, hook) +- +- self.with_module: bool = False +- +- if module is not None: +- self.module: weakref.ReferenceType[Module] = weakref.ref(module) +- self.with_module = True +- +- def __call__(self, *args: Any, **kwargs: Any) -> Any: +- if self.with_module: +- module = self.module() +- if module is None: +- raise RuntimeError("You are trying to call the hook of a dead Module!") +- return self.hook(module, *args, **kwargs) +- return self.hook(*args, **kwargs) +- +- def __getstate__(self) -> Dict: +- result = {"hook": self.hook, "with_module": self.with_module} +- if self.with_module: +- result["module"] = self.module() +- +- return result +- +- def __setstate__(self, state: Dict): +- self.hook = state["hook"] +- self.with_module = state["with_module"] +- +- if self.with_module: +- if state["module"] is None: +- raise RuntimeError("You are trying to revive the hook of a dead Module!") +- self.module = weakref.ref(state["module"]) +- +- +-def register_module_buffer_registration_hook( +- hook: Callable[..., None], +-) -> RemovableHandle: +- r"""Register a buffer registration hook common to all modules. +- +- .. warning :: +- +- This adds global state to the `nn.Module` module +- +- The hook will be called every time :func:`register_buffer` is invoked. +- It should have the following signature:: +- +- hook(module, name, buffer) -> None or new buffer +- +- The hook can modify the input or return a single modified value in the hook. +- +- Returns: +- :class:`mindtorch.utils.hooks.RemovableHandle`: +- a handle that can be used to remove the added hook by calling +- ``handle.remove()`` +- """ +- handle = RemovableHandle(_global_buffer_registration_hooks) +- _global_buffer_registration_hooks[handle.id] = hook +- return handle +- +- +-def register_module_module_registration_hook( +- hook: Callable[..., None], +-) -> RemovableHandle: +- r"""Register a module registration hook common to all modules. +- +- .. warning :: +- +- This adds global state to the `nn.Module` module +- +- The hook will be called every time :func:`register_module` is invoked. +- It should have the following signature:: +- +- hook(module, name, submodule) -> None or new submodule +- +- The hook can modify the input or return a single modified value in the hook. +- +- Returns: +- :class:`mindtorch.utils.hooks.RemovableHandle`: +- a handle that can be used to remove the added hook by calling +- ``handle.remove()`` +- """ +- handle = RemovableHandle(_global_module_registration_hooks) +- _global_module_registration_hooks[handle.id] = hook +- return handle +- +- +-def register_module_parameter_registration_hook( +- hook: Callable[..., None], +-) -> RemovableHandle: +- r"""Register a parameter registration hook common to all modules. +- +- .. warning :: +- +- This adds global state to the `nn.Module` module +- +- The hook will be called every time :func:`register_parameter` is invoked. +- It should have the following signature:: +- +- hook(module, name, param) -> None or new parameter +- +- The hook can modify the input or return a single modified value in the hook. +- +- Returns: +- :class:`mindtorch.utils.hooks.RemovableHandle`: +- a handle that can be used to remove the added hook by calling +- ``handle.remove()`` +- """ +- handle = RemovableHandle(_global_parameter_registration_hooks) +- _global_parameter_registration_hooks[handle.id] = hook +- return handle +- +- +-def register_module_forward_pre_hook(hook: Callable[..., None]) -> RemovableHandle: +- r"""Register a forward pre-hook common to all modules. +- +- .. warning :: +- +- This adds global state to the `nn.module` module +- and it is only intended for debugging/profiling purposes. +- +- The hook will be called every time before :func:`forward` is invoked. +- It should have the following signature:: +- +- hook(module, input) -> None or modified input +- +- The input contains only the positional arguments given to the module. +- Keyword arguments won't be passed to the hooks and only to the ``forward``. +- The hook can modify the input. User can either return a tuple or a +- single modified value in the hook. We will wrap the value into a tuple +- if a single value is returned(unless that value is already a tuple). +- +- This hook has precedence over the specific module hooks registered with +- ``register_forward_pre_hook``. +- +- Returns: +- :class:`mindtorch.utils.hooks.RemovableHandle`: +- a handle that can be used to remove the added hook by calling +- ``handle.remove()`` +- """ +- handle = RemovableHandle(_global_forward_pre_hooks) +- _global_forward_pre_hooks[handle.id] = hook +- return handle +- +- +-def register_module_forward_hook( +- hook: Callable[..., None], +- *, +- with_kwargs: bool = False, +- always_call: bool = False, +-) -> RemovableHandle: +- r"""Register a global forward hook for all the modules. +- +- .. warning :: +- +- This adds global state to the `nn.module` module +- and it is only intended for debugging/profiling purposes. +- +- The hook will be called every time after :func:`forward` has computed an output. +- It should have the following signature:: +- +- hook(module, input, output) -> None or modified output +- +- The input contains only the positional arguments given to the module. +- Keyword arguments won't be passed to the hooks and only to the ``forward``. +- You can optionally modify the output of the module by returning a new value +- that will replace the output from the :func:`forward` function. +- +- Parameters: +- hook (Callable): The user defined hook to be registered. +- always_call (bool): If ``True`` the ``hook`` will be run regardless of +- whether an exception is raised while calling the Module. +- Default: ``False`` +- Returns: +- :class:`mindtorch.utils.hooks.RemovableHandle`: +- a handle that can be used to remove the added hook by calling +- ``handle.remove()`` +- +- This hook will be executed before specific module hooks registered with +- ``register_forward_hook``. +- """ +- handle = RemovableHandle( +- _global_forward_hooks, extra_dict=_global_forward_hooks_always_called +- ) +- _global_forward_hooks[handle.id] = hook +- if with_kwargs: +- _global_forward_hooks_with_kwargs[handle.id] = True +- if always_call: +- _global_forward_hooks_always_called[handle.id] = True +- return handle +- +- +-def register_module_backward_hook( +- hook: Callable[["Module", _grad_t, _grad_t], Union[None, _grad_t]], +-) -> RemovableHandle: +- r"""Register a backward hook common to all the modules. +- +- This function is deprecated in favor of +- :func:`mindtorch.nn.modules.module.register_module_full_backward_hook` +- and the behavior of this function will change in future versions. +- +- Returns: +- :class:`mindtorch.utils.hooks.RemovableHandle`: +- a handle that can be used to remove the added hook by calling +- ``handle.remove()`` +- +- """ +- global _global_is_full_backward_hook +- if _global_is_full_backward_hook is True: +- raise RuntimeError( +- "Cannot use both regular backward hooks and full backward hooks as a " +- "global Module hook. Please use only one of them." +- ) +- +- _global_is_full_backward_hook = False +- +- handle = RemovableHandle(_global_backward_hooks) +- _global_backward_hooks[handle.id] = hook +- return handle +- +- +-def register_module_full_backward_pre_hook( +- hook: Callable[["Module", _grad_t], Union[None, _grad_t]], +-) -> RemovableHandle: +- r"""Register a backward pre-hook common to all the modules. +- +- .. warning :: +- This adds global state to the `nn.module` module +- and it is only intended for debugging/profiling purposes. +- +- Hooks registered using this function behave in the same way as those +- registered by :meth:`mindtorch.nn.Module.register_full_backward_pre_hook`. +- Refer to its documentation for more details. +- +- Hooks registered using this function will be called before hooks registered +- using :meth:`mindtorch.nn.Module.register_full_backward_pre_hook`. +- +- Returns: +- :class:`mindtorch.utils.hooks.RemovableHandle`: +- a handle that can be used to remove the added hook by calling +- ``handle.remove()`` +- +- """ +- handle = RemovableHandle(_global_backward_pre_hooks) +- _global_backward_pre_hooks[handle.id] = hook +- return handle +- +- +-def register_module_full_backward_hook( +- hook: Callable[["Module", _grad_t, _grad_t], Union[None, _grad_t]], +-) -> RemovableHandle: +- r"""Register a backward hook common to all the modules. +- +- .. warning :: +- This adds global state to the `nn.module` module +- and it is only intended for debugging/profiling purposes. +- +- Hooks registered using this function behave in the same way as those +- registered by :meth:`mindtorch.nn.Module.register_full_backward_hook`. +- Refer to its documentation for more details. +- +- Hooks registered using this function will be called before hooks registered +- using :meth:`mindtorch.nn.Module.register_full_backward_hook`. +- +- Returns: +- :class:`mindtorch.utils.hooks.RemovableHandle`: +- a handle that can be used to remove the added hook by calling +- ``handle.remove()`` +- +- """ +- global _global_is_full_backward_hook +- if _global_is_full_backward_hook is False: +- raise RuntimeError( +- "Cannot use both regular backward hooks and full backward hooks as a " +- "global Module hook. Please use only one of them." +- ) +- +- _global_is_full_backward_hook = True +- +- handle = RemovableHandle(_global_backward_hooks) +- _global_backward_hooks[handle.id] = hook +- return handle +- +- +-# Trick mypy into not applying contravariance rules to inputs by defining +-# forward as a value, rather than a function. See also +-# https://github.com/python/mypy/issues/8795 +-def _forward_unimplemented(self, *input: Any) -> None: +- r"""Define the computation performed at every call. +- +- Should be overridden by all subclasses. +- +- .. note:: +- Although the recipe for forward pass needs to be defined within +- this function, one should call the :class:`Module` instance afterwards +- instead of this since the former takes care of running the +- registered hooks while the latter silently ignores them. +- """ +- raise NotImplementedError( +- f'Module [{type(self).__name__}] is missing the required "forward" function' +- ) +- +-class Module: +- r"""Base class for all neural network modules. +- +- Your models should also subclass this class. +- +- Modules can also contain other Modules, allowing to nest them in +- a tree structure. You can assign the submodules as regular attributes:: +- +- import minispore.nn as nn +- import minispore.nn.functional as F +- +- class Model(nn.Module): +- def __init__(self): +- super(Model, self).__init__() +- self.conv1 = nn.Conv2d(1, 20, 5) +- self.conv2 = nn.Conv2d(20, 20, 5) +- +- def forward(self, x): +- x = F.relu(self.conv1(x)) +- return F.relu(self.conv2(x)) +- """ +- +- __ms_class__ = False +- training: bool +- _parameters: Dict[str, Optional[Parameter]] +- _buffers: Dict[str, Optional[Tensor]] +- _non_persistent_buffers_set: Set[str] +- _backward_pre_hooks: Dict[int, Callable] +- _backward_hooks: Dict[int, Callable] +- _is_full_backward_hook: Optional[bool] +- _forward_hooks: Dict[int, Callable] +- # Marks whether the corresponding _forward_hooks accept kwargs or not. +- # As JIT does not support Set[int], this dict is used as a set, where all +- # hooks represented in this dict accept kwargs. +- _forward_hooks_with_kwargs: Dict[int, bool] +- # forward hooks that should always be called even if an exception is raised +- _forward_hooks_always_called: Dict[int, bool] +- _forward_pre_hooks: Dict[int, Callable] +- # Marks whether the corresponding _forward_hooks accept kwargs or not. +- # As JIT does not support Set[int], this dict is used as a set, where all +- # hooks represented in this dict accept kwargs. +- _forward_pre_hooks_with_kwargs: Dict[int, bool] +- _state_dict_hooks: Dict[int, Callable] +- _load_state_dict_pre_hooks: Dict[int, Callable] +- _state_dict_pre_hooks: Dict[int, Callable] +- _load_state_dict_post_hooks: Dict[int, Callable] +- _modules: Dict[str, Optional['Module']] +- call_super_init: bool = False +- _compiled_call_impl : Optional[Callable] = None +- +- def __init__(self): +- """ +- Calls super().__setattr__('a', a) instead of the typical self.a = a +- to avoid Module.__setattr__ overhead. Module's __setattr__ has special +- handling for parameters, submodules, and buffers but simply calls into +- super().__setattr__ for all other attributes. +- """ +- super().__setattr__('training', True) +- super().__setattr__('_parameters', OrderedDict()) +- super().__setattr__('_buffers', OrderedDict()) +- super().__setattr__('_non_persistent_buffers_set', set()) +- super().__setattr__('_backward_pre_hooks', OrderedDict()) +- super().__setattr__('_backward_hooks', OrderedDict()) +- super().__setattr__('_is_full_backward_hook', None) +- super().__setattr__('_forward_hooks', OrderedDict()) +- super().__setattr__('_forward_hooks_with_kwargs', OrderedDict()) +- super().__setattr__('_forward_hooks_always_called', OrderedDict()) +- super().__setattr__('_forward_pre_hooks', OrderedDict()) +- super().__setattr__('_forward_pre_hooks_with_kwargs', OrderedDict()) +- super().__setattr__('_state_dict_hooks', OrderedDict()) +- super().__setattr__('_state_dict_pre_hooks', OrderedDict()) +- super().__setattr__('_load_state_dict_pre_hooks', OrderedDict()) +- super().__setattr__('_load_state_dict_post_hooks', OrderedDict()) +- super().__setattr__('_modules', OrderedDict()) +- +- def forward(self, *input, **kwargs): +- """Defines the computation performed at every call. +- +- Should be overriden by all subclasses. +- +- .. note:: +- Although the recipe for forward pass needs to be defined within +- this function, one should call the :class:`Module` instance afterwards +- instead of this since the former takes care of running the +- registered hooks while the latter silently ignores them. +- """ +- raise NotImplementedError +- +- def register_buffer(self, name: str, tensor: Optional[Tensor], persistent: bool = True) -> None: +- r"""Add a buffer to the module. +- +- This is typically used to register a buffer that should not to be +- considered a model parameter. For example, BatchNorm's ``running_mean`` +- is not a parameter, but is part of the module's state. Buffers, by +- default, are persistent and will be saved alongside parameters. This +- behavior can be changed by setting :attr:`persistent` to ``False``. The +- only difference between a persistent buffer and a non-persistent buffer +- is that the latter will not be a part of this module's +- :attr:`state_dict`. +- +- Buffers can be accessed as attributes using given names. +- +- Args: +- name (str): name of the buffer. The buffer can be accessed +- from this module using the given name +- tensor (Tensor or None): buffer to be registered. If ``None``, then operations +- that run on buffers, such as :attr:`cuda`, are ignored. If ``None``, +- the buffer is **not** included in the module's :attr:`state_dict`. +- persistent (bool): whether the buffer is part of this module's +- :attr:`state_dict`. +- +- Example:: +- +- >>> # xdoctest: +SKIP("undefined vars") +- >>> self.register_buffer('running_mean', ops.zeros(num_features)) +- +- """ +- if '_buffers' not in self.__dict__: +- raise AttributeError( +- "cannot assign buffer before Module.__init__() call") +- elif not isinstance(name, str): +- raise TypeError(f"buffer name should be a string. Got {type(name)}") +- elif '.' in name: +- raise KeyError("buffer name can't contain \".\"") +- elif name == '': +- raise KeyError("buffer name can't be empty string \"\"") +- elif hasattr(self, name) and name not in self._buffers: +- raise KeyError(f"attribute '{name}' already exists") +- elif tensor is not None and not isinstance(tensor, mindtorch.Tensor): +- raise TypeError(f"cannot assign '{type(tensor)}' object to buffer '{name}' " +- "(torch Tensor or None required)" +- ) +- else: +- for hook in _global_buffer_registration_hooks.values(): +- output = hook(self, name, tensor) +- if output is not None: +- tensor = output +- if isinstance(tensor, StubTensor): +- tensor = mindspore.Tensor(tensor.stub_sync()) +- self._buffers[name] = tensor +- if persistent: +- self._non_persistent_buffers_set.discard(name) +- else: +- self._non_persistent_buffers_set.add(name) +- +- def register_parameter(self, name: str, param: Optional[Parameter]) -> None: +- r"""Add a parameter to the module. +- +- The parameter can be accessed as an attribute using given name. +- +- Args: +- name (str): name of the parameter. The parameter can be accessed +- from this module using the given name +- param (Parameter or None): parameter to be added to the module. If +- ``None``, then operations that run on parameters, such as :attr:`cuda`, +- are ignored. If ``None``, the parameter is **not** included in the +- module's :attr:`state_dict`. +- """ +- if '_parameters' not in self.__dict__: +- raise AttributeError( +- "cannot assign parameter before Module.__init__() call") +- +- elif not isinstance(name, str): +- raise TypeError(f"parameter name should be a string. Got {type(name)}") +- elif '.' in name: +- raise KeyError("parameter name can't contain \".\"") +- elif name == '': +- raise KeyError("parameter name can't be empty string \"\"") +- elif hasattr(self, name) and name not in self._parameters: +- raise KeyError(f"attribute '{name}' already exists") +- +- if param is None: +- self._parameters[name] = None +- elif not isinstance(param, Parameter): +- raise TypeError(f"cannot assign '{type(param)}' object to parameter '{name}' " +- "(nn.Parameter or None required)" +- ) +- else: +- for hook in _global_parameter_registration_hooks.values(): +- output = hook(self, name, param) +- if output is not None: +- param = output +- self._parameters[name] = param +- +- def add_module(self, name: str, module: Optional["Module"]) -> None: +- r"""Add a child module to the current module. +- +- The module can be accessed as an attribute using the given name. +- +- Args: +- name (str): name of the child module. The child module can be +- accessed from this module using the given name +- module (Module): child module to be added to the module. +- """ +- if not isinstance(module, Module) and module is not None: +- raise TypeError(f"{mindtorch.typename(module)} is not a Module subclass") +- elif not isinstance(name, str): +- raise TypeError( +- f"module name should be a string. Got {mindtorch.typename(name)}" +- ) +- elif hasattr(self, name) and name not in self._modules: +- raise KeyError(f"attribute '{name}' already exists") +- elif "." in name: +- raise KeyError(f'module name can\'t contain ".", got: {name}') +- elif name == "": +- raise KeyError('module name can\'t be empty string ""') +- for hook in _global_module_registration_hooks.values(): +- output = hook(self, name, module) +- if output is not None: +- module = output +- self._modules[name] = module +- +- def register_module(self, name: str, module: Optional["Module"]) -> None: +- r"""Alias for :func:`add_module`.""" +- self.add_module(name, module) +- +- def get_parameter(self, target: str) -> "Parameter": +- """Return the parameter given by ``target`` if it exists, otherwise throw an error. +- +- See the docstring for ``get_submodule`` for a more detailed +- explanation of this method's functionality as well as how to +- correctly specify ``target``. +- +- Args: +- target: The fully-qualified string name of the Parameter +- to look for. (See ``get_submodule`` for how to specify a +- fully-qualified string.) +- +- Returns: +- mindtorch.nn.Parameter: The Parameter referenced by ``target`` +- +- Raises: +- AttributeError: If the target string references an invalid +- path or resolves to something that is not an +- ``nn.Parameter`` +- """ +- module_path, _, param_name = target.rpartition(".") +- +- mod: mindtorch.nn.Module = self.get_submodule(module_path) +- +- if not hasattr(mod, param_name): +- raise AttributeError( +- mod._get_name() + " has no attribute `" + param_name + "`" +- ) +- +- param: mindtorch.nn.Parameter = getattr(mod, param_name) +- +- if not isinstance(param, mindtorch.nn.Parameter): +- raise AttributeError("`" + param_name + "` is not an nn.Parameter") +- +- return param +- +- def get_buffer(self, target: str) -> "Tensor": +- """Return the buffer given by ``target`` if it exists, otherwise throw an error. +- +- See the docstring for ``get_submodule`` for a more detailed +- explanation of this method's functionality as well as how to +- correctly specify ``target``. +- +- Args: +- target: The fully-qualified string name of the buffer +- to look for. (See ``get_submodule`` for how to specify a +- fully-qualified string.) +- +- Returns: +- mindtorch.Tensor: The buffer referenced by ``target`` +- +- Raises: +- AttributeError: If the target string references an invalid +- path or resolves to something that is not a +- buffer +- """ +- module_path, _, buffer_name = target.rpartition(".") +- +- mod: mindtorch.nn.Module = self.get_submodule(module_path) +- +- if not hasattr(mod, buffer_name): +- raise AttributeError( +- mod._get_name() + " has no attribute `" + buffer_name + "`" +- ) +- +- buffer: mindtorch.Tensor = getattr(mod, buffer_name) +- +- if buffer_name not in mod._buffers: +- raise AttributeError("`" + buffer_name + "` is not a buffer") +- +- return buffer +- +- +- def get_extra_state(self) -> Any: +- """Return any extra state to include in the module's state_dict. +- +- Implement this and a corresponding :func:`set_extra_state` for your module +- if you need to store extra state. This function is called when building the +- module's `state_dict()`. +- +- Note that extra state should be picklable to ensure working serialization +- of the state_dict. We only provide provide backwards compatibility guarantees +- for serializing Tensors; other objects may break backwards compatibility if +- their serialized pickled form changes. +- +- Returns: +- object: Any extra state to store in the module's state_dict +- """ +- raise RuntimeError( +- "Reached a code path in Module.get_extra_state() that should never be called. " +- "Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml " +- "to report this bug.") +- +- +- def set_extra_state(self, state: Any) -> None: +- """Set extra state contained in the loaded `state_dict`. +- +- This function is called from :func:`load_state_dict` to handle any extra state +- found within the `state_dict`. Implement this function and a corresponding +- :func:`get_extra_state` for your module if you need to store extra state within its +- `state_dict`. +- +- Args: +- state (dict): Extra state from the `state_dict` +- """ +- raise RuntimeError( +- "Reached a code path in Module.set_extra_state() that should never be called. " +- "Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml " +- "to report this bug.") +- +- def _apply(self, fn, recurse=True): +- if recurse: +- for module in self.children(): +- module._apply(fn) +- +- def compute_should_use_set_data(tensor, tensor_applied): +- if mindtorch._has_compatible_shallow_copy_type(tensor, tensor_applied): +- # If the new tensor has compatible tensor type as the existing tensor, +- # the current behavior is to change the tensor in-place using `.data =`, +- # and the future behavior is to overwrite the existing tensor. However, +- # changing the current behavior is a BC-breaking change, and we want it +- # to happen in future releases. So for now we introduce the +- # `mindtorch.__future__.get_overwrite_module_params_on_conversion()` +- # global flag to let the user control whether they want the future +- # behavior of overwriting the existing tensor or not. +- return not mindtorch.__future__.get_overwrite_module_params_on_conversion() +- else: +- return False +- +- should_use_swap_tensors = ( +- mindtorch.__future__.get_swap_module_params_on_conversion() +- ) +- +- for key, param in self._parameters.items(): +- if param is None: +- continue +- # Tensors stored in modules are graph leaves, and we don't want to +- # track autograd history of `param_applied`, so we have to use +- # `with mindtorch.no_grad():` +- with mindtorch.no_grad(): +- param_applied = fn(param) +- p_should_use_set_data = compute_should_use_set_data(param, param_applied) +- +- # subclasses may have multiple child tensors so we need to use swap_tensors +- p_should_use_swap_tensors = should_use_swap_tensors +- +- param_grad = param.grad +- if p_should_use_swap_tensors: +- try: +- if param_grad is not None: +- # Accessing param.grad makes its at::Tensor's use_count 2, which will prevent swapping. +- # Decrement use count of the gradient by setting to None +- param.grad = None +- param_applied = Parameter( +- param_applied, requires_grad=param.requires_grad +- ) +- mindtorch.utils.swap_tensors(param, param_applied) +- except Exception as e: +- if param_grad is not None: +- param.grad = param_grad +- raise RuntimeError( +- f"_apply(): Couldn't swap {self._get_name()}.{key}" +- ) from e +- out_param = param +- elif p_should_use_set_data: +- param.data = param_applied +- out_param = param +- else: +- assert isinstance(param, Parameter) +- assert param.is_leaf +- out_param = Parameter(param_applied, param.requires_grad) +- self._parameters[key] = out_param +- +- if param_grad is not None: +- with mindtorch.no_grad(): +- grad_applied = fn(param_grad) +- g_should_use_set_data = compute_should_use_set_data( +- param_grad, grad_applied +- ) +- if p_should_use_swap_tensors: +- grad_applied.requires_grad_(param_grad.requires_grad) +- try: +- mindtorch.utils.swap_tensors(param_grad, grad_applied) +- except Exception as e: +- raise RuntimeError( +- f"_apply(): Couldn't swap {self._get_name()}.{key}.grad" +- ) from e +- out_param.grad = param_grad +- elif g_should_use_set_data: +- assert out_param.grad is not None +- out_param.grad.data = grad_applied +- else: +- assert param_grad.is_leaf +- out_param.grad = grad_applied.requires_grad_( +- param_grad.requires_grad +- ) +- +- for key, buf in self._buffers.items(): +- if buf is not None: +- self._buffers[key] = fn(buf) +- +- return self +- +- def apply(self, fn): +- """Applies ``fn`` recursively to every submodule (as returned by ``.children()``) +- as well as self. Typical use includes initializing the parameters of a model +- (see also :ref:`torch-nn-init`). +- +- Args: +- fn (:class:`Module` -> None): function to be applied to each submodule +- +- Returns: +- Module: self +- +- Example: +- >>> def init_weights(m): +- >>> print(m) +- >>> if type(m) == nn.Linear: +- >>> m.weight.data.fill_(1.0) +- >>> print(m.weight) +- >>> +- >>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2)) +- >>> net.apply(init_weights) +- Linear (2 -> 2) +- Parameter containing: +- 1 1 +- 1 1 +- [mindtorch.Tensor of size 2x2] +- Linear (2 -> 2) +- Parameter containing: +- 1 1 +- 1 1 +- [mindtorch.Tensor of size 2x2] +- Sequential ( +- (0): Linear (2 -> 2) +- (1): Linear (2 -> 2) +- ) +- """ +- for module in self.children(): +- module.apply(fn) +- fn(self) +- return self +- +- def _wrapped_call_impl(self, *args, **kwargs): +- if self._compiled_call_impl is not None: +- return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] +- return self._call_impl(*args, **kwargs) +- +- # torchrec tests the code consistency with the following code +- # fmt: off +- def _call_impl(self, *args, **kwargs): +- forward_call = self.forward +- # If we don't have any hooks, we want to skip the rest of the logic in +- # this function, and just call forward. +- if self.__ms_class__: +- return forward_call(*args, **kwargs) +- +- if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks +- or _global_backward_pre_hooks or _global_backward_hooks +- or _global_forward_hooks or _global_forward_pre_hooks): +- return forward_call(*args, **kwargs) +- +- try: +- result = None +- called_always_called_hooks = set() +- +- full_backward_hooks, non_full_backward_hooks = [], [] +- backward_pre_hooks = [] +- if self._backward_pre_hooks or _global_backward_pre_hooks: +- backward_pre_hooks = self._get_backward_pre_hooks() +- +- if self._backward_hooks or _global_backward_hooks: +- full_backward_hooks, non_full_backward_hooks = self._get_backward_hooks() +- +- if _global_forward_pre_hooks or self._forward_pre_hooks: +- for hook_id, hook in ( +- *_global_forward_pre_hooks.items(), +- *self._forward_pre_hooks.items(), +- ): +- if hook_id in self._forward_pre_hooks_with_kwargs: +- args_kwargs_result = hook(self, args, kwargs) # type: ignore[misc] +- if args_kwargs_result is not None: +- if isinstance(args_kwargs_result, tuple) and len(args_kwargs_result) == 2: +- args, kwargs = args_kwargs_result +- else: +- raise RuntimeError( +- "forward pre-hook must return None or a tuple " +- f"of (new_args, new_kwargs), but got {args_kwargs_result}." +- ) +- else: +- args_result = hook(self, args) +- if args_result is not None: +- if not isinstance(args_result, tuple): +- args_result = (args_result,) +- args = args_result +- +- bw_hook = None +- # if full_backward_hooks or backward_pre_hooks: +- # bw_hook = BackwardHook(self, full_backward_hooks, backward_pre_hooks) +- # args = bw_hook.setup_input_hook(args) +- +- result = forward_call(*args, **kwargs) +- if _global_forward_hooks or self._forward_hooks: +- for hook_id, hook in ( +- *_global_forward_hooks.items(), +- *self._forward_hooks.items(), +- ): +- # mark that always called hook is run +- if hook_id in self._forward_hooks_always_called or hook_id in _global_forward_hooks_always_called: +- called_always_called_hooks.add(hook_id) +- +- if hook_id in self._forward_hooks_with_kwargs: +- hook_result = hook(self, args, kwargs, result) +- else: +- hook_result = hook(self, args, result) +- +- if hook_result is not None: +- result = hook_result +- +- if bw_hook: +- if not isinstance(result, (mindtorch.Tensor, tuple)): +- warnings.warn("For backward hooks to be called," +- " module output should be a Tensor or a tuple of Tensors" +- f" but received {type(result)}") +- result = bw_hook.setup_output_hook(result) +- +- # Handle the non-full backward hooks +- if non_full_backward_hooks: +- var = result +- while not isinstance(var, mindtorch.Tensor): +- if isinstance(var, dict): +- var = next(v for v in var.values() if isinstance(v, mindtorch.Tensor)) +- else: +- var = var[0] +- # grad_fn = var.grad_fn +- # if grad_fn is not None: +- # for hook in non_full_backward_hooks: +- # grad_fn.register_hook(_WrappedHook(hook, self)) +- # self._maybe_warn_non_full_backward_hook(args, result, grad_fn) +- +- return result +- +- except Exception: +- # run always called hooks if they have not already been run +- # For now only forward hooks have the always_call option but perhaps +- # this functionality should be added to full backward hooks as well. +- for hook_id, hook in _global_forward_hooks.items(): +- if hook_id in _global_forward_hooks_always_called and hook_id not in called_always_called_hooks: # type: ignore[possibly-undefined] +- try: +- hook_result = hook(self, args, result) # type: ignore[possibly-undefined] +- if hook_result is not None: +- result = hook_result +- except Exception as e: +- warnings.warn("global module forward hook with ``always_call=True`` raised an exception " +- f"that was silenced as another error was raised in forward: {str(e)}") +- continue +- +- for hook_id, hook in self._forward_hooks.items(): +- if hook_id in self._forward_hooks_always_called and hook_id not in called_always_called_hooks: # type: ignore[possibly-undefined] +- try: +- if hook_id in self._forward_hooks_with_kwargs: +- hook_result = hook(self, args, kwargs, result) # type: ignore[possibly-undefined] +- else: +- hook_result = hook(self, args, result) # type: ignore[possibly-undefined] +- if hook_result is not None: +- result = hook_result +- except Exception as e: +- warnings.warn("module forward hook with ``always_call=True`` raised an exception " +- f"that was silenced as another error was raised in forward: {str(e)}") +- continue +- # raise exception raised in try block +- raise +- # fmt: on +- +- __call__: Callable[..., Any] = _wrapped_call_impl +- +- def __getstate__(self): +- state = self.__dict__.copy() +- state.pop("_compiled_call_impl", None) +- return state +- +- def __setstate__(self, state): +- self.__dict__.update(state) +- +- # Support loading old checkpoints that don't have the following attrs: +- if "_forward_pre_hooks" not in self.__dict__: +- self._forward_pre_hooks = OrderedDict() +- if "_forward_pre_hooks_with_kwargs" not in self.__dict__: +- self._forward_pre_hooks_with_kwargs = OrderedDict() +- if "_forward_hooks_with_kwargs" not in self.__dict__: +- self._forward_hooks_with_kwargs = OrderedDict() +- if "_forward_hooks_always_called" not in self.__dict__: +- self._forward_hooks_always_called = OrderedDict() +- if "_state_dict_hooks" not in self.__dict__: +- self._state_dict_hooks = OrderedDict() +- if "_state_dict_pre_hooks" not in self.__dict__: +- self._state_dict_pre_hooks = OrderedDict() +- if "_load_state_dict_pre_hooks" not in self.__dict__: +- self._load_state_dict_pre_hooks = OrderedDict() +- if "_load_state_dict_post_hooks" not in self.__dict__: +- self._load_state_dict_post_hooks = OrderedDict() +- if "_non_persistent_buffers_set" not in self.__dict__: +- self._non_persistent_buffers_set = set() +- if "_is_full_backward_hook" not in self.__dict__: +- self._is_full_backward_hook = None +- if "_backward_pre_hooks" not in self.__dict__: +- self._backward_pre_hooks = OrderedDict() +- +- def __getattr__(self, name): +- if '_parameters' in self.__dict__: +- _parameters = self.__dict__['_parameters'] +- if name in _parameters: +- return _parameters[name] +- if '_buffers' in self.__dict__: +- _buffers = self.__dict__['_buffers'] +- if name in _buffers: +- return _buffers[name] +- if '_modules' in self.__dict__: +- modules = self.__dict__['_modules'] +- if name in modules: +- return modules[name] +- raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") +- +- def __setattr__(self, name: str, value: Union[Tensor, "Module"]) -> None: +- def remove_from(*dicts_or_sets): +- for d in dicts_or_sets: +- if name in d: +- if isinstance(d, dict): +- del d[name] +- else: +- d.discard(name) +- +- params = self.__dict__.get("_parameters") +- if isinstance(value, Parameter): +- if params is None: +- raise AttributeError( +- "cannot assign parameters before Module.__init__() call" +- ) +- remove_from( +- self.__dict__, +- self._buffers, +- self._modules, +- self._non_persistent_buffers_set, +- ) +- self.register_parameter(name, value) +- elif params is not None and name in params: +- if value is not None: +- raise TypeError( +- f"cannot assign '{mindtorch.typename(value)}' as parameter '{name}' " +- "(mindtorch.nn.Parameter or None expected)" +- ) +- self.register_parameter(name, value) +- else: +- modules = self.__dict__.get("_modules") +- if isinstance(value, Module): +- if modules is None: +- raise AttributeError( +- "cannot assign module before Module.__init__() call" +- ) +- remove_from( +- self.__dict__, +- self._parameters, +- self._buffers, +- self._non_persistent_buffers_set, +- ) +- for hook in _global_module_registration_hooks.values(): +- output = hook(self, name, value) +- if output is not None: +- value = output +- modules[name] = value +- +- elif modules is not None and name in modules: +- if value is not None: +- raise TypeError( +- f"cannot assign '{mindtorch.typename(value)}' as child module '{name}' " +- "(mindtorch.nn.Module or None expected)" +- ) +- for hook in _global_module_registration_hooks.values(): +- output = hook(self, name, value) +- if output is not None: +- value = output +- modules[name] = value +- else: +- buffers = self.__dict__.get("_buffers") +- if isinstance(value, Buffer) or buffers is not None and name in buffers: +- if value is not None and not isinstance(value, mindtorch.Tensor): +- raise TypeError( +- f"cannot assign '{mindtorch.typename(value)}' as buffer '{name}' " +- "(mindtorch.nn.Buffer, mindtorch.Tensor or None expected)" +- ) +- if isinstance(value, Buffer): +- persistent = value.persistent +- else: +- persistent = name not in self._non_persistent_buffers_set +- # === HACK === +- # This whole block below should just be: +- # self.register_buffer(name, value, persistent) +- +- # But to support subclasses of nn.Module that (wrongfully) implement a +- # register_buffer() method that doesn't have the "persistent" +- # argument. Only pass it in if it is accepted otherwise assume +- # it is always true +- if ( +- getattr(self.register_buffer, "__func__", None) +- is Module.register_buffer +- ): +- self.register_buffer(name, value, persistent) +- else: +- sign = inspect.signature(self.register_buffer) +- if "persistent" in sign.parameters: +- self.register_buffer(name, value, persistent) +- else: +- if not persistent: +- raise RuntimeError( +- "Registering a non-persistent buffer " +- "on a Module subclass that implements " +- "register_buffer() without the persistent " +- "argument is not allowed." +- ) +- # Assume that the implementation without the argument has the +- # behavior from before the argument was added: persistent=True +- self.register_buffer(name, value) +- # === HACK END === +- else: +- super().__setattr__(name, value) +- +- def __delattr__(self, name): +- if name in self._parameters: +- del self._parameters[name] +- elif name in self._buffers: +- del self._buffers[name] +- self._non_persistent_buffers_set.discard(name) +- elif name in self._modules: +- del self._modules[name] +- else: +- super().__delattr__(name) +- +- def _register_state_dict_hook(self, hook): +- r"""Register a post-hook for the :meth:`~mindtorch.nn.Module.state_dict` method. +- +- It should have the following signature:: +- hook(module, state_dict, prefix, local_metadata) -> None or state_dict +- +- The registered hooks can modify the ``state_dict`` inplace or return a new one. +- If a new ``state_dict`` is returned, it will only be respected if it is the root +- module that :meth:`~nn.Module.state_dict` is called from. +- """ +- if getattr(hook, "_from_public_api", False): +- raise RuntimeError( +- "Cannot register the same function as the state dict post hook that was " +- "previously registered via register_state_dict_post_hook" +- ) +- handle = RemovableHandle(self._state_dict_hooks) +- self._state_dict_hooks[handle.id] = hook +- return handle +- +- def extra_repr(self) -> str: +- r"""Set the extra representation of the module. +- +- To print customized extra information, you should re-implement +- this method in your own modules. Both single-line and multi-line +- strings are acceptable. +- """ +- return '' +- +- +- def __repr__(self): +- # We treat the extra repr like the sub-module, one item per line +- extra_lines = [] +- extra_repr = self.extra_repr() +- # empty string will be split into list [''] +- if extra_repr: +- extra_lines = extra_repr.split('\n') +- child_lines = [] +- for key, module in self._modules.items(): +- mod_str = repr(module) +- mod_str = _addindent(mod_str, 2) +- child_lines.append('(' + key + '): ' + mod_str) +- lines = extra_lines + child_lines +- +- main_str = self._get_name() + '(' +- if lines: +- # simple one-liner info, which most builtin Modules will use +- if len(extra_lines) == 1 and not child_lines: +- main_str += extra_lines[0] +- else: +- main_str += '\n ' + '\n '.join(lines) + '\n' +- +- main_str += ')' +- return main_str +- +- def __dir__(self): +- module_attrs = dir(self.__class__) +- attrs = list(self.__dict__.keys()) +- parameters = list(self._parameters.keys()) +- modules = list(self._modules.keys()) +- buffers = list(self._buffers.keys()) +- keys = module_attrs + attrs + parameters + modules + buffers +- +- # Eliminate attrs that are not legal Python variable names +- keys = [key for key in keys if not key[0].isdigit()] +- +- return sorted(keys) +- +- def cuda(self: T, device: Optional[Union[int, device]] = None) -> T: +- r"""Move all model parameters and buffers to the GPU. +- +- This also makes associated parameters and buffers different objects. So +- it should be called before constructing optimizer if the module will +- live on GPU while being optimized. +- +- .. note:: +- This method modifies the module in-place. +- +- Args: +- device (int, optional): if specified, all parameters will be +- copied to that device +- +- Returns: +- Module: self +- """ +- return self._apply(lambda t: t.cuda(device)) +- +- def npu(self: T, device: Optional[Union[int, device]] = None) -> T: +- return self._apply(lambda t: t.npu(device)) +- +- def cpu(self: T, device: Optional[Union[int, device]] = None) -> T: +- return self._apply(lambda t: t.cpu()) +- +- +- def _load_from_state_dict( +- self, +- state_dict, +- prefix, +- local_metadata, +- strict, +- missing_keys, +- unexpected_keys, +- error_msgs, +- ) -> None: +- r"""Copy parameters and buffers from :attr:`state_dict` into only this module, but not its descendants. +- +- This is called on every submodule +- in :meth:`~mindtorch.nn.Module.load_state_dict`. Metadata saved for this +- module in input :attr:`state_dict` is provided as :attr:`local_metadata`. +- For state dicts without metadata, :attr:`local_metadata` is empty. +- Subclasses can achieve class-specific backward compatible loading using +- the version number at `local_metadata.get("version", None)`. +- Additionally, :attr:`local_metadata` can also contain the key +- `assign_to_params_buffers` that indicates whether keys should be +- assigned their corresponding tensor in the state_dict. +- +- .. note:: +- :attr:`state_dict` is not the same object as the input +- :attr:`state_dict` to :meth:`~mindtorch.nn.Module.load_state_dict`. So +- it can be modified. +- +- Args: +- state_dict (dict): a dict containing parameters and +- persistent buffers. +- prefix (str): the prefix for parameters and buffers used in this +- module +- local_metadata (dict): a dict containing the metadata for this module. +- See +- strict (bool): whether to strictly enforce that the keys in +- :attr:`state_dict` with :attr:`prefix` match the names of +- parameters and buffers in this module +- missing_keys (list of str): if ``strict=True``, add missing keys to +- this list +- unexpected_keys (list of str): if ``strict=True``, add unexpected +- keys to this list +- error_msgs (list of str): error messages should be added to this +- list, and will be reported together in +- :meth:`~mindtorch.nn.Module.load_state_dict` +- """ +- for hook in self._load_state_dict_pre_hooks.values(): +- hook( +- state_dict, +- prefix, +- local_metadata, +- strict, +- missing_keys, +- unexpected_keys, +- error_msgs, +- ) +- +- persistent_buffers = { +- k: v +- for k, v in self._buffers.items() +- if k not in self._non_persistent_buffers_set +- } +- local_name_params = itertools.chain( +- self._parameters.items(), persistent_buffers.items() +- ) +- local_state = {k: v for k, v in local_name_params if v is not None} +- assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False) +- use_swap_tensors = mindtorch.__future__.get_swap_module_params_on_conversion() +- +- for name, param in local_state.items(): +- key = prefix + name +- if key in state_dict: +- input_param = state_dict[key] +- if not mindtorch.overrides.is_tensor_like(input_param): +- error_msgs.append( +- f'While copying the parameter named "{key}", ' +- "expected mindtorch.Tensor or Tensor-like object from checkpoint but " +- f"received {type(input_param)}" +- ) +- continue +- +- # This is used to avoid copying uninitialized parameters into +- # non-lazy modules, since they dont have the hook to do the checks +- # in such case, it will error when accessing the .shape attribute. +- is_param_lazy = mindtorch.nn.parameter.is_lazy(param) +- # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+ +- if ( +- not is_param_lazy +- and len(param.shape) == 0 +- and len(input_param.shape) == 1 +- ): +- input_param = input_param[0] +- +- if not is_param_lazy and input_param.shape != param.shape: +- # local shape should match the one in checkpoint +- error_msgs.append( +- f"size mismatch for {key}: copying a param with shape {input_param.shape} from checkpoint, " +- f"the shape in current model is {param.shape}." +- ) +- continue +- +- if ( +- param.is_meta +- and not input_param.is_meta +- and not assign_to_params_buffers +- ): +- warnings.warn( +- f"for {key}: copying from a non-meta parameter in the checkpoint to a meta " +- "parameter in the current model, which is a no-op. (Did you mean to " +- "pass `assign=True` to assign items in the state dictionary to their " +- "corresponding key in the module instead of copying them in place?)" +- ) +- +- try: +- with mindtorch.no_grad(): +- if use_swap_tensors: +- new_input_param = param.module_load( +- input_param, assign=assign_to_params_buffers +- ) +- if id(new_input_param) == id(input_param) or id( +- new_input_param +- ) == id(param): +- raise RuntimeError( +- "module_load returned one of self or other, please .detach() " +- "the result if returning one of the inputs in module_load" +- ) +- if isinstance(param, mindtorch.nn.Parameter): +- if not isinstance(new_input_param, mindtorch.nn.Parameter): +- new_input_param = mindtorch.nn.Parameter( +- new_input_param, +- requires_grad=param.requires_grad, +- ) +- else: +- new_input_param.requires_grad_(param.requires_grad) +- mindtorch.utils.swap_tensors(param, new_input_param) +- del new_input_param +- elif assign_to_params_buffers: +- # Shape checks are already done above +- if isinstance(param, mindtorch.nn.Parameter): +- if not isinstance(input_param, mindtorch.nn.Parameter): +- input_param = mindtorch.nn.Parameter( +- input_param, requires_grad=param.requires_grad +- ) +- else: +- input_param.requires_grad_(param.requires_grad) +- setattr(self, name, input_param) +- else: +- param.copy_(input_param) +- except Exception as ex: +- action = "swapping" if use_swap_tensors else "copying" +- error_msgs.append( +- f'While {action} the parameter named "{key}", ' +- f"whose dimensions in the model are {param.size()} and " +- f"whose dimensions in the checkpoint are {input_param.size()}, " +- f"an exception occurred : {ex.args}." +- ) +- elif strict: +- missing_keys.append(key) +- +- extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX +- if ( +- getattr(self.__class__, "set_extra_state", Module.set_extra_state) +- is not Module.set_extra_state +- ): +- if extra_state_key in state_dict: +- self.set_extra_state(state_dict[extra_state_key]) +- elif strict: +- missing_keys.append(extra_state_key) +- elif strict and (extra_state_key in state_dict): +- unexpected_keys.append(extra_state_key) +- +- if strict: +- for key in state_dict.keys(): +- if key.startswith(prefix) and key != extra_state_key: +- input_name = key[len(prefix) :].split(".", 1) +- # Must be Module if it have attributes +- if len(input_name) > 1: +- if input_name[0] not in self._modules: +- unexpected_keys.append(key) +- elif input_name[0] not in local_state: +- unexpected_keys.append(key) +- +- def load_state_dict(self, state_dict: Mapping[str, Any], +- strict: bool = True, assign: bool = False): +- r"""Copy parameters and buffers from :attr:`state_dict` into this module and its descendants. +- +- If :attr:`strict` is ``True``, then +- the keys of :attr:`state_dict` must exactly match the keys returned +- by this module's :meth:`~nn.Module.state_dict` function. +- +- Args: +- state_dict (dict): a dict containing parameters and +- persistent buffers. +- strict (bool, optional): whether to strictly enforce that the keys +- in :attr:`state_dict` match the keys returned by this module's +- :meth:`~nn.Module.state_dict` function. Default: ``True`` +- assign (bool, optional): When ``False``, the properties of the tensors +- in the current module are preserved while when ``True``, the +- properties of the Tensors in the state dict are preserved. The only +- exception is the ``requires_grad`` field of :class:`~nn.Parameter`s +- for which the value from the module is preserved. +- Default: ``False`` +- +- Returns: +- ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: +- * **missing_keys** is a list of str containing the missing keys +- * **unexpected_keys** is a list of str containing the unexpected keys +- +- Note: +- If a parameter or buffer is registered as ``None`` and its corresponding key +- exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a +- ``RuntimeError``. +- """ +- if not isinstance(state_dict, Mapping): +- raise TypeError(f"Expected state_dict to be dict-like, got {type(state_dict)}.") +- +- missing_keys: List[str] = [] +- unexpected_keys: List[str] = [] +- error_msgs: List[str] = [] +- +- # copy state_dict so _load_from_state_dict can modify it +- metadata = getattr(state_dict, '_metadata', None) +- state_dict = OrderedDict(state_dict) +- +- if metadata is not None: +- # mypy isn't aware that "_metadata" exists in state_dict +- state_dict._metadata = metadata # type: ignore[attr-defined] +- +- def load(module, local_state_dict, prefix=''): +- local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) +- if assign: +- local_metadata['assign_to_params_buffers'] = assign +- module._load_from_state_dict( +- local_state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) +- for name, child in module._modules.items(): +- if child is not None: +- child_prefix = prefix + name + '.' +- child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)} +- load(child, child_state_dict, child_prefix) # noqa: F821 +- +- # Note that the hook can modify missing_keys and unexpected_keys. +- incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys) +- for hook in module._load_state_dict_post_hooks.values(): +- out = hook(module, incompatible_keys) +- assert out is None, ( +- "Hooks registered with ``register_load_state_dict_post_hook`` are not" +- "expected to return new values, if incompatible_keys need to be modified," +- "it should be done inplace." +- ) +- +- load(self, state_dict) +- del load +- +- if strict: +- if len(unexpected_keys) > 0: +- error_msgs.insert( +- 0, 'Unexpected key(s) in state_dict: {}. '.format( +- ', '.join(f'"{k}"' for k in unexpected_keys))) +- if len(missing_keys) > 0: +- error_msgs.insert( +- 0, 'Missing key(s) in state_dict: {}. '.format( +- ', '.join(f'"{k}"' for k in missing_keys))) +- +- if len(error_msgs) > 0: +- raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( +- self.__class__.__name__, "\n\t".join(error_msgs))) +- return _IncompatibleKeys(missing_keys, unexpected_keys) +- +- +- def _named_members( +- self, get_members_fn, prefix="", recurse=True, remove_duplicate: bool = True +- ): +- r"""Help yield various names + members of modules.""" +- memo = set() +- modules = ( +- self.named_modules(prefix=prefix, remove_duplicate=remove_duplicate) +- if recurse +- else [(prefix, self)] +- ) +- for module_prefix, module in modules: +- members = get_members_fn(module) +- for k, v in members: +- if v is None or v in memo: +- continue +- if remove_duplicate: +- memo.add(v) +- name = module_prefix + ("." if module_prefix else "") + k +- yield name, v +- +- def parameters(self, recurse: bool = True) -> Iterator[Parameter]: +- r"""Return an iterator over module parameters. +- +- This is typically passed to an optimizer. +- +- Args: +- recurse (bool): if True, then yields parameters of this module +- and all submodules. Otherwise, yields only parameters that +- are direct members of this module. +- +- Yields: +- Parameter: module parameter +- +- Example:: +- +- >>> # xdoctest: +SKIP("undefined vars") +- >>> for param in model.parameters(): +- >>> print(type(param), param.shape) +- (20L,) +- (20L, 1L, 5L, 5L) +- +- """ +- for name, param in self.named_parameters(recurse=recurse): +- yield param +- +- def trainable_params(self, recurse: bool = True): +- params = tuple() +- for name, param in self.named_parameters(recurse=recurse): +- if param.requires_grad: +- params += (param,) +- return params +- +- def get_submodule(self, target: str) -> "Module": +- """Return the submodule given by ``target`` if it exists, otherwise throw an error. +- +- For example, let's say you have an ``nn.Module`` ``A`` that +- looks like this: +- +- .. code-block:: text +- +- A( +- (net_b): Module( +- (net_c): Module( +- (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2)) +- ) +- (linear): Linear(in_features=100, out_features=200, bias=True) +- ) +- ) +- +- (The diagram shows an ``nn.Module`` ``A``. ``A`` has a nested +- submodule ``net_b``, which itself has two submodules ``net_c`` +- and ``linear``. ``net_c`` then has a submodule ``conv``.) +- +- To check whether or not we have the ``linear`` submodule, we +- would call ``get_submodule("net_b.linear")``. To check whether +- we have the ``conv`` submodule, we would call +- ``get_submodule("net_b.net_c.conv")``. +- +- The runtime of ``get_submodule`` is bounded by the degree +- of module nesting in ``target``. A query against +- ``named_modules`` achieves the same result, but it is O(N) in +- the number of transitive modules. So, for a simple check to see +- if some submodule exists, ``get_submodule`` should always be +- used. +- +- Args: +- target: The fully-qualified string name of the submodule +- to look for. (See above example for how to specify a +- fully-qualified string.) +- +- Returns: +- nn.Module: The submodule referenced by ``target`` +- +- Raises: +- AttributeError: If the target string references an invalid +- path or resolves to something that is not an +- ``nn.Module`` +- """ +- if target == "": +- return self +- +- atoms: List[str] = target.split(".") +- mod: Module = self +- +- for item in atoms: +- +- if not hasattr(mod, item): +- raise AttributeError(mod._get_name() + " has no " +- "attribute `" + item + "`") +- +- mod = getattr(mod, item) +- +- if not isinstance(mod, Module): +- raise AttributeError("`" + item + "` is not " +- "an nn.Module") +- +- return mod +- +- def get_parameters(self, expand=True): +- return self.parameters(expand) +- +- def named_parameters( +- self, +- prefix: str = '', +- recurse: bool = True, +- remove_duplicate: bool = True +- ) -> Iterator[Tuple[str, Parameter]]: +- r"""Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself. +- +- Args: +- prefix (str): prefix to prepend to all parameter names. +- recurse (bool): if True, then yields parameters of this module +- and all submodules. Otherwise, yields only parameters that +- are direct members of this module. +- remove_duplicate (bool, optional): whether to remove the duplicated +- parameters in the result. Defaults to True. +- +- Yields: +- (str, Parameter): Tuple containing the name and parameter +- +- Example:: +- +- >>> # xdoctest: +SKIP("undefined vars") +- >>> for name, param in self.named_parameters(): +- >>> if name in ['bias']: +- >>> print(param.shape) +- +- """ +- gen = self._named_members( +- lambda module: module._parameters.items(), +- prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate) +- yield from gen +- +- def parameters_and_names(self, name_prefix='', expand=True): +- return self.named_parameters(name_prefix, expand) +- +- def buffers(self, recurse: bool = True) -> Iterator[Tensor]: +- r"""Return an iterator over module buffers. +- +- Args: +- recurse (bool): if True, then yields buffers of this module +- and all submodules. Otherwise, yields only buffers that +- are direct members of this module. +- +- Yields: +- mindtorch.Tensor: module buffer +- +- Example:: +- +- >>> # xdoctest: +SKIP("undefined vars") +- >>> for buf in model.buffers(): +- >>> print(type(buf), buf.shape) +- (20L,) +- (20L, 1L, 5L, 5L) +- +- """ +- for _, buf in self.named_buffers(recurse=recurse): +- yield buf +- +- +- def named_buffers(self, prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) -> Iterator[Tuple[str, Tensor]]: +- r"""Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself. +- +- Args: +- prefix (str): prefix to prepend to all buffer names. +- recurse (bool, optional): if True, then yields buffers of this module +- and all submodules. Otherwise, yields only buffers that +- are direct members of this module. Defaults to True. +- remove_duplicate (bool, optional): whether to remove the duplicated buffers in the result. Defaults to True. +- +- Yields: +- (str, mindtorch.Tensor): Tuple containing the name and buffer +- +- Example:: +- +- >>> # xdoctest: +SKIP("undefined vars") +- >>> for name, buf in self.named_buffers(): +- >>> if name in ['running_var']: +- >>> print(buf.shape) +- +- """ +- gen = self._named_members( +- lambda module: module._buffers.items(), +- prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate) +- yield from gen +- +- def _all_buffers(self, memo=None): +- if memo is None: +- memo = set() +- for name, b in self._buffers.items(): +- if b is not None and b not in memo: +- memo.add(b) +- yield b +- for module in self.children(): +- for b in module._all_buffers(memo): +- yield b +- +- def children(self): +- """Returns an iterator over immediate children modules. +- +- Yields: +- Module: a child module +- """ +- for name, module in self.named_children(): +- yield module +- +- def named_children(self): +- """Returns an iterator over immediate children modules, yielding both +- the name of the module as well as the module itself. +- +- Yields: +- (string, Module): Tuple containing a name and child module +- +- Example: +- >>> for name, module in model.named_children(): +- >>> if name in ['conv4', 'conv5']: +- >>> print(module) +- """ +- memo = set() +- for name, module in self._modules.items(): +- if module is not None and module not in memo: +- memo.add(module) +- yield name, module +- +- def modules(self): +- """Returns an iterator over all modules in the network. +- +- Yields: +- Module: a module in the network +- +- Note: +- Duplicate modules are returned only once. In the following +- example, ``l`` will be returned only once. +- +- >>> l = nn.Linear(2, 2) +- >>> net = nn.Sequential(l, l) +- >>> for idx, m in enumerate(net.modules()): +- >>> print(idx, '->', m) +- 0 -> Sequential ( +- (0): Linear (2 -> 2) +- (1): Linear (2 -> 2) +- ) +- 1 -> Linear (2 -> 2) +- """ +- for name, module in self.named_modules(): +- yield module +- +- def named_modules(self, memo: Optional[Set['Module']] = None, prefix: str = '', remove_duplicate: bool = True): +- r"""Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself. +- +- Args: +- memo: a memo to store the set of modules already added to the result +- prefix: a prefix that will be added to the name of the module +- remove_duplicate: whether to remove the duplicated module instances in the result +- or not +- +- Yields: +- (str, Module): Tuple of name and module +- +- Note: +- Duplicate modules are returned only once. In the following +- example, ``l`` will be returned only once. +- +- Example:: +- +- >>> l = nn.Linear(2, 2) +- >>> net = nn.Sequential(l, l) +- >>> for idx, m in enumerate(net.named_modules()): +- ... print(idx, '->', m) +- +- 0 -> ('', Sequential( +- (0): Linear(in_features=2, out_features=2, bias=True) +- (1): Linear(in_features=2, out_features=2, bias=True) +- )) +- 1 -> ('0', Linear(in_features=2, out_features=2, bias=True)) +- +- """ +- if memo is None: +- memo = set() +- if self not in memo: +- if remove_duplicate: +- memo.add(self) +- yield prefix, self +- for name, module in self._modules.items(): +- if module is None: +- continue +- submodule_prefix = prefix + ('.' if prefix else '') + name +- yield from module.named_modules(memo, submodule_prefix, remove_duplicate) +- +- def jit(self, mode=True): +- self.__ms_class__ = mode +- for module in self.children(): +- module.jit(mode) +- return self +- +- def compile(self, *args, **kwargs): +- self.jit() +- def forward_fn(*args, **kwargs): +- return self.forward(*args, **kwargs) +- +- # forward_fn = mindspore.jit(forward_fn, *args, **kwargs) +- self._compiled_call_impl = forward_fn +- +- @property +- def skip_syntax(self): +- return self.__ms_class__ +- +- def train(self, mode=True): +- """Sets the module in training mode. +- +- This has any effect only on modules such as Dropout or BatchNorm. +- +- Returns: +- Module: self +- """ +- self.training = mode +- for module in self.children(): +- module.train(mode) +- return self +- +- set_train = train +- +- def eval(self): +- """Sets the module in evaluation mode. +- +- This has any effect only on modules such as Dropout or BatchNorm. +- """ +- return self.train(False) +- +- def requires_grad_(self: T, requires_grad: bool = True) -> T: +- r"""Change if autograd should record operations on parameters in this module. +- +- This method sets the parameters' :attr:`requires_grad` attributes +- in-place. +- +- This method is helpful for freezing part of the module for finetuning +- or training parts of a model individually (e.g., GAN training). +- +- See :ref:`locally-disable-grad-doc` for a comparison between +- `.requires_grad_()` and several similar mechanisms that may be confused with it. +- +- Args: +- requires_grad (bool): whether autograd should record operations on +- parameters in this module. Default: ``True``. +- +- Returns: +- Module: self +- """ +- for p in self.parameters(): +- p.requires_grad = requires_grad +- return self +- +- +- def _get_name(self): +- return self.__class__.__name__ +- +- def to(self, *args, **kwargs): +- r"""Move and/or cast the parameters and buffers. +- +- This can be called as +- +- .. function:: to(device=None, dtype=None, non_blocking=False) +- :noindex: +- +- .. function:: to(dtype, non_blocking=False) +- :noindex: +- +- .. function:: to(tensor, non_blocking=False) +- :noindex: +- +- .. function:: to(memory_format=mindtorch.channels_last) +- :noindex: +- +- Its signature is similar to :meth:`mindtorch.Tensor.to`, but only accepts +- floating point or complex :attr:`dtype`\ s. In addition, this method will +- only cast the floating point or complex parameters and buffers to :attr:`dtype` +- (if given). The integral parameters and buffers will be moved +- :attr:`device`, if that is given, but with dtypes unchanged. When +- :attr:`non_blocking` is set, it tries to convert/move asynchronously +- with respect to the host if possible, e.g., moving CPU Tensors with +- pinned memory to CUDA devices. +- +- See below for examples. +- +- .. note:: +- This method modifies the module in-place. +- +- Args: +- device (:class:`mindtorch.device`): the desired device of the parameters +- and buffers in this module +- dtype (:class:`mindtorch.dtype`): the desired floating point or complex dtype of +- the parameters and buffers in this module +- tensor (mindtorch.Tensor): Tensor whose dtype and device are the desired +- dtype and device for all parameters and buffers in this module +- memory_format (:class:`mindtorch.memory_format`): the desired memory +- format for 4D parameters and buffers in this module (keyword +- only argument) +- +- Returns: +- Module: self +- +- Examples:: +- +- >>> # xdoctest: +IGNORE_WANT("non-deterministic") +- >>> linear = nn.Linear(2, 2) +- >>> linear.weight +- Parameter containing: +- tensor([[ 0.1913, -0.3420], +- [-0.5113, -0.2325]]) +- >>> linear.to(mindtorch.double) +- Linear(in_features=2, out_features=2, bias=True) +- >>> linear.weight +- Parameter containing: +- tensor([[ 0.1913, -0.3420], +- [-0.5113, -0.2325]], dtype=mindtorch.float64) +- >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA1) +- >>> gpu1 = mindtorch.device("cuda:1") +- >>> linear.to(gpu1, dtype=mindtorch.half, non_blocking=True) +- Linear(in_features=2, out_features=2, bias=True) +- >>> linear.weight +- Parameter containing: +- tensor([[ 0.1914, -0.3420], +- [-0.5112, -0.2324]], dtype=mindtorch.float16, device='cuda:1') +- >>> cpu = mindtorch.device("cpu") +- >>> linear.to(cpu) +- Linear(in_features=2, out_features=2, bias=True) +- >>> linear.weight +- Parameter containing: +- tensor([[ 0.1914, -0.3420], +- [-0.5112, -0.2324]], dtype=mindtorch.float16) +- +- >>> linear = nn.Linear(2, 2, bias=None).to(mindtorch.cdouble) +- >>> linear.weight +- Parameter containing: +- tensor([[ 0.3741+0.j, 0.2382+0.j], +- [ 0.5593+0.j, -0.4443+0.j]], dtype=mindtorch.complex128) +- >>> linear(mindtorch.ones(3, 2, dtype=mindtorch.cdouble)) +- tensor([[0.6122+0.j, 0.1150+0.j], +- [0.6122+0.j, 0.1150+0.j], +- [0.6122+0.j, 0.1150+0.j]], dtype=mindtorch.complex128) +- +- """ +- device, dtype, non_blocking, convert_to_format = mindtorch._C._nn._parse_to( +- *args, **kwargs +- ) +- +- if dtype is not None: +- if not (dtype.is_floating_point or dtype.is_complex): +- raise TypeError( +- "nn.Module.to only accepts floating point or complex " +- f"dtypes, but got desired dtype={dtype}" +- ) +- if dtype.is_complex: +- warnings.warn( +- "Complex modules are a new feature under active development whose design may change, " +- "and some modules might not work as expected when using complex tensors as parameters or buffers. " +- "Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml " +- "if a complex module does not work as expected." +- ) +- +- def convert(t): +- try: +- if convert_to_format is not None and t.dim() in (4, 5): +- return t.to( +- device, +- dtype if t.is_floating_point() or t.is_complex() else None, +- non_blocking, +- memory_format=convert_to_format, +- ) +- return t.to( +- device, +- dtype if t.is_floating_point() or t.is_complex() else None, +- non_blocking=non_blocking, +- ) +- except NotImplementedError as e: +- if str(e) == "Cannot copy out of meta tensor; no data!": +- raise NotImplementedError( +- f"{e} Please use mindtorch.nn.Module.to_empty() instead of mindtorch.nn.Module.to() " +- f"when moving module from meta to a different device." +- ) from None +- else: +- raise +- +- return self._apply(convert) +- +- def half(self: T) -> T: +- r"""Casts all floating point parameters and buffers to ``half`` datatype. +- +- .. note:: +- This method modifies the module in-place. +- +- Returns: +- Module: self +- """ +- return self._apply(lambda t: t.half() if t.is_floating_point() else t) +- +- def bfloat16(self: T) -> T: +- r"""Casts all floating point parameters and buffers to ``bfloat16`` datatype. +- +- .. note:: +- This method modifies the module in-place. +- +- Returns: +- Module: self +- """ +- return self._apply(lambda t: t.bfloat16() if t.is_floating_point() else t) +- +- def to_empty( +- self, *, device, recurse: bool = True +- ): +- r"""Move the parameters and buffers to the specified device without copying storage. +- +- Args: +- device (:class:`mindtorch.device`): The desired device of the parameters +- and buffers in this module. +- recurse (bool): Whether parameters and buffers of submodules should +- be recursively moved to the specified device. +- +- Returns: +- Module: self +- """ +- return self._apply( +- lambda t: mindtorch.empty_like(t, device=device), recurse=recurse +- ) +- +- def float(self: T) -> T: +- r"""Casts all floating point parameters and buffers to ``float`` datatype. +- +- .. note:: +- This method modifies the module in-place. +- +- Returns: +- Module: self +- """ +- return self._apply(lambda t: t.float() if t.is_floating_point() else t) +- +- +- def double(self: T) -> T: +- r"""Casts all floating point parameters and buffers to ``double`` datatype. +- +- .. note:: +- This method modifies the module in-place. +- +- Returns: +- Module: self +- """ +- return self._apply(lambda t: t.double() if t.is_floating_point() else t) +- +- +- def half(self: T) -> T: +- r"""Casts all floating point parameters and buffers to ``half`` datatype. +- +- .. note:: +- This method modifies the module in-place. +- +- Returns: +- Module: self +- """ +- return self._apply(lambda t: t.half() if t.is_floating_point() else t) +- +- +- def _save_to_state_dict(self, destination, prefix, keep_vars): +- r"""Save module state to the `destination` dictionary. +- +- The `destination` dictionary will contain the state +- of the module, but not its descendants. This is called on every +- submodule in :meth:`~nn.Module.state_dict`. +- +- In rare cases, subclasses can achieve class-specific behavior by +- overriding this method with custom logic. +- +- Args: +- destination (dict): a dict where state will be stored +- prefix (str): the prefix for parameters and buffers used in this +- module +- """ +- for name, param in self._parameters.items(): +- if param is not None: +- destination[prefix + name] = param +- for name, buf in self._buffers.items(): +- if buf is not None and name not in self._non_persistent_buffers_set: +- destination[prefix + name] = buf +- extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX +- if getattr(self.__class__, "get_extra_state", Module.get_extra_state) is not Module.get_extra_state: +- destination[extra_state_key] = self.get_extra_state() +- +- # The user can pass an optional arbitrary mappable object to `state_dict`, in which case `state_dict` returns +- # back that same object. But if they pass nothing, an `OrderedDict` is created and returned. +- T_destination = TypeVar('T_destination', bound=Dict[str, Any]) +- +- @overload +- def state_dict(self, *, destination: T_destination, prefix: str = ..., keep_vars: bool = ...) -> T_destination: +- ... +- +- @overload +- def state_dict(self, *, prefix: str = ..., keep_vars: bool = ...) -> Dict[str, Any]: +- ... +- +- def state_dict(self, *args, destination=None, prefix='', keep_vars=False): +- r"""Return a dictionary containing references to the whole state of the module. +- +- Both parameters and persistent buffers (e.g. running averages) are +- included. Keys are corresponding parameter and buffer names. +- Parameters and buffers set to ``None`` are not included. +- +- .. note:: +- The returned object is a shallow copy. It contains references +- to the module's parameters and buffers. +- +- .. warning:: +- Currently ``state_dict()`` also accepts positional arguments for +- ``destination``, ``prefix`` and ``keep_vars`` in order. However, +- this is being deprecated and keyword arguments will be enforced in +- future releases. +- +- .. warning:: +- Please avoid the use of argument ``destination`` as it is not +- designed for end-users. +- +- Args: +- destination (dict, optional): If provided, the state of module will +- be updated into the dict and the same object is returned. +- Otherwise, an ``OrderedDict`` will be created and returned. +- Default: ``None``. +- prefix (str, optional): a prefix added to parameter and buffer +- names to compose the keys in state_dict. Default: ``''``. +- keep_vars (bool, optional): by default the :class:`~mindtorch.Tensor` s +- returned in the state dict are detached from autograd. If it's +- set to ``True``, detaching will not be performed. +- Default: ``False``. +- +- Returns: +- dict: +- a dictionary containing a whole state of the module +- +- Example:: +- +- >>> # xdoctest: +SKIP("undefined vars") +- >>> module.state_dict().keys() +- ['bias', 'weight'] +- +- """ +- # TODO: Remove `args` and the parsing logic when BC allows. +- if len(args) > 0: +- if destination is None: +- destination = args[0] +- if len(args) > 1 and prefix == '': +- prefix = args[1] +- if len(args) > 2 and keep_vars is False: +- keep_vars = args[2] +- # DeprecationWarning is ignored by default +- warnings.warn( +- "Positional args are being deprecated, use kwargs instead.") +- +- if destination is None: +- destination = OrderedDict() +- destination._metadata = OrderedDict() +- +- local_metadata = {} +- if hasattr(destination, "_metadata"): +- destination._metadata[prefix[:-1]] = local_metadata +- +- for hook in self._state_dict_pre_hooks.values(): +- hook(self, prefix, keep_vars) +- self._save_to_state_dict(destination, prefix, keep_vars) +- for name, module in self._modules.items(): +- if module is not None: +- module.state_dict(destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars) +- for hook in self._state_dict_hooks.values(): +- hook_result = hook(self, destination, prefix, local_metadata) +- if hook_result is not None: +- destination = hook_result +- return destination +- +- def _register_load_state_dict_pre_hook(self, hook, with_module=False): +- r"""Register a pre-hook for the :meth:`~nn.Module.load_state_dict` method. +- +- These hooks will be called with arguments: `state_dict`, `prefix`, +- `local_metadata`, `strict`, `missing_keys`, `unexpected_keys`, +- `error_msgs`, before loading `state_dict` into `self`. These arguments +- are exactly the same as those of `_load_from_state_dict`. +- +- If ``with_module`` is ``True``, then the first argument to the hook is +- an instance of the module. +- +- Arguments: +- hook (Callable): Callable hook that will be invoked before +- loading the state dict. +- with_module (bool, optional): Whether or not to pass the module +- instance to the hook as the first parameter. +- """ +- handle = hooks.RemovableHandle(self._load_state_dict_pre_hooks) +- self._load_state_dict_pre_hooks[handle.id] = _WrappedHook(hook, self if with_module else None) +- return handle +- +- def register_load_state_dict_post_hook(self, hook): +- r"""Register a post hook to be run after module's ``load_state_dict`` is called. +- +- It should have the following signature:: +- hook(module, incompatible_keys) -> None +- +- The ``module`` argument is the current module that this hook is registered +- on, and the ``incompatible_keys`` argument is a ``NamedTuple`` consisting +- of attributes ``missing_keys`` and ``unexpected_keys``. ``missing_keys`` +- is a ``list`` of ``str`` containing the missing keys and +- ``unexpected_keys`` is a ``list`` of ``str`` containing the unexpected keys. +- +- The given incompatible_keys can be modified inplace if needed. +- +- Note that the checks performed when calling :func:`load_state_dict` with +- ``strict=True`` are affected by modifications the hook makes to +- ``missing_keys`` or ``unexpected_keys``, as expected. Additions to either +- set of keys will result in an error being thrown when ``strict=True``, and +- clearing out both missing and unexpected keys will avoid an error. +- +- Returns: +- :class:`utils.hooks.RemovableHandle`: +- a handle that can be used to remove the added hook by calling +- ``handle.remove()`` +- """ +- handle = hooks.RemovableHandle(self._load_state_dict_post_hooks) +- self._load_state_dict_post_hooks[handle.id] = hook +- return handle +- +- def parameters_dict(self, recurse=True): +- param_dict = OrderedDict() +- for name, param in self.named_parameters(recurse=recurse, remove_duplicate=False): +- param_dict[name] = param +- return param_dict +- +- def register_forward_pre_hook( +- self, +- hook: Union[ +- Callable[[T, Tuple[Any, ...]], Optional[Any]], +- Callable[[T, Tuple[Any, ...], Dict[str, Any]], Optional[Tuple[Any, Dict[str, Any]]]], +- ], +- *, +- prepend: bool = False, +- with_kwargs: bool = False, +- ) -> RemovableHandle: +- r"""Registers a forward pre-hook on the module. +- +- The hook will be called every time before :func:`forward` is invoked. +- +- +- If ``with_kwargs`` is false or not specified, the input contains only +- the positional arguments given to the module. Keyword arguments won't be +- passed to the hooks and only to the ``forward``. The hook can modify the +- input. User can either return a tuple or a single modified value in the +- hook. We will wrap the value into a tuple if a single value is returned +- (unless that value is already a tuple). The hook should have the +- following signature:: +- +- hook(module, args) -> None or modified input +- +- If ``with_kwargs`` is true, the forward pre-hook will be passed the +- kwargs given to the forward function. And if the hook modifies the +- input, both the args and kwargs should be returned. The hook should have +- the following signature:: +- +- hook(module, args, kwargs) -> None or a tuple of modified input and kwargs +- +- Args: +- hook (Callable): The user defined hook to be registered. +- prepend (bool): If true, the provided ``hook`` will be fired before +- all existing ``forward_pre`` hooks on this +- :class:`nn.modules.Module`. Otherwise, the provided +- ``hook`` will be fired after all existing ``forward_pre`` hooks +- on this :class:`nn.modules.Module`. Note that global +- ``forward_pre`` hooks registered with +- :func:`register_module_forward_pre_hook` will fire before all +- hooks registered by this method. +- Default: ``False`` +- with_kwargs (bool): If true, the ``hook`` will be passed the kwargs +- given to the forward function. +- Default: ``False`` +- +- Returns: +- :class:`utils.hooks.RemovableHandle`: +- a handle that can be used to remove the added hook by calling +- ``handle.remove()`` +- """ +- handle = hooks.RemovableHandle( +- self._forward_pre_hooks, +- extra_dict=self._forward_pre_hooks_with_kwargs +- ) +- self._forward_pre_hooks[handle.id] = hook +- if with_kwargs: +- self._forward_pre_hooks_with_kwargs[handle.id] = True +- +- if prepend: +- self._forward_pre_hooks.move_to_end(handle.id, last=False) # type: ignore[attr-defined] +- return handle +- +- +- def register_forward_hook( +- self, +- hook: Union[ +- Callable[[T, Tuple[Any, ...], Any], Optional[Any]], +- Callable[[T, Tuple[Any, ...], Dict[str, Any], Any], Optional[Any]], +- ], +- *, +- prepend: bool = False, +- with_kwargs: bool = False, +- ) -> RemovableHandle: +- r"""Registers a forward hook on the module. +- +- The hook will be called every time after :func:`forward` has computed an output. +- +- If ``with_kwargs`` is ``False`` or not specified, the input contains only +- the positional arguments given to the module. Keyword arguments won't be +- passed to the hooks and only to the ``forward``. The hook can modify the +- output. It can modify the input inplace but it will not have effect on +- forward since this is called after :func:`forward` is called. The hook +- should have the following signature:: +- +- hook(module, args, output) -> None or modified output +- +- If ``with_kwargs`` is ``True``, the forward hook will be passed the +- ``kwargs`` given to the forward function and be expected to return the +- output possibly modified. The hook should have the following signature:: +- +- hook(module, args, kwargs, output) -> None or modified output +- +- Args: +- hook (Callable): The user defined hook to be registered. +- prepend (bool): If ``True``, the provided ``hook`` will be fired +- before all existing ``forward`` hooks on this +- :class:`nn.modules.Module`. Otherwise, the provided +- ``hook`` will be fired after all existing ``forward`` hooks on +- this :class:`nn.modules.Module`. Note that global +- ``forward`` hooks registered with +- :func:`register_module_forward_hook` will fire before all hooks +- registered by this method. +- Default: ``False`` +- with_kwargs (bool): If ``True``, the ``hook`` will be passed the +- kwargs given to the forward function. +- Default: ``False`` +- +- Returns: +- :class:`utils.hooks.RemovableHandle`: +- a handle that can be used to remove the added hook by calling +- ``handle.remove()`` +- """ +- handle = hooks.RemovableHandle( +- self._forward_hooks, +- extra_dict=self._forward_hooks_with_kwargs +- ) +- self._forward_hooks[handle.id] = hook +- if with_kwargs: +- self._forward_hooks_with_kwargs[handle.id] = True +- +- if prepend: +- self._forward_hooks.move_to_end(handle.id, last=False) # type: ignore[attr-defined] +- return handle +- +- def zero_grad(self, set_to_none: bool = True) -> None: +- r"""Reset gradients of all model parameters. +- +- See similar function under :class:`mindtorch.optim.Optimizer` for more context. +- +- Args: +- set_to_none (bool): instead of setting to zero, set the grads to None. +- See :meth:`mindtorch.optim.Optimizer.zero_grad` for details. +- """ +- if getattr(self, "_is_replica", False): +- warnings.warn( +- "Calling .zero_grad() from a module created with nn.DataParallel() has no effect. " +- "The parameters are copied (in a differentiable manner) from the original module. " +- "This means they are not leaf nodes in autograd and so don't accumulate gradients. " +- "If you need gradients in your forward method, consider using autograd.grad instead." +- ) +- +- for p in self.parameters(): +- if p.grad is not None: +- p.grad = None ++"""Module""" ++import warnings ++import weakref ++import functools ++import inspect ++from typing import Dict, Optional, Callable, Set, overload, TypeVar, Any, Iterator, Tuple, Union, \ ++ Mapping, List ++import itertools ++from collections import OrderedDict, namedtuple ++import mindspore ++try: ++ from mindspore.common._stub_tensor import StubTensor ++except: ++ class StubTensor: pass ++ ++import mindtorch ++from mindtorch import device, dtype, Tensor ++from mindspore import ParameterTuple, Parameter as MsParameter ++ ++from ..parameter import Parameter, Buffer ++from ...utils import hooks ++from ...utils.hooks import RemovableHandle ++ ++_grad_t = Union[Tuple[Tensor, ...], Tensor] ++T = TypeVar('T', bound='Module') ++ ++class _IncompatibleKeys(namedtuple('IncompatibleKeys', ['missing_keys', 'unexpected_keys'])): ++ def __repr__(self): ++ if not self.missing_keys and not self.unexpected_keys: ++ return '' ++ return super().__repr__() ++ ++ __str__ = __repr__ ++ ++def _addindent(s_, numSpaces): ++ s = s_.split('\n') ++ # don't do anything for single-line stuff ++ if len(s) == 1: ++ return s_ ++ first = s.pop(0) ++ s = [(numSpaces * ' ') + line for line in s] ++ s = '\n'.join(s) ++ s = first + '\n' + s ++ return s ++ ++_EXTRA_STATE_KEY_SUFFIX = '_extra_state' ++ ++_global_buffer_registration_hooks: Dict[int, Callable] = OrderedDict() ++_global_module_registration_hooks: Dict[int, Callable] = OrderedDict() ++_global_parameter_registration_hooks: Dict[int, Callable] = OrderedDict() ++ ++ ++_global_backward_pre_hooks: Dict[int, Callable] = OrderedDict() ++_global_backward_hooks: Dict[int, Callable] = OrderedDict() ++_global_is_full_backward_hook: Optional[bool] = None ++_global_forward_pre_hooks: Dict[int, Callable] = OrderedDict() ++_global_forward_hooks: Dict[int, Callable] = OrderedDict() ++_global_forward_hooks_always_called: Dict[int, bool] = OrderedDict() ++_global_forward_hooks_with_kwargs: Dict[int, bool] = OrderedDict() ++ ++ ++class _WrappedHook: ++ def __init__(self, hook: Callable, module: Optional["Module"] = None): ++ self.hook: Callable = hook ++ functools.update_wrapper(self, hook) ++ ++ self.with_module: bool = False ++ ++ if module is not None: ++ self.module: weakref.ReferenceType[Module] = weakref.ref(module) ++ self.with_module = True ++ ++ def __call__(self, *args: Any, **kwargs: Any) -> Any: ++ if self.with_module: ++ module = self.module() ++ if module is None: ++ raise RuntimeError("You are trying to call the hook of a dead Module!") ++ return self.hook(module, *args, **kwargs) ++ return self.hook(*args, **kwargs) ++ ++ def __getstate__(self) -> Dict: ++ result = {"hook": self.hook, "with_module": self.with_module} ++ if self.with_module: ++ result["module"] = self.module() ++ ++ return result ++ ++ def __setstate__(self, state: Dict): ++ self.hook = state["hook"] ++ self.with_module = state["with_module"] ++ ++ if self.with_module: ++ if state["module"] is None: ++ raise RuntimeError("You are trying to revive the hook of a dead Module!") ++ self.module = weakref.ref(state["module"]) ++ ++ ++def register_module_buffer_registration_hook( ++ hook: Callable[..., None], ++) -> RemovableHandle: ++ r"""Register a buffer registration hook common to all modules. ++ ++ .. warning :: ++ ++ This adds global state to the `nn.Module` module ++ ++ The hook will be called every time :func:`register_buffer` is invoked. ++ It should have the following signature:: ++ ++ hook(module, name, buffer) -> None or new buffer ++ ++ The hook can modify the input or return a single modified value in the hook. ++ ++ Returns: ++ :class:`mindtorch.utils.hooks.RemovableHandle`: ++ a handle that can be used to remove the added hook by calling ++ ``handle.remove()`` ++ """ ++ handle = RemovableHandle(_global_buffer_registration_hooks) ++ _global_buffer_registration_hooks[handle.id] = hook ++ return handle ++ ++ ++def register_module_module_registration_hook( ++ hook: Callable[..., None], ++) -> RemovableHandle: ++ r"""Register a module registration hook common to all modules. ++ ++ .. warning :: ++ ++ This adds global state to the `nn.Module` module ++ ++ The hook will be called every time :func:`register_module` is invoked. ++ It should have the following signature:: ++ ++ hook(module, name, submodule) -> None or new submodule ++ ++ The hook can modify the input or return a single modified value in the hook. ++ ++ Returns: ++ :class:`mindtorch.utils.hooks.RemovableHandle`: ++ a handle that can be used to remove the added hook by calling ++ ``handle.remove()`` ++ """ ++ handle = RemovableHandle(_global_module_registration_hooks) ++ _global_module_registration_hooks[handle.id] = hook ++ return handle ++ ++ ++def register_module_parameter_registration_hook( ++ hook: Callable[..., None], ++) -> RemovableHandle: ++ r"""Register a parameter registration hook common to all modules. ++ ++ .. warning :: ++ ++ This adds global state to the `nn.Module` module ++ ++ The hook will be called every time :func:`register_parameter` is invoked. ++ It should have the following signature:: ++ ++ hook(module, name, param) -> None or new parameter ++ ++ The hook can modify the input or return a single modified value in the hook. ++ ++ Returns: ++ :class:`mindtorch.utils.hooks.RemovableHandle`: ++ a handle that can be used to remove the added hook by calling ++ ``handle.remove()`` ++ """ ++ handle = RemovableHandle(_global_parameter_registration_hooks) ++ _global_parameter_registration_hooks[handle.id] = hook ++ return handle ++ ++ ++def register_module_forward_pre_hook(hook: Callable[..., None]) -> RemovableHandle: ++ r"""Register a forward pre-hook common to all modules. ++ ++ .. warning :: ++ ++ This adds global state to the `nn.module` module ++ and it is only intended for debugging/profiling purposes. ++ ++ The hook will be called every time before :func:`forward` is invoked. ++ It should have the following signature:: ++ ++ hook(module, input) -> None or modified input ++ ++ The input contains only the positional arguments given to the module. ++ Keyword arguments won't be passed to the hooks and only to the ``forward``. ++ The hook can modify the input. User can either return a tuple or a ++ single modified value in the hook. We will wrap the value into a tuple ++ if a single value is returned(unless that value is already a tuple). ++ ++ This hook has precedence over the specific module hooks registered with ++ ``register_forward_pre_hook``. ++ ++ Returns: ++ :class:`mindtorch.utils.hooks.RemovableHandle`: ++ a handle that can be used to remove the added hook by calling ++ ``handle.remove()`` ++ """ ++ handle = RemovableHandle(_global_forward_pre_hooks) ++ _global_forward_pre_hooks[handle.id] = hook ++ return handle ++ ++ ++def register_module_forward_hook( ++ hook: Callable[..., None], ++ *, ++ with_kwargs: bool = False, ++ always_call: bool = False, ++) -> RemovableHandle: ++ r"""Register a global forward hook for all the modules. ++ ++ .. warning :: ++ ++ This adds global state to the `nn.module` module ++ and it is only intended for debugging/profiling purposes. ++ ++ The hook will be called every time after :func:`forward` has computed an output. ++ It should have the following signature:: ++ ++ hook(module, input, output) -> None or modified output ++ ++ The input contains only the positional arguments given to the module. ++ Keyword arguments won't be passed to the hooks and only to the ``forward``. ++ You can optionally modify the output of the module by returning a new value ++ that will replace the output from the :func:`forward` function. ++ ++ Parameters: ++ hook (Callable): The user defined hook to be registered. ++ always_call (bool): If ``True`` the ``hook`` will be run regardless of ++ whether an exception is raised while calling the Module. ++ Default: ``False`` ++ Returns: ++ :class:`mindtorch.utils.hooks.RemovableHandle`: ++ a handle that can be used to remove the added hook by calling ++ ``handle.remove()`` ++ ++ This hook will be executed before specific module hooks registered with ++ ``register_forward_hook``. ++ """ ++ handle = RemovableHandle( ++ _global_forward_hooks, extra_dict=_global_forward_hooks_always_called ++ ) ++ _global_forward_hooks[handle.id] = hook ++ if with_kwargs: ++ _global_forward_hooks_with_kwargs[handle.id] = True ++ if always_call: ++ _global_forward_hooks_always_called[handle.id] = True ++ return handle ++ ++ ++def register_module_backward_hook( ++ hook: Callable[["Module", _grad_t, _grad_t], Union[None, _grad_t]], ++) -> RemovableHandle: ++ r"""Register a backward hook common to all the modules. ++ ++ This function is deprecated in favor of ++ :func:`mindtorch.nn.modules.module.register_module_full_backward_hook` ++ and the behavior of this function will change in future versions. ++ ++ Returns: ++ :class:`mindtorch.utils.hooks.RemovableHandle`: ++ a handle that can be used to remove the added hook by calling ++ ``handle.remove()`` ++ ++ """ ++ global _global_is_full_backward_hook ++ if _global_is_full_backward_hook is True: ++ raise RuntimeError( ++ "Cannot use both regular backward hooks and full backward hooks as a " ++ "global Module hook. Please use only one of them." ++ ) ++ ++ _global_is_full_backward_hook = False ++ ++ handle = RemovableHandle(_global_backward_hooks) ++ _global_backward_hooks[handle.id] = hook ++ return handle ++ ++ ++def register_module_full_backward_pre_hook( ++ hook: Callable[["Module", _grad_t], Union[None, _grad_t]], ++) -> RemovableHandle: ++ r"""Register a backward pre-hook common to all the modules. ++ ++ .. warning :: ++ This adds global state to the `nn.module` module ++ and it is only intended for debugging/profiling purposes. ++ ++ Hooks registered using this function behave in the same way as those ++ registered by :meth:`mindtorch.nn.Module.register_full_backward_pre_hook`. ++ Refer to its documentation for more details. ++ ++ Hooks registered using this function will be called before hooks registered ++ using :meth:`mindtorch.nn.Module.register_full_backward_pre_hook`. ++ ++ Returns: ++ :class:`mindtorch.utils.hooks.RemovableHandle`: ++ a handle that can be used to remove the added hook by calling ++ ``handle.remove()`` ++ ++ """ ++ handle = RemovableHandle(_global_backward_pre_hooks) ++ _global_backward_pre_hooks[handle.id] = hook ++ return handle ++ ++ ++def register_module_full_backward_hook( ++ hook: Callable[["Module", _grad_t, _grad_t], Union[None, _grad_t]], ++) -> RemovableHandle: ++ r"""Register a backward hook common to all the modules. ++ ++ .. warning :: ++ This adds global state to the `nn.module` module ++ and it is only intended for debugging/profiling purposes. ++ ++ Hooks registered using this function behave in the same way as those ++ registered by :meth:`mindtorch.nn.Module.register_full_backward_hook`. ++ Refer to its documentation for more details. ++ ++ Hooks registered using this function will be called before hooks registered ++ using :meth:`mindtorch.nn.Module.register_full_backward_hook`. ++ ++ Returns: ++ :class:`mindtorch.utils.hooks.RemovableHandle`: ++ a handle that can be used to remove the added hook by calling ++ ``handle.remove()`` ++ ++ """ ++ global _global_is_full_backward_hook ++ if _global_is_full_backward_hook is False: ++ raise RuntimeError( ++ "Cannot use both regular backward hooks and full backward hooks as a " ++ "global Module hook. Please use only one of them." ++ ) ++ ++ _global_is_full_backward_hook = True ++ ++ handle = RemovableHandle(_global_backward_hooks) ++ _global_backward_hooks[handle.id] = hook ++ return handle ++ ++ ++# Trick mypy into not applying contravariance rules to inputs by defining ++# forward as a value, rather than a function. See also ++# https://github.com/python/mypy/issues/8795 ++def _forward_unimplemented(self, *input: Any) -> None: ++ r"""Define the computation performed at every call. ++ ++ Should be overridden by all subclasses. ++ ++ .. note:: ++ Although the recipe for forward pass needs to be defined within ++ this function, one should call the :class:`Module` instance afterwards ++ instead of this since the former takes care of running the ++ registered hooks while the latter silently ignores them. ++ """ ++ raise NotImplementedError( ++ f'Module [{type(self).__name__}] is missing the required "forward" function' ++ ) ++ ++class Module: ++ r"""Base class for all neural network modules. ++ ++ Your models should also subclass this class. ++ ++ Modules can also contain other Modules, allowing to nest them in ++ a tree structure. You can assign the submodules as regular attributes:: ++ ++ import minispore.nn as nn ++ import minispore.nn.functional as F ++ ++ class Model(nn.Module): ++ def __init__(self): ++ super(Model, self).__init__() ++ self.conv1 = nn.Conv2d(1, 20, 5) ++ self.conv2 = nn.Conv2d(20, 20, 5) ++ ++ def forward(self, x): ++ x = F.relu(self.conv1(x)) ++ return F.relu(self.conv2(x)) ++ """ ++ ++ __ms_class__ = False ++ training: bool ++ _parameters: Dict[str, Optional[Parameter]] ++ _buffers: Dict[str, Optional[Tensor]] ++ _non_persistent_buffers_set: Set[str] ++ _backward_pre_hooks: Dict[int, Callable] ++ _backward_hooks: Dict[int, Callable] ++ _is_full_backward_hook: Optional[bool] ++ _forward_hooks: Dict[int, Callable] ++ # Marks whether the corresponding _forward_hooks accept kwargs or not. ++ # As JIT does not support Set[int], this dict is used as a set, where all ++ # hooks represented in this dict accept kwargs. ++ _forward_hooks_with_kwargs: Dict[int, bool] ++ # forward hooks that should always be called even if an exception is raised ++ _forward_hooks_always_called: Dict[int, bool] ++ _forward_pre_hooks: Dict[int, Callable] ++ # Marks whether the corresponding _forward_hooks accept kwargs or not. ++ # As JIT does not support Set[int], this dict is used as a set, where all ++ # hooks represented in this dict accept kwargs. ++ _forward_pre_hooks_with_kwargs: Dict[int, bool] ++ _state_dict_hooks: Dict[int, Callable] ++ _load_state_dict_pre_hooks: Dict[int, Callable] ++ _state_dict_pre_hooks: Dict[int, Callable] ++ _load_state_dict_post_hooks: Dict[int, Callable] ++ _modules: Dict[str, Optional['Module']] ++ call_super_init: bool = False ++ _compiled_call_impl : Optional[Callable] = None ++ ++ def __init__(self): ++ """ ++ Calls super().__setattr__('a', a) instead of the typical self.a = a ++ to avoid Module.__setattr__ overhead. Module's __setattr__ has special ++ handling for parameters, submodules, and buffers but simply calls into ++ super().__setattr__ for all other attributes. ++ """ ++ super().__setattr__('training', True) ++ super().__setattr__('_parameters', OrderedDict()) ++ super().__setattr__('_buffers', OrderedDict()) ++ super().__setattr__('_non_persistent_buffers_set', set()) ++ super().__setattr__('_backward_pre_hooks', OrderedDict()) ++ super().__setattr__('_backward_hooks', OrderedDict()) ++ super().__setattr__('_is_full_backward_hook', None) ++ super().__setattr__('_forward_hooks', OrderedDict()) ++ super().__setattr__('_forward_hooks_with_kwargs', OrderedDict()) ++ super().__setattr__('_forward_hooks_always_called', OrderedDict()) ++ super().__setattr__('_forward_pre_hooks', OrderedDict()) ++ super().__setattr__('_forward_pre_hooks_with_kwargs', OrderedDict()) ++ super().__setattr__('_state_dict_hooks', OrderedDict()) ++ super().__setattr__('_state_dict_pre_hooks', OrderedDict()) ++ super().__setattr__('_load_state_dict_pre_hooks', OrderedDict()) ++ super().__setattr__('_load_state_dict_post_hooks', OrderedDict()) ++ super().__setattr__('_modules', OrderedDict()) ++ ++ def forward(self, *input, **kwargs): ++ """Defines the computation performed at every call. ++ ++ Should be overriden by all subclasses. ++ ++ .. note:: ++ Although the recipe for forward pass needs to be defined within ++ this function, one should call the :class:`Module` instance afterwards ++ instead of this since the former takes care of running the ++ registered hooks while the latter silently ignores them. ++ """ ++ raise NotImplementedError ++ ++ def register_buffer(self, name: str, tensor: Optional[Tensor], persistent: bool = True) -> None: ++ r"""Add a buffer to the module. ++ ++ This is typically used to register a buffer that should not to be ++ considered a model parameter. For example, BatchNorm's ``running_mean`` ++ is not a parameter, but is part of the module's state. Buffers, by ++ default, are persistent and will be saved alongside parameters. This ++ behavior can be changed by setting :attr:`persistent` to ``False``. The ++ only difference between a persistent buffer and a non-persistent buffer ++ is that the latter will not be a part of this module's ++ :attr:`state_dict`. ++ ++ Buffers can be accessed as attributes using given names. ++ ++ Args: ++ name (str): name of the buffer. The buffer can be accessed ++ from this module using the given name ++ tensor (Tensor or None): buffer to be registered. If ``None``, then operations ++ that run on buffers, such as :attr:`cuda`, are ignored. If ``None``, ++ the buffer is **not** included in the module's :attr:`state_dict`. ++ persistent (bool): whether the buffer is part of this module's ++ :attr:`state_dict`. ++ ++ Example:: ++ ++ >>> # xdoctest: +SKIP("undefined vars") ++ >>> self.register_buffer('running_mean', ops.zeros(num_features)) ++ ++ """ ++ if '_buffers' not in self.__dict__: ++ raise AttributeError( ++ "cannot assign buffer before Module.__init__() call") ++ elif not isinstance(name, str): ++ raise TypeError(f"buffer name should be a string. Got {type(name)}") ++ elif '.' in name: ++ raise KeyError("buffer name can't contain \".\"") ++ elif name == '': ++ raise KeyError("buffer name can't be empty string \"\"") ++ elif hasattr(self, name) and name not in self._buffers: ++ raise KeyError(f"attribute '{name}' already exists") ++ elif tensor is not None and not isinstance(tensor, mindtorch.Tensor): ++ raise TypeError(f"cannot assign '{type(tensor)}' object to buffer '{name}' " ++ "(torch Tensor or None required)" ++ ) ++ else: ++ for hook in _global_buffer_registration_hooks.values(): ++ output = hook(self, name, tensor) ++ if output is not None: ++ tensor = output ++ if isinstance(tensor, StubTensor): ++ tensor = mindspore.Tensor(tensor.stub_sync()) ++ self._buffers[name] = tensor ++ if persistent: ++ self._non_persistent_buffers_set.discard(name) ++ else: ++ self._non_persistent_buffers_set.add(name) ++ ++ def register_parameter(self, name: str, param: Optional[Parameter]) -> None: ++ r"""Add a parameter to the module. ++ ++ The parameter can be accessed as an attribute using given name. ++ ++ Args: ++ name (str): name of the parameter. The parameter can be accessed ++ from this module using the given name ++ param (Parameter or None): parameter to be added to the module. If ++ ``None``, then operations that run on parameters, such as :attr:`cuda`, ++ are ignored. If ``None``, the parameter is **not** included in the ++ module's :attr:`state_dict`. ++ """ ++ if '_parameters' not in self.__dict__: ++ raise AttributeError( ++ "cannot assign parameter before Module.__init__() call") ++ ++ elif not isinstance(name, str): ++ raise TypeError(f"parameter name should be a string. Got {type(name)}") ++ elif '.' in name: ++ raise KeyError("parameter name can't contain \".\"") ++ elif name == '': ++ raise KeyError("parameter name can't be empty string \"\"") ++ elif hasattr(self, name) and name not in self._parameters: ++ raise KeyError(f"attribute '{name}' already exists") ++ ++ if param is None: ++ self._parameters[name] = None ++ elif not isinstance(param, Parameter): ++ raise TypeError(f"cannot assign '{type(param)}' object to parameter '{name}' " ++ "(nn.Parameter or None required)" ++ ) ++ else: ++ for hook in _global_parameter_registration_hooks.values(): ++ output = hook(self, name, param) ++ if output is not None: ++ param = output ++ self._parameters[name] = param ++ ++ def add_module(self, name: str, module: Optional["Module"]) -> None: ++ r"""Add a child module to the current module. ++ ++ The module can be accessed as an attribute using the given name. ++ ++ Args: ++ name (str): name of the child module. The child module can be ++ accessed from this module using the given name ++ module (Module): child module to be added to the module. ++ """ ++ if not isinstance(module, Module) and module is not None: ++ raise TypeError(f"{mindtorch.typename(module)} is not a Module subclass") ++ elif not isinstance(name, str): ++ raise TypeError( ++ f"module name should be a string. Got {mindtorch.typename(name)}" ++ ) ++ elif hasattr(self, name) and name not in self._modules: ++ raise KeyError(f"attribute '{name}' already exists") ++ elif "." in name: ++ raise KeyError(f'module name can\'t contain ".", got: {name}') ++ elif name == "": ++ raise KeyError('module name can\'t be empty string ""') ++ for hook in _global_module_registration_hooks.values(): ++ output = hook(self, name, module) ++ if output is not None: ++ module = output ++ self._modules[name] = module ++ ++ def register_module(self, name: str, module: Optional["Module"]) -> None: ++ r"""Alias for :func:`add_module`.""" ++ self.add_module(name, module) ++ ++ def get_parameter(self, target: str) -> "Parameter": ++ """Return the parameter given by ``target`` if it exists, otherwise throw an error. ++ ++ See the docstring for ``get_submodule`` for a more detailed ++ explanation of this method's functionality as well as how to ++ correctly specify ``target``. ++ ++ Args: ++ target: The fully-qualified string name of the Parameter ++ to look for. (See ``get_submodule`` for how to specify a ++ fully-qualified string.) ++ ++ Returns: ++ mindtorch.nn.Parameter: The Parameter referenced by ``target`` ++ ++ Raises: ++ AttributeError: If the target string references an invalid ++ path or resolves to something that is not an ++ ``nn.Parameter`` ++ """ ++ module_path, _, param_name = target.rpartition(".") ++ ++ mod: mindtorch.nn.Module = self.get_submodule(module_path) ++ ++ if not hasattr(mod, param_name): ++ raise AttributeError( ++ mod._get_name() + " has no attribute `" + param_name + "`" ++ ) ++ ++ param: mindtorch.nn.Parameter = getattr(mod, param_name) ++ ++ if not isinstance(param, mindtorch.nn.Parameter): ++ raise AttributeError("`" + param_name + "` is not an nn.Parameter") ++ ++ return param ++ ++ def get_buffer(self, target: str) -> "Tensor": ++ """Return the buffer given by ``target`` if it exists, otherwise throw an error. ++ ++ See the docstring for ``get_submodule`` for a more detailed ++ explanation of this method's functionality as well as how to ++ correctly specify ``target``. ++ ++ Args: ++ target: The fully-qualified string name of the buffer ++ to look for. (See ``get_submodule`` for how to specify a ++ fully-qualified string.) ++ ++ Returns: ++ mindtorch.Tensor: The buffer referenced by ``target`` ++ ++ Raises: ++ AttributeError: If the target string references an invalid ++ path or resolves to something that is not a ++ buffer ++ """ ++ module_path, _, buffer_name = target.rpartition(".") ++ ++ mod: mindtorch.nn.Module = self.get_submodule(module_path) ++ ++ if not hasattr(mod, buffer_name): ++ raise AttributeError( ++ mod._get_name() + " has no attribute `" + buffer_name + "`" ++ ) ++ ++ buffer: mindtorch.Tensor = getattr(mod, buffer_name) ++ ++ if buffer_name not in mod._buffers: ++ raise AttributeError("`" + buffer_name + "` is not a buffer") ++ ++ return buffer ++ ++ ++ def get_extra_state(self) -> Any: ++ """Return any extra state to include in the module's state_dict. ++ ++ Implement this and a corresponding :func:`set_extra_state` for your module ++ if you need to store extra state. This function is called when building the ++ module's `state_dict()`. ++ ++ Note that extra state should be picklable to ensure working serialization ++ of the state_dict. We only provide provide backwards compatibility guarantees ++ for serializing Tensors; other objects may break backwards compatibility if ++ their serialized pickled form changes. ++ ++ Returns: ++ object: Any extra state to store in the module's state_dict ++ """ ++ raise RuntimeError( ++ "Reached a code path in Module.get_extra_state() that should never be called. " ++ "Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml " ++ "to report this bug.") ++ ++ ++ def set_extra_state(self, state: Any) -> None: ++ """Set extra state contained in the loaded `state_dict`. ++ ++ This function is called from :func:`load_state_dict` to handle any extra state ++ found within the `state_dict`. Implement this function and a corresponding ++ :func:`get_extra_state` for your module if you need to store extra state within its ++ `state_dict`. ++ ++ Args: ++ state (dict): Extra state from the `state_dict` ++ """ ++ raise RuntimeError( ++ "Reached a code path in Module.set_extra_state() that should never be called. " ++ "Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml " ++ "to report this bug.") ++ ++ def _apply(self, fn, recurse=True): ++ if recurse: ++ for module in self.children(): ++ module._apply(fn) ++ ++ def compute_should_use_set_data(tensor, tensor_applied): ++ if mindtorch._has_compatible_shallow_copy_type(tensor, tensor_applied): ++ # If the new tensor has compatible tensor type as the existing tensor, ++ # the current behavior is to change the tensor in-place using `.data =`, ++ # and the future behavior is to overwrite the existing tensor. However, ++ # changing the current behavior is a BC-breaking change, and we want it ++ # to happen in future releases. So for now we introduce the ++ # `mindtorch.__future__.get_overwrite_module_params_on_conversion()` ++ # global flag to let the user control whether they want the future ++ # behavior of overwriting the existing tensor or not. ++ return not mindtorch.__future__.get_overwrite_module_params_on_conversion() ++ else: ++ return False ++ ++ should_use_swap_tensors = ( ++ mindtorch.__future__.get_swap_module_params_on_conversion() ++ ) ++ ++ for key, param in self._parameters.items(): ++ if param is None: ++ continue ++ # Tensors stored in modules are graph leaves, and we don't want to ++ # track autograd history of `param_applied`, so we have to use ++ # `with mindtorch.no_grad():` ++ with mindtorch.no_grad(): ++ param_applied = fn(param) ++ p_should_use_set_data = compute_should_use_set_data(param, param_applied) ++ ++ # subclasses may have multiple child tensors so we need to use swap_tensors ++ p_should_use_swap_tensors = should_use_swap_tensors ++ ++ param_grad = param.grad ++ if p_should_use_swap_tensors: ++ try: ++ if param_grad is not None: ++ # Accessing param.grad makes its at::Tensor's use_count 2, which will prevent swapping. ++ # Decrement use count of the gradient by setting to None ++ param.grad = None ++ param_applied = Parameter( ++ param_applied, requires_grad=param.requires_grad ++ ) ++ mindtorch.utils.swap_tensors(param, param_applied) ++ except Exception as e: ++ if param_grad is not None: ++ param.grad = param_grad ++ raise RuntimeError( ++ f"_apply(): Couldn't swap {self._get_name()}.{key}" ++ ) from e ++ out_param = param ++ elif p_should_use_set_data: ++ param.data = param_applied ++ out_param = param ++ else: ++ assert isinstance(param, Parameter) ++ assert param.is_leaf ++ out_param = Parameter(param_applied, param.requires_grad) ++ self._parameters[key] = out_param ++ ++ if param_grad is not None: ++ with mindtorch.no_grad(): ++ grad_applied = fn(param_grad) ++ g_should_use_set_data = compute_should_use_set_data( ++ param_grad, grad_applied ++ ) ++ if p_should_use_swap_tensors: ++ grad_applied.requires_grad_(param_grad.requires_grad) ++ try: ++ mindtorch.utils.swap_tensors(param_grad, grad_applied) ++ except Exception as e: ++ raise RuntimeError( ++ f"_apply(): Couldn't swap {self._get_name()}.{key}.grad" ++ ) from e ++ out_param.grad = param_grad ++ elif g_should_use_set_data: ++ assert out_param.grad is not None ++ out_param.grad.data = grad_applied ++ else: ++ assert param_grad.is_leaf ++ out_param.grad = grad_applied.requires_grad_( ++ param_grad.requires_grad ++ ) ++ ++ for key, buf in self._buffers.items(): ++ if buf is not None: ++ self._buffers[key] = fn(buf) ++ ++ return self ++ ++ def apply(self, fn): ++ """Applies ``fn`` recursively to every submodule (as returned by ``.children()``) ++ as well as self. Typical use includes initializing the parameters of a model ++ (see also :ref:`torch-nn-init`). ++ ++ Args: ++ fn (:class:`Module` -> None): function to be applied to each submodule ++ ++ Returns: ++ Module: self ++ ++ Example: ++ >>> def init_weights(m): ++ >>> print(m) ++ >>> if type(m) == nn.Linear: ++ >>> m.weight.data.fill_(1.0) ++ >>> print(m.weight) ++ >>> ++ >>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2)) ++ >>> net.apply(init_weights) ++ Linear (2 -> 2) ++ Parameter containing: ++ 1 1 ++ 1 1 ++ [mindtorch.Tensor of size 2x2] ++ Linear (2 -> 2) ++ Parameter containing: ++ 1 1 ++ 1 1 ++ [mindtorch.Tensor of size 2x2] ++ Sequential ( ++ (0): Linear (2 -> 2) ++ (1): Linear (2 -> 2) ++ ) ++ """ ++ for module in self.children(): ++ module.apply(fn) ++ fn(self) ++ return self ++ ++ def _wrapped_call_impl(self, *args, **kwargs): ++ if self._compiled_call_impl is not None: ++ return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] ++ return self._call_impl(*args, **kwargs) ++ ++ # torchrec tests the code consistency with the following code ++ # fmt: off ++ def _call_impl(self, *args, **kwargs): ++ forward_call = self.forward ++ # If we don't have any hooks, we want to skip the rest of the logic in ++ # this function, and just call forward. ++ if self.__ms_class__: ++ return forward_call(*args, **kwargs) ++ ++ if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks ++ or _global_backward_pre_hooks or _global_backward_hooks ++ or _global_forward_hooks or _global_forward_pre_hooks): ++ return forward_call(*args, **kwargs) ++ ++ try: ++ result = None ++ called_always_called_hooks = set() ++ ++ full_backward_hooks, non_full_backward_hooks = [], [] ++ backward_pre_hooks = [] ++ if self._backward_pre_hooks or _global_backward_pre_hooks: ++ backward_pre_hooks = self._get_backward_pre_hooks() ++ ++ if self._backward_hooks or _global_backward_hooks: ++ full_backward_hooks, non_full_backward_hooks = self._get_backward_hooks() ++ ++ if _global_forward_pre_hooks or self._forward_pre_hooks: ++ for hook_id, hook in ( ++ *_global_forward_pre_hooks.items(), ++ *self._forward_pre_hooks.items(), ++ ): ++ if hook_id in self._forward_pre_hooks_with_kwargs: ++ args_kwargs_result = hook(self, args, kwargs) # type: ignore[misc] ++ if args_kwargs_result is not None: ++ if isinstance(args_kwargs_result, tuple) and len(args_kwargs_result) == 2: ++ args, kwargs = args_kwargs_result ++ else: ++ raise RuntimeError( ++ "forward pre-hook must return None or a tuple " ++ f"of (new_args, new_kwargs), but got {args_kwargs_result}." ++ ) ++ else: ++ args_result = hook(self, args) ++ if args_result is not None: ++ if not isinstance(args_result, tuple): ++ args_result = (args_result,) ++ args = args_result ++ ++ bw_hook = None ++ # if full_backward_hooks or backward_pre_hooks: ++ # bw_hook = BackwardHook(self, full_backward_hooks, backward_pre_hooks) ++ # args = bw_hook.setup_input_hook(args) ++ ++ result = forward_call(*args, **kwargs) ++ if _global_forward_hooks or self._forward_hooks: ++ for hook_id, hook in ( ++ *_global_forward_hooks.items(), ++ *self._forward_hooks.items(), ++ ): ++ # mark that always called hook is run ++ if hook_id in self._forward_hooks_always_called or hook_id in _global_forward_hooks_always_called: ++ called_always_called_hooks.add(hook_id) ++ ++ if hook_id in self._forward_hooks_with_kwargs: ++ hook_result = hook(self, args, kwargs, result) ++ else: ++ hook_result = hook(self, args, result) ++ ++ if hook_result is not None: ++ result = hook_result ++ ++ if bw_hook: ++ if not isinstance(result, (mindtorch.Tensor, tuple)): ++ warnings.warn("For backward hooks to be called," ++ " module output should be a Tensor or a tuple of Tensors" ++ f" but received {type(result)}") ++ result = bw_hook.setup_output_hook(result) ++ ++ # Handle the non-full backward hooks ++ if non_full_backward_hooks: ++ var = result ++ while not isinstance(var, mindtorch.Tensor): ++ if isinstance(var, dict): ++ var = next(v for v in var.values() if isinstance(v, mindtorch.Tensor)) ++ else: ++ var = var[0] ++ # grad_fn = var.grad_fn ++ # if grad_fn is not None: ++ # for hook in non_full_backward_hooks: ++ # grad_fn.register_hook(_WrappedHook(hook, self)) ++ # self._maybe_warn_non_full_backward_hook(args, result, grad_fn) ++ ++ return result ++ ++ except Exception: ++ # run always called hooks if they have not already been run ++ # For now only forward hooks have the always_call option but perhaps ++ # this functionality should be added to full backward hooks as well. ++ for hook_id, hook in _global_forward_hooks.items(): ++ if hook_id in _global_forward_hooks_always_called and hook_id not in called_always_called_hooks: # type: ignore[possibly-undefined] ++ try: ++ hook_result = hook(self, args, result) # type: ignore[possibly-undefined] ++ if hook_result is not None: ++ result = hook_result ++ except Exception as e: ++ warnings.warn("global module forward hook with ``always_call=True`` raised an exception " ++ f"that was silenced as another error was raised in forward: {str(e)}") ++ continue ++ ++ for hook_id, hook in self._forward_hooks.items(): ++ if hook_id in self._forward_hooks_always_called and hook_id not in called_always_called_hooks: # type: ignore[possibly-undefined] ++ try: ++ if hook_id in self._forward_hooks_with_kwargs: ++ hook_result = hook(self, args, kwargs, result) # type: ignore[possibly-undefined] ++ else: ++ hook_result = hook(self, args, result) # type: ignore[possibly-undefined] ++ if hook_result is not None: ++ result = hook_result ++ except Exception as e: ++ warnings.warn("module forward hook with ``always_call=True`` raised an exception " ++ f"that was silenced as another error was raised in forward: {str(e)}") ++ continue ++ # raise exception raised in try block ++ raise ++ # fmt: on ++ ++ __call__: Callable[..., Any] = _wrapped_call_impl ++ ++ def __getstate__(self): ++ state = self.__dict__.copy() ++ state.pop("_compiled_call_impl", None) ++ return state ++ ++ def __setstate__(self, state): ++ self.__dict__.update(state) ++ ++ # Support loading old checkpoints that don't have the following attrs: ++ if "_forward_pre_hooks" not in self.__dict__: ++ self._forward_pre_hooks = OrderedDict() ++ if "_forward_pre_hooks_with_kwargs" not in self.__dict__: ++ self._forward_pre_hooks_with_kwargs = OrderedDict() ++ if "_forward_hooks_with_kwargs" not in self.__dict__: ++ self._forward_hooks_with_kwargs = OrderedDict() ++ if "_forward_hooks_always_called" not in self.__dict__: ++ self._forward_hooks_always_called = OrderedDict() ++ if "_state_dict_hooks" not in self.__dict__: ++ self._state_dict_hooks = OrderedDict() ++ if "_state_dict_pre_hooks" not in self.__dict__: ++ self._state_dict_pre_hooks = OrderedDict() ++ if "_load_state_dict_pre_hooks" not in self.__dict__: ++ self._load_state_dict_pre_hooks = OrderedDict() ++ if "_load_state_dict_post_hooks" not in self.__dict__: ++ self._load_state_dict_post_hooks = OrderedDict() ++ if "_non_persistent_buffers_set" not in self.__dict__: ++ self._non_persistent_buffers_set = set() ++ if "_is_full_backward_hook" not in self.__dict__: ++ self._is_full_backward_hook = None ++ if "_backward_pre_hooks" not in self.__dict__: ++ self._backward_pre_hooks = OrderedDict() ++ ++ def __getattr__(self, name): ++ if '_parameters' in self.__dict__: ++ _parameters = self.__dict__['_parameters'] ++ if name in _parameters: ++ return _parameters[name] ++ if '_buffers' in self.__dict__: ++ _buffers = self.__dict__['_buffers'] ++ if name in _buffers: ++ return _buffers[name] ++ if '_modules' in self.__dict__: ++ modules = self.__dict__['_modules'] ++ if name in modules: ++ return modules[name] ++ raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") ++ ++ def __setattr__(self, name: str, value: Union[Tensor, "Module"]) -> None: ++ def remove_from(*dicts_or_sets): ++ for d in dicts_or_sets: ++ if name in d: ++ if isinstance(d, dict): ++ del d[name] ++ else: ++ d.discard(name) ++ ++ params = self.__dict__.get("_parameters") ++ if isinstance(value, Parameter): ++ if params is None: ++ raise AttributeError( ++ "cannot assign parameters before Module.__init__() call" ++ ) ++ remove_from( ++ self.__dict__, ++ self._buffers, ++ self._modules, ++ self._non_persistent_buffers_set, ++ ) ++ self.register_parameter(name, value) ++ elif params is not None and name in params: ++ if value is not None: ++ raise TypeError( ++ f"cannot assign '{mindtorch.typename(value)}' as parameter '{name}' " ++ "(mindtorch.nn.Parameter or None expected)" ++ ) ++ self.register_parameter(name, value) ++ else: ++ modules = self.__dict__.get("_modules") ++ if isinstance(value, Module): ++ if modules is None: ++ raise AttributeError( ++ "cannot assign module before Module.__init__() call" ++ ) ++ remove_from( ++ self.__dict__, ++ self._parameters, ++ self._buffers, ++ self._non_persistent_buffers_set, ++ ) ++ for hook in _global_module_registration_hooks.values(): ++ output = hook(self, name, value) ++ if output is not None: ++ value = output ++ modules[name] = value ++ ++ elif modules is not None and name in modules: ++ if value is not None: ++ raise TypeError( ++ f"cannot assign '{mindtorch.typename(value)}' as child module '{name}' " ++ "(mindtorch.nn.Module or None expected)" ++ ) ++ for hook in _global_module_registration_hooks.values(): ++ output = hook(self, name, value) ++ if output is not None: ++ value = output ++ modules[name] = value ++ else: ++ buffers = self.__dict__.get("_buffers") ++ if isinstance(value, Buffer) or buffers is not None and name in buffers: ++ if value is not None and not isinstance(value, mindtorch.Tensor): ++ raise TypeError( ++ f"cannot assign '{mindtorch.typename(value)}' as buffer '{name}' " ++ "(mindtorch.nn.Buffer, mindtorch.Tensor or None expected)" ++ ) ++ if isinstance(value, Buffer): ++ persistent = value.persistent ++ else: ++ persistent = name not in self._non_persistent_buffers_set ++ # === HACK === ++ # This whole block below should just be: ++ # self.register_buffer(name, value, persistent) ++ ++ # But to support subclasses of nn.Module that (wrongfully) implement a ++ # register_buffer() method that doesn't have the "persistent" ++ # argument. Only pass it in if it is accepted otherwise assume ++ # it is always true ++ if ( ++ getattr(self.register_buffer, "__func__", None) ++ is Module.register_buffer ++ ): ++ self.register_buffer(name, value, persistent) ++ else: ++ sign = inspect.signature(self.register_buffer) ++ if "persistent" in sign.parameters: ++ self.register_buffer(name, value, persistent) ++ else: ++ if not persistent: ++ raise RuntimeError( ++ "Registering a non-persistent buffer " ++ "on a Module subclass that implements " ++ "register_buffer() without the persistent " ++ "argument is not allowed." ++ ) ++ # Assume that the implementation without the argument has the ++ # behavior from before the argument was added: persistent=True ++ self.register_buffer(name, value) ++ # === HACK END === ++ else: ++ super().__setattr__(name, value) ++ ++ def __delattr__(self, name): ++ if name in self._parameters: ++ del self._parameters[name] ++ elif name in self._buffers: ++ del self._buffers[name] ++ self._non_persistent_buffers_set.discard(name) ++ elif name in self._modules: ++ del self._modules[name] ++ else: ++ super().__delattr__(name) ++ ++ def _register_state_dict_hook(self, hook): ++ r"""Register a post-hook for the :meth:`~mindtorch.nn.Module.state_dict` method. ++ ++ It should have the following signature:: ++ hook(module, state_dict, prefix, local_metadata) -> None or state_dict ++ ++ The registered hooks can modify the ``state_dict`` inplace or return a new one. ++ If a new ``state_dict`` is returned, it will only be respected if it is the root ++ module that :meth:`~nn.Module.state_dict` is called from. ++ """ ++ if getattr(hook, "_from_public_api", False): ++ raise RuntimeError( ++ "Cannot register the same function as the state dict post hook that was " ++ "previously registered via register_state_dict_post_hook" ++ ) ++ handle = RemovableHandle(self._state_dict_hooks) ++ self._state_dict_hooks[handle.id] = hook ++ return handle ++ ++ def extra_repr(self) -> str: ++ r"""Set the extra representation of the module. ++ ++ To print customized extra information, you should re-implement ++ this method in your own modules. Both single-line and multi-line ++ strings are acceptable. ++ """ ++ return '' ++ ++ ++ def __repr__(self): ++ # We treat the extra repr like the sub-module, one item per line ++ extra_lines = [] ++ extra_repr = self.extra_repr() ++ # empty string will be split into list [''] ++ if extra_repr: ++ extra_lines = extra_repr.split('\n') ++ child_lines = [] ++ for key, module in self._modules.items(): ++ mod_str = repr(module) ++ mod_str = _addindent(mod_str, 2) ++ child_lines.append('(' + key + '): ' + mod_str) ++ lines = extra_lines + child_lines ++ ++ main_str = self._get_name() + '(' ++ if lines: ++ # simple one-liner info, which most builtin Modules will use ++ if len(extra_lines) == 1 and not child_lines: ++ main_str += extra_lines[0] ++ else: ++ main_str += '\n ' + '\n '.join(lines) + '\n' ++ ++ main_str += ')' ++ return main_str ++ ++ def __dir__(self): ++ module_attrs = dir(self.__class__) ++ attrs = list(self.__dict__.keys()) ++ parameters = list(self._parameters.keys()) ++ modules = list(self._modules.keys()) ++ buffers = list(self._buffers.keys()) ++ keys = module_attrs + attrs + parameters + modules + buffers ++ ++ # Eliminate attrs that are not legal Python variable names ++ keys = [key for key in keys if not key[0].isdigit()] ++ ++ return sorted(keys) ++ ++ def cuda(self: T, device: Optional[Union[int, device]] = None) -> T: ++ r"""Move all model parameters and buffers to the GPU. ++ ++ This also makes associated parameters and buffers different objects. So ++ it should be called before constructing optimizer if the module will ++ live on GPU while being optimized. ++ ++ .. note:: ++ This method modifies the module in-place. ++ ++ Args: ++ device (int, optional): if specified, all parameters will be ++ copied to that device ++ ++ Returns: ++ Module: self ++ """ ++ return self._apply(lambda t: t.cuda(device)) ++ ++ def npu(self: T, device: Optional[Union[int, device]] = None) -> T: ++ return self._apply(lambda t: t.npu(device)) ++ ++ def cpu(self: T, device: Optional[Union[int, device]] = None) -> T: ++ return self._apply(lambda t: t.cpu()) ++ ++ ++ def _load_from_state_dict( ++ self, ++ state_dict, ++ prefix, ++ local_metadata, ++ strict, ++ missing_keys, ++ unexpected_keys, ++ error_msgs, ++ ) -> None: ++ r"""Copy parameters and buffers from :attr:`state_dict` into only this module, but not its descendants. ++ ++ This is called on every submodule ++ in :meth:`~mindtorch.nn.Module.load_state_dict`. Metadata saved for this ++ module in input :attr:`state_dict` is provided as :attr:`local_metadata`. ++ For state dicts without metadata, :attr:`local_metadata` is empty. ++ Subclasses can achieve class-specific backward compatible loading using ++ the version number at `local_metadata.get("version", None)`. ++ Additionally, :attr:`local_metadata` can also contain the key ++ `assign_to_params_buffers` that indicates whether keys should be ++ assigned their corresponding tensor in the state_dict. ++ ++ .. note:: ++ :attr:`state_dict` is not the same object as the input ++ :attr:`state_dict` to :meth:`~mindtorch.nn.Module.load_state_dict`. So ++ it can be modified. ++ ++ Args: ++ state_dict (dict): a dict containing parameters and ++ persistent buffers. ++ prefix (str): the prefix for parameters and buffers used in this ++ module ++ local_metadata (dict): a dict containing the metadata for this module. ++ See ++ strict (bool): whether to strictly enforce that the keys in ++ :attr:`state_dict` with :attr:`prefix` match the names of ++ parameters and buffers in this module ++ missing_keys (list of str): if ``strict=True``, add missing keys to ++ this list ++ unexpected_keys (list of str): if ``strict=True``, add unexpected ++ keys to this list ++ error_msgs (list of str): error messages should be added to this ++ list, and will be reported together in ++ :meth:`~mindtorch.nn.Module.load_state_dict` ++ """ ++ for hook in self._load_state_dict_pre_hooks.values(): ++ hook( ++ state_dict, ++ prefix, ++ local_metadata, ++ strict, ++ missing_keys, ++ unexpected_keys, ++ error_msgs, ++ ) ++ ++ persistent_buffers = { ++ k: v ++ for k, v in self._buffers.items() ++ if k not in self._non_persistent_buffers_set ++ } ++ local_name_params = itertools.chain( ++ self._parameters.items(), persistent_buffers.items() ++ ) ++ local_state = {k: v for k, v in local_name_params if v is not None} ++ assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False) ++ use_swap_tensors = mindtorch.__future__.get_swap_module_params_on_conversion() ++ ++ for name, param in local_state.items(): ++ key = prefix + name ++ if key in state_dict: ++ input_param = state_dict[key] ++ if not mindtorch.overrides.is_tensor_like(input_param): ++ error_msgs.append( ++ f'While copying the parameter named "{key}", ' ++ "expected mindtorch.Tensor or Tensor-like object from checkpoint but " ++ f"received {type(input_param)}" ++ ) ++ continue ++ ++ # This is used to avoid copying uninitialized parameters into ++ # non-lazy modules, since they dont have the hook to do the checks ++ # in such case, it will error when accessing the .shape attribute. ++ is_param_lazy = mindtorch.nn.parameter.is_lazy(param) ++ # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+ ++ if ( ++ not is_param_lazy ++ and len(param.shape) == 0 ++ and len(input_param.shape) == 1 ++ ): ++ input_param = input_param[0] ++ ++ if not is_param_lazy and input_param.shape != param.shape: ++ # local shape should match the one in checkpoint ++ error_msgs.append( ++ f"size mismatch for {key}: copying a param with shape {input_param.shape} from checkpoint, " ++ f"the shape in current model is {param.shape}." ++ ) ++ continue ++ ++ if ( ++ param.is_meta ++ and not input_param.is_meta ++ and not assign_to_params_buffers ++ ): ++ warnings.warn( ++ f"for {key}: copying from a non-meta parameter in the checkpoint to a meta " ++ "parameter in the current model, which is a no-op. (Did you mean to " ++ "pass `assign=True` to assign items in the state dictionary to their " ++ "corresponding key in the module instead of copying them in place?)" ++ ) ++ ++ try: ++ with mindtorch.no_grad(): ++ if use_swap_tensors: ++ new_input_param = param.module_load( ++ input_param, assign=assign_to_params_buffers ++ ) ++ if id(new_input_param) == id(input_param) or id( ++ new_input_param ++ ) == id(param): ++ raise RuntimeError( ++ "module_load returned one of self or other, please .detach() " ++ "the result if returning one of the inputs in module_load" ++ ) ++ if isinstance(param, mindtorch.nn.Parameter): ++ if not isinstance(new_input_param, mindtorch.nn.Parameter): ++ new_input_param = mindtorch.nn.Parameter( ++ new_input_param, ++ requires_grad=param.requires_grad, ++ ) ++ else: ++ new_input_param.requires_grad_(param.requires_grad) ++ mindtorch.utils.swap_tensors(param, new_input_param) ++ del new_input_param ++ elif assign_to_params_buffers: ++ # Shape checks are already done above ++ if isinstance(param, mindtorch.nn.Parameter): ++ if not isinstance(input_param, mindtorch.nn.Parameter): ++ input_param = mindtorch.nn.Parameter( ++ input_param, requires_grad=param.requires_grad ++ ) ++ else: ++ input_param.requires_grad_(param.requires_grad) ++ setattr(self, name, input_param) ++ else: ++ param.copy_(input_param) ++ except Exception as ex: ++ action = "swapping" if use_swap_tensors else "copying" ++ error_msgs.append( ++ f'While {action} the parameter named "{key}", ' ++ f"whose dimensions in the model are {param.size()} and " ++ f"whose dimensions in the checkpoint are {input_param.size()}, " ++ f"an exception occurred : {ex.args}." ++ ) ++ elif strict: ++ missing_keys.append(key) ++ ++ extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX ++ if ( ++ getattr(self.__class__, "set_extra_state", Module.set_extra_state) ++ is not Module.set_extra_state ++ ): ++ if extra_state_key in state_dict: ++ self.set_extra_state(state_dict[extra_state_key]) ++ elif strict: ++ missing_keys.append(extra_state_key) ++ elif strict and (extra_state_key in state_dict): ++ unexpected_keys.append(extra_state_key) ++ ++ if strict: ++ for key in state_dict.keys(): ++ if key.startswith(prefix) and key != extra_state_key: ++ input_name = key[len(prefix) :].split(".", 1) ++ # Must be Module if it have attributes ++ if len(input_name) > 1: ++ if input_name[0] not in self._modules: ++ unexpected_keys.append(key) ++ elif input_name[0] not in local_state: ++ unexpected_keys.append(key) ++ ++ def load_state_dict(self, state_dict: Mapping[str, Any], ++ strict: bool = True, assign: bool = False): ++ r"""Copy parameters and buffers from :attr:`state_dict` into this module and its descendants. ++ ++ If :attr:`strict` is ``True``, then ++ the keys of :attr:`state_dict` must exactly match the keys returned ++ by this module's :meth:`~nn.Module.state_dict` function. ++ ++ Args: ++ state_dict (dict): a dict containing parameters and ++ persistent buffers. ++ strict (bool, optional): whether to strictly enforce that the keys ++ in :attr:`state_dict` match the keys returned by this module's ++ :meth:`~nn.Module.state_dict` function. Default: ``True`` ++ assign (bool, optional): When ``False``, the properties of the tensors ++ in the current module are preserved while when ``True``, the ++ properties of the Tensors in the state dict are preserved. The only ++ exception is the ``requires_grad`` field of :class:`~nn.Parameter`s ++ for which the value from the module is preserved. ++ Default: ``False`` ++ ++ Returns: ++ ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: ++ * **missing_keys** is a list of str containing the missing keys ++ * **unexpected_keys** is a list of str containing the unexpected keys ++ ++ Note: ++ If a parameter or buffer is registered as ``None`` and its corresponding key ++ exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a ++ ``RuntimeError``. ++ """ ++ if not isinstance(state_dict, Mapping): ++ raise TypeError(f"Expected state_dict to be dict-like, got {type(state_dict)}.") ++ ++ missing_keys: List[str] = [] ++ unexpected_keys: List[str] = [] ++ error_msgs: List[str] = [] ++ ++ # copy state_dict so _load_from_state_dict can modify it ++ metadata = getattr(state_dict, '_metadata', None) ++ state_dict = OrderedDict(state_dict) ++ ++ if metadata is not None: ++ # mypy isn't aware that "_metadata" exists in state_dict ++ state_dict._metadata = metadata # type: ignore[attr-defined] ++ ++ def load(module, local_state_dict, prefix=''): ++ local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) ++ if assign: ++ local_metadata['assign_to_params_buffers'] = assign ++ module._load_from_state_dict( ++ local_state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) ++ for name, child in module._modules.items(): ++ if child is not None: ++ child_prefix = prefix + name + '.' ++ child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)} ++ load(child, child_state_dict, child_prefix) # noqa: F821 ++ ++ # Note that the hook can modify missing_keys and unexpected_keys. ++ incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys) ++ for hook in module._load_state_dict_post_hooks.values(): ++ out = hook(module, incompatible_keys) ++ assert out is None, ( ++ "Hooks registered with ``register_load_state_dict_post_hook`` are not" ++ "expected to return new values, if incompatible_keys need to be modified," ++ "it should be done inplace." ++ ) ++ ++ load(self, state_dict) ++ del load ++ ++ if strict: ++ if len(unexpected_keys) > 0: ++ error_msgs.insert( ++ 0, 'Unexpected key(s) in state_dict: {}. '.format( ++ ', '.join(f'"{k}"' for k in unexpected_keys))) ++ if len(missing_keys) > 0: ++ error_msgs.insert( ++ 0, 'Missing key(s) in state_dict: {}. '.format( ++ ', '.join(f'"{k}"' for k in missing_keys))) ++ ++ if len(error_msgs) > 0: ++ raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( ++ self.__class__.__name__, "\n\t".join(error_msgs))) ++ return _IncompatibleKeys(missing_keys, unexpected_keys) ++ ++ ++ def _named_members( ++ self, get_members_fn, prefix="", recurse=True, remove_duplicate: bool = True ++ ): ++ r"""Help yield various names + members of modules.""" ++ memo = set() ++ modules = ( ++ self.named_modules(prefix=prefix, remove_duplicate=remove_duplicate) ++ if recurse ++ else [(prefix, self)] ++ ) ++ for module_prefix, module in modules: ++ members = get_members_fn(module) ++ for k, v in members: ++ if v is None or v in memo: ++ continue ++ if remove_duplicate: ++ memo.add(v) ++ name = module_prefix + ("." if module_prefix else "") + k ++ yield name, v ++ ++ def parameters(self, recurse: bool = True) -> Iterator[Parameter]: ++ r"""Return an iterator over module parameters. ++ ++ This is typically passed to an optimizer. ++ ++ Args: ++ recurse (bool): if True, then yields parameters of this module ++ and all submodules. Otherwise, yields only parameters that ++ are direct members of this module. ++ ++ Yields: ++ Parameter: module parameter ++ ++ Example:: ++ ++ >>> # xdoctest: +SKIP("undefined vars") ++ >>> for param in model.parameters(): ++ >>> print(type(param), param.shape) ++ (20L,) ++ (20L, 1L, 5L, 5L) ++ ++ """ ++ for name, param in self.named_parameters(recurse=recurse): ++ yield param ++ ++ def trainable_params(self, recurse: bool = True): ++ def _ensure_ms_parameter(param_obj, base_name, index=None): ++ if isinstance(param_obj, MsParameter): ++ return param_obj ++ if isinstance(param_obj, Parameter): ++ tensor = param_obj ++ else: ++ tensor = param_obj ++ suffix = f"_{index}" if index is not None else "" ++ param_name = getattr(param_obj, "name", None) ++ if not param_name: ++ param_name = f"{base_name}{suffix}" ++ return MsParameter(tensor, name=param_name) ++ ++ params = [] ++ for name, param in self.named_parameters(recurse=recurse): ++ if not param.requires_grad: ++ continue ++ if isinstance(param, ParameterTuple): ++ for idx, inner_param in enumerate(param): ++ params.append(_ensure_ms_parameter(inner_param, name, idx)) ++ else: ++ params.append(_ensure_ms_parameter(param, name)) ++ ++ return ParameterTuple(tuple(params)) ++ ++ def get_submodule(self, target: str) -> "Module": ++ """Return the submodule given by ``target`` if it exists, otherwise throw an error. ++ ++ For example, let's say you have an ``nn.Module`` ``A`` that ++ looks like this: ++ ++ .. code-block:: text ++ ++ A( ++ (net_b): Module( ++ (net_c): Module( ++ (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2)) ++ ) ++ (linear): Linear(in_features=100, out_features=200, bias=True) ++ ) ++ ) ++ ++ (The diagram shows an ``nn.Module`` ``A``. ``A`` has a nested ++ submodule ``net_b``, which itself has two submodules ``net_c`` ++ and ``linear``. ``net_c`` then has a submodule ``conv``.) ++ ++ To check whether or not we have the ``linear`` submodule, we ++ would call ``get_submodule("net_b.linear")``. To check whether ++ we have the ``conv`` submodule, we would call ++ ``get_submodule("net_b.net_c.conv")``. ++ ++ The runtime of ``get_submodule`` is bounded by the degree ++ of module nesting in ``target``. A query against ++ ``named_modules`` achieves the same result, but it is O(N) in ++ the number of transitive modules. So, for a simple check to see ++ if some submodule exists, ``get_submodule`` should always be ++ used. ++ ++ Args: ++ target: The fully-qualified string name of the submodule ++ to look for. (See above example for how to specify a ++ fully-qualified string.) ++ ++ Returns: ++ nn.Module: The submodule referenced by ``target`` ++ ++ Raises: ++ AttributeError: If the target string references an invalid ++ path or resolves to something that is not an ++ ``nn.Module`` ++ """ ++ if target == "": ++ return self ++ ++ atoms: List[str] = target.split(".") ++ mod: Module = self ++ ++ for item in atoms: ++ ++ if not hasattr(mod, item): ++ raise AttributeError(mod._get_name() + " has no " ++ "attribute `" + item + "`") ++ ++ mod = getattr(mod, item) ++ ++ if not isinstance(mod, Module): ++ raise AttributeError("`" + item + "` is not " ++ "an nn.Module") ++ ++ return mod ++ ++ def get_parameters(self, expand=True): ++ return self.parameters(expand) ++ ++ def named_parameters( ++ self, ++ prefix: str = '', ++ recurse: bool = True, ++ remove_duplicate: bool = True ++ ) -> Iterator[Tuple[str, Parameter]]: ++ r"""Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself. ++ ++ Args: ++ prefix (str): prefix to prepend to all parameter names. ++ recurse (bool): if True, then yields parameters of this module ++ and all submodules. Otherwise, yields only parameters that ++ are direct members of this module. ++ remove_duplicate (bool, optional): whether to remove the duplicated ++ parameters in the result. Defaults to True. ++ ++ Yields: ++ (str, Parameter): Tuple containing the name and parameter ++ ++ Example:: ++ ++ >>> # xdoctest: +SKIP("undefined vars") ++ >>> for name, param in self.named_parameters(): ++ >>> if name in ['bias']: ++ >>> print(param.shape) ++ ++ """ ++ gen = self._named_members( ++ lambda module: module._parameters.items(), ++ prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate) ++ yield from gen ++ ++ def parameters_and_names(self, name_prefix='', expand=True): ++ return self.named_parameters(name_prefix, expand) ++ ++ def buffers(self, recurse: bool = True) -> Iterator[Tensor]: ++ r"""Return an iterator over module buffers. ++ ++ Args: ++ recurse (bool): if True, then yields buffers of this module ++ and all submodules. Otherwise, yields only buffers that ++ are direct members of this module. ++ ++ Yields: ++ mindtorch.Tensor: module buffer ++ ++ Example:: ++ ++ >>> # xdoctest: +SKIP("undefined vars") ++ >>> for buf in model.buffers(): ++ >>> print(type(buf), buf.shape) ++ (20L,) ++ (20L, 1L, 5L, 5L) ++ ++ """ ++ for _, buf in self.named_buffers(recurse=recurse): ++ yield buf ++ ++ ++ def named_buffers(self, prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) -> Iterator[Tuple[str, Tensor]]: ++ r"""Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself. ++ ++ Args: ++ prefix (str): prefix to prepend to all buffer names. ++ recurse (bool, optional): if True, then yields buffers of this module ++ and all submodules. Otherwise, yields only buffers that ++ are direct members of this module. Defaults to True. ++ remove_duplicate (bool, optional): whether to remove the duplicated buffers in the result. Defaults to True. ++ ++ Yields: ++ (str, mindtorch.Tensor): Tuple containing the name and buffer ++ ++ Example:: ++ ++ >>> # xdoctest: +SKIP("undefined vars") ++ >>> for name, buf in self.named_buffers(): ++ >>> if name in ['running_var']: ++ >>> print(buf.shape) ++ ++ """ ++ gen = self._named_members( ++ lambda module: module._buffers.items(), ++ prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate) ++ yield from gen ++ ++ def _all_buffers(self, memo=None): ++ if memo is None: ++ memo = set() ++ for name, b in self._buffers.items(): ++ if b is not None and b not in memo: ++ memo.add(b) ++ yield b ++ for module in self.children(): ++ for b in module._all_buffers(memo): ++ yield b ++ ++ def children(self): ++ """Returns an iterator over immediate children modules. ++ ++ Yields: ++ Module: a child module ++ """ ++ for name, module in self.named_children(): ++ yield module ++ ++ def named_children(self): ++ """Returns an iterator over immediate children modules, yielding both ++ the name of the module as well as the module itself. ++ ++ Yields: ++ (string, Module): Tuple containing a name and child module ++ ++ Example: ++ >>> for name, module in model.named_children(): ++ >>> if name in ['conv4', 'conv5']: ++ >>> print(module) ++ """ ++ memo = set() ++ for name, module in self._modules.items(): ++ if module is not None and module not in memo: ++ memo.add(module) ++ yield name, module ++ ++ def modules(self): ++ """Returns an iterator over all modules in the network. ++ ++ Yields: ++ Module: a module in the network ++ ++ Note: ++ Duplicate modules are returned only once. In the following ++ example, ``l`` will be returned only once. ++ ++ >>> l = nn.Linear(2, 2) ++ >>> net = nn.Sequential(l, l) ++ >>> for idx, m in enumerate(net.modules()): ++ >>> print(idx, '->', m) ++ 0 -> Sequential ( ++ (0): Linear (2 -> 2) ++ (1): Linear (2 -> 2) ++ ) ++ 1 -> Linear (2 -> 2) ++ """ ++ for name, module in self.named_modules(): ++ yield module ++ ++ def named_modules(self, memo: Optional[Set['Module']] = None, prefix: str = '', remove_duplicate: bool = True): ++ r"""Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself. ++ ++ Args: ++ memo: a memo to store the set of modules already added to the result ++ prefix: a prefix that will be added to the name of the module ++ remove_duplicate: whether to remove the duplicated module instances in the result ++ or not ++ ++ Yields: ++ (str, Module): Tuple of name and module ++ ++ Note: ++ Duplicate modules are returned only once. In the following ++ example, ``l`` will be returned only once. ++ ++ Example:: ++ ++ >>> l = nn.Linear(2, 2) ++ >>> net = nn.Sequential(l, l) ++ >>> for idx, m in enumerate(net.named_modules()): ++ ... print(idx, '->', m) ++ ++ 0 -> ('', Sequential( ++ (0): Linear(in_features=2, out_features=2, bias=True) ++ (1): Linear(in_features=2, out_features=2, bias=True) ++ )) ++ 1 -> ('0', Linear(in_features=2, out_features=2, bias=True)) ++ ++ """ ++ if memo is None: ++ memo = set() ++ if self not in memo: ++ if remove_duplicate: ++ memo.add(self) ++ yield prefix, self ++ for name, module in self._modules.items(): ++ if module is None: ++ continue ++ submodule_prefix = prefix + ('.' if prefix else '') + name ++ yield from module.named_modules(memo, submodule_prefix, remove_duplicate) ++ ++ def jit(self, mode=True): ++ self.__ms_class__ = mode ++ for module in self.children(): ++ module.jit(mode) ++ return self ++ ++ def compile(self, *args, **kwargs): ++ self.jit() ++ def forward_fn(*args, **kwargs): ++ return self.forward(*args, **kwargs) ++ ++ # forward_fn = mindspore.jit(forward_fn, *args, **kwargs) ++ self._compiled_call_impl = forward_fn ++ ++ @property ++ def skip_syntax(self): ++ return self.__ms_class__ ++ ++ def train(self, mode=True): ++ """Sets the module in training mode. ++ ++ This has any effect only on modules such as Dropout or BatchNorm. ++ ++ Returns: ++ Module: self ++ """ ++ self.training = mode ++ for module in self.children(): ++ module.train(mode) ++ return self ++ ++ set_train = train ++ ++ def eval(self): ++ """Sets the module in evaluation mode. ++ ++ This has any effect only on modules such as Dropout or BatchNorm. ++ """ ++ return self.train(False) ++ ++ def requires_grad_(self: T, requires_grad: bool = True) -> T: ++ r"""Change if autograd should record operations on parameters in this module. ++ ++ This method sets the parameters' :attr:`requires_grad` attributes ++ in-place. ++ ++ This method is helpful for freezing part of the module for finetuning ++ or training parts of a model individually (e.g., GAN training). ++ ++ See :ref:`locally-disable-grad-doc` for a comparison between ++ `.requires_grad_()` and several similar mechanisms that may be confused with it. ++ ++ Args: ++ requires_grad (bool): whether autograd should record operations on ++ parameters in this module. Default: ``True``. ++ ++ Returns: ++ Module: self ++ """ ++ for p in self.parameters(): ++ p.requires_grad = requires_grad ++ return self ++ ++ ++ def _get_name(self): ++ return self.__class__.__name__ ++ ++ def to(self, *args, **kwargs): ++ r"""Move and/or cast the parameters and buffers. ++ ++ This can be called as ++ ++ .. function:: to(device=None, dtype=None, non_blocking=False) ++ :noindex: ++ ++ .. function:: to(dtype, non_blocking=False) ++ :noindex: ++ ++ .. function:: to(tensor, non_blocking=False) ++ :noindex: ++ ++ .. function:: to(memory_format=mindtorch.channels_last) ++ :noindex: ++ ++ Its signature is similar to :meth:`mindtorch.Tensor.to`, but only accepts ++ floating point or complex :attr:`dtype`\ s. In addition, this method will ++ only cast the floating point or complex parameters and buffers to :attr:`dtype` ++ (if given). The integral parameters and buffers will be moved ++ :attr:`device`, if that is given, but with dtypes unchanged. When ++ :attr:`non_blocking` is set, it tries to convert/move asynchronously ++ with respect to the host if possible, e.g., moving CPU Tensors with ++ pinned memory to CUDA devices. ++ ++ See below for examples. ++ ++ .. note:: ++ This method modifies the module in-place. ++ ++ Args: ++ device (:class:`mindtorch.device`): the desired device of the parameters ++ and buffers in this module ++ dtype (:class:`mindtorch.dtype`): the desired floating point or complex dtype of ++ the parameters and buffers in this module ++ tensor (mindtorch.Tensor): Tensor whose dtype and device are the desired ++ dtype and device for all parameters and buffers in this module ++ memory_format (:class:`mindtorch.memory_format`): the desired memory ++ format for 4D parameters and buffers in this module (keyword ++ only argument) ++ ++ Returns: ++ Module: self ++ ++ Examples:: ++ ++ >>> # xdoctest: +IGNORE_WANT("non-deterministic") ++ >>> linear = nn.Linear(2, 2) ++ >>> linear.weight ++ Parameter containing: ++ tensor([[ 0.1913, -0.3420], ++ [-0.5113, -0.2325]]) ++ >>> linear.to(mindtorch.double) ++ Linear(in_features=2, out_features=2, bias=True) ++ >>> linear.weight ++ Parameter containing: ++ tensor([[ 0.1913, -0.3420], ++ [-0.5113, -0.2325]], dtype=mindtorch.float64) ++ >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA1) ++ >>> gpu1 = mindtorch.device("cuda:1") ++ >>> linear.to(gpu1, dtype=mindtorch.half, non_blocking=True) ++ Linear(in_features=2, out_features=2, bias=True) ++ >>> linear.weight ++ Parameter containing: ++ tensor([[ 0.1914, -0.3420], ++ [-0.5112, -0.2324]], dtype=mindtorch.float16, device='cuda:1') ++ >>> cpu = mindtorch.device("cpu") ++ >>> linear.to(cpu) ++ Linear(in_features=2, out_features=2, bias=True) ++ >>> linear.weight ++ Parameter containing: ++ tensor([[ 0.1914, -0.3420], ++ [-0.5112, -0.2324]], dtype=mindtorch.float16) ++ ++ >>> linear = nn.Linear(2, 2, bias=None).to(mindtorch.cdouble) ++ >>> linear.weight ++ Parameter containing: ++ tensor([[ 0.3741+0.j, 0.2382+0.j], ++ [ 0.5593+0.j, -0.4443+0.j]], dtype=mindtorch.complex128) ++ >>> linear(mindtorch.ones(3, 2, dtype=mindtorch.cdouble)) ++ tensor([[0.6122+0.j, 0.1150+0.j], ++ [0.6122+0.j, 0.1150+0.j], ++ [0.6122+0.j, 0.1150+0.j]], dtype=mindtorch.complex128) ++ ++ """ ++ device, dtype, non_blocking, convert_to_format = mindtorch._C._nn._parse_to( ++ *args, **kwargs ++ ) ++ ++ if dtype is not None: ++ if not (dtype.is_floating_point or dtype.is_complex): ++ raise TypeError( ++ "nn.Module.to only accepts floating point or complex " ++ f"dtypes, but got desired dtype={dtype}" ++ ) ++ if dtype.is_complex: ++ warnings.warn( ++ "Complex modules are a new feature under active development whose design may change, " ++ "and some modules might not work as expected when using complex tensors as parameters or buffers. " ++ "Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml " ++ "if a complex module does not work as expected." ++ ) ++ ++ def convert(t): ++ try: ++ if convert_to_format is not None and t.dim() in (4, 5): ++ return t.to( ++ device, ++ dtype if t.is_floating_point() or t.is_complex() else None, ++ non_blocking, ++ memory_format=convert_to_format, ++ ) ++ return t.to( ++ device, ++ dtype if t.is_floating_point() or t.is_complex() else None, ++ non_blocking=non_blocking, ++ ) ++ except NotImplementedError as e: ++ if str(e) == "Cannot copy out of meta tensor; no data!": ++ raise NotImplementedError( ++ f"{e} Please use mindtorch.nn.Module.to_empty() instead of mindtorch.nn.Module.to() " ++ f"when moving module from meta to a different device." ++ ) from None ++ else: ++ raise ++ ++ return self._apply(convert) ++ ++ def half(self: T) -> T: ++ r"""Casts all floating point parameters and buffers to ``half`` datatype. ++ ++ .. note:: ++ This method modifies the module in-place. ++ ++ Returns: ++ Module: self ++ """ ++ return self._apply(lambda t: t.half() if t.is_floating_point() else t) ++ ++ def bfloat16(self: T) -> T: ++ r"""Casts all floating point parameters and buffers to ``bfloat16`` datatype. ++ ++ .. note:: ++ This method modifies the module in-place. ++ ++ Returns: ++ Module: self ++ """ ++ return self._apply(lambda t: t.bfloat16() if t.is_floating_point() else t) ++ ++ def to_empty( ++ self, *, device, recurse: bool = True ++ ): ++ r"""Move the parameters and buffers to the specified device without copying storage. ++ ++ Args: ++ device (:class:`mindtorch.device`): The desired device of the parameters ++ and buffers in this module. ++ recurse (bool): Whether parameters and buffers of submodules should ++ be recursively moved to the specified device. ++ ++ Returns: ++ Module: self ++ """ ++ return self._apply( ++ lambda t: mindtorch.empty_like(t, device=device), recurse=recurse ++ ) ++ ++ def float(self: T) -> T: ++ r"""Casts all floating point parameters and buffers to ``float`` datatype. ++ ++ .. note:: ++ This method modifies the module in-place. ++ ++ Returns: ++ Module: self ++ """ ++ return self._apply(lambda t: t.float() if t.is_floating_point() else t) ++ ++ ++ def double(self: T) -> T: ++ r"""Casts all floating point parameters and buffers to ``double`` datatype. ++ ++ .. note:: ++ This method modifies the module in-place. ++ ++ Returns: ++ Module: self ++ """ ++ return self._apply(lambda t: t.double() if t.is_floating_point() else t) ++ ++ ++ def half(self: T) -> T: ++ r"""Casts all floating point parameters and buffers to ``half`` datatype. ++ ++ .. note:: ++ This method modifies the module in-place. ++ ++ Returns: ++ Module: self ++ """ ++ return self._apply(lambda t: t.half() if t.is_floating_point() else t) ++ ++ ++ def _save_to_state_dict(self, destination, prefix, keep_vars): ++ r"""Save module state to the `destination` dictionary. ++ ++ The `destination` dictionary will contain the state ++ of the module, but not its descendants. This is called on every ++ submodule in :meth:`~nn.Module.state_dict`. ++ ++ In rare cases, subclasses can achieve class-specific behavior by ++ overriding this method with custom logic. ++ ++ Args: ++ destination (dict): a dict where state will be stored ++ prefix (str): the prefix for parameters and buffers used in this ++ module ++ """ ++ for name, param in self._parameters.items(): ++ if param is not None: ++ destination[prefix + name] = param ++ for name, buf in self._buffers.items(): ++ if buf is not None and name not in self._non_persistent_buffers_set: ++ destination[prefix + name] = buf ++ extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX ++ if getattr(self.__class__, "get_extra_state", Module.get_extra_state) is not Module.get_extra_state: ++ destination[extra_state_key] = self.get_extra_state() ++ ++ # The user can pass an optional arbitrary mappable object to `state_dict`, in which case `state_dict` returns ++ # back that same object. But if they pass nothing, an `OrderedDict` is created and returned. ++ T_destination = TypeVar('T_destination', bound=Dict[str, Any]) ++ ++ @overload ++ def state_dict(self, *, destination: T_destination, prefix: str = ..., keep_vars: bool = ...) -> T_destination: ++ ... ++ ++ @overload ++ def state_dict(self, *, prefix: str = ..., keep_vars: bool = ...) -> Dict[str, Any]: ++ ... ++ ++ def state_dict(self, *args, destination=None, prefix='', keep_vars=False): ++ r"""Return a dictionary containing references to the whole state of the module. ++ ++ Both parameters and persistent buffers (e.g. running averages) are ++ included. Keys are corresponding parameter and buffer names. ++ Parameters and buffers set to ``None`` are not included. ++ ++ .. note:: ++ The returned object is a shallow copy. It contains references ++ to the module's parameters and buffers. ++ ++ .. warning:: ++ Currently ``state_dict()`` also accepts positional arguments for ++ ``destination``, ``prefix`` and ``keep_vars`` in order. However, ++ this is being deprecated and keyword arguments will be enforced in ++ future releases. ++ ++ .. warning:: ++ Please avoid the use of argument ``destination`` as it is not ++ designed for end-users. ++ ++ Args: ++ destination (dict, optional): If provided, the state of module will ++ be updated into the dict and the same object is returned. ++ Otherwise, an ``OrderedDict`` will be created and returned. ++ Default: ``None``. ++ prefix (str, optional): a prefix added to parameter and buffer ++ names to compose the keys in state_dict. Default: ``''``. ++ keep_vars (bool, optional): by default the :class:`~mindtorch.Tensor` s ++ returned in the state dict are detached from autograd. If it's ++ set to ``True``, detaching will not be performed. ++ Default: ``False``. ++ ++ Returns: ++ dict: ++ a dictionary containing a whole state of the module ++ ++ Example:: ++ ++ >>> # xdoctest: +SKIP("undefined vars") ++ >>> module.state_dict().keys() ++ ['bias', 'weight'] ++ ++ """ ++ # TODO: Remove `args` and the parsing logic when BC allows. ++ if len(args) > 0: ++ if destination is None: ++ destination = args[0] ++ if len(args) > 1 and prefix == '': ++ prefix = args[1] ++ if len(args) > 2 and keep_vars is False: ++ keep_vars = args[2] ++ # DeprecationWarning is ignored by default ++ warnings.warn( ++ "Positional args are being deprecated, use kwargs instead.") ++ ++ if destination is None: ++ destination = OrderedDict() ++ destination._metadata = OrderedDict() ++ ++ local_metadata = {} ++ if hasattr(destination, "_metadata"): ++ destination._metadata[prefix[:-1]] = local_metadata ++ ++ for hook in self._state_dict_pre_hooks.values(): ++ hook(self, prefix, keep_vars) ++ self._save_to_state_dict(destination, prefix, keep_vars) ++ for name, module in self._modules.items(): ++ if module is not None: ++ module.state_dict(destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars) ++ for hook in self._state_dict_hooks.values(): ++ hook_result = hook(self, destination, prefix, local_metadata) ++ if hook_result is not None: ++ destination = hook_result ++ return destination ++ ++ def _register_load_state_dict_pre_hook(self, hook, with_module=False): ++ r"""Register a pre-hook for the :meth:`~nn.Module.load_state_dict` method. ++ ++ These hooks will be called with arguments: `state_dict`, `prefix`, ++ `local_metadata`, `strict`, `missing_keys`, `unexpected_keys`, ++ `error_msgs`, before loading `state_dict` into `self`. These arguments ++ are exactly the same as those of `_load_from_state_dict`. ++ ++ If ``with_module`` is ``True``, then the first argument to the hook is ++ an instance of the module. ++ ++ Arguments: ++ hook (Callable): Callable hook that will be invoked before ++ loading the state dict. ++ with_module (bool, optional): Whether or not to pass the module ++ instance to the hook as the first parameter. ++ """ ++ handle = hooks.RemovableHandle(self._load_state_dict_pre_hooks) ++ self._load_state_dict_pre_hooks[handle.id] = _WrappedHook(hook, self if with_module else None) ++ return handle ++ ++ def register_load_state_dict_post_hook(self, hook): ++ r"""Register a post hook to be run after module's ``load_state_dict`` is called. ++ ++ It should have the following signature:: ++ hook(module, incompatible_keys) -> None ++ ++ The ``module`` argument is the current module that this hook is registered ++ on, and the ``incompatible_keys`` argument is a ``NamedTuple`` consisting ++ of attributes ``missing_keys`` and ``unexpected_keys``. ``missing_keys`` ++ is a ``list`` of ``str`` containing the missing keys and ++ ``unexpected_keys`` is a ``list`` of ``str`` containing the unexpected keys. ++ ++ The given incompatible_keys can be modified inplace if needed. ++ ++ Note that the checks performed when calling :func:`load_state_dict` with ++ ``strict=True`` are affected by modifications the hook makes to ++ ``missing_keys`` or ``unexpected_keys``, as expected. Additions to either ++ set of keys will result in an error being thrown when ``strict=True``, and ++ clearing out both missing and unexpected keys will avoid an error. ++ ++ Returns: ++ :class:`utils.hooks.RemovableHandle`: ++ a handle that can be used to remove the added hook by calling ++ ``handle.remove()`` ++ """ ++ handle = hooks.RemovableHandle(self._load_state_dict_post_hooks) ++ self._load_state_dict_post_hooks[handle.id] = hook ++ return handle ++ ++ def parameters_dict(self, recurse=True): ++ param_dict = OrderedDict() ++ for name, param in self.named_parameters(recurse=recurse, remove_duplicate=False): ++ param_dict[name] = param ++ return param_dict ++ ++ def register_forward_pre_hook( ++ self, ++ hook: Union[ ++ Callable[[T, Tuple[Any, ...]], Optional[Any]], ++ Callable[[T, Tuple[Any, ...], Dict[str, Any]], Optional[Tuple[Any, Dict[str, Any]]]], ++ ], ++ *, ++ prepend: bool = False, ++ with_kwargs: bool = False, ++ ) -> RemovableHandle: ++ r"""Registers a forward pre-hook on the module. ++ ++ The hook will be called every time before :func:`forward` is invoked. ++ ++ ++ If ``with_kwargs`` is false or not specified, the input contains only ++ the positional arguments given to the module. Keyword arguments won't be ++ passed to the hooks and only to the ``forward``. The hook can modify the ++ input. User can either return a tuple or a single modified value in the ++ hook. We will wrap the value into a tuple if a single value is returned ++ (unless that value is already a tuple). The hook should have the ++ following signature:: ++ ++ hook(module, args) -> None or modified input ++ ++ If ``with_kwargs`` is true, the forward pre-hook will be passed the ++ kwargs given to the forward function. And if the hook modifies the ++ input, both the args and kwargs should be returned. The hook should have ++ the following signature:: ++ ++ hook(module, args, kwargs) -> None or a tuple of modified input and kwargs ++ ++ Args: ++ hook (Callable): The user defined hook to be registered. ++ prepend (bool): If true, the provided ``hook`` will be fired before ++ all existing ``forward_pre`` hooks on this ++ :class:`nn.modules.Module`. Otherwise, the provided ++ ``hook`` will be fired after all existing ``forward_pre`` hooks ++ on this :class:`nn.modules.Module`. Note that global ++ ``forward_pre`` hooks registered with ++ :func:`register_module_forward_pre_hook` will fire before all ++ hooks registered by this method. ++ Default: ``False`` ++ with_kwargs (bool): If true, the ``hook`` will be passed the kwargs ++ given to the forward function. ++ Default: ``False`` ++ ++ Returns: ++ :class:`utils.hooks.RemovableHandle`: ++ a handle that can be used to remove the added hook by calling ++ ``handle.remove()`` ++ """ ++ handle = hooks.RemovableHandle( ++ self._forward_pre_hooks, ++ extra_dict=self._forward_pre_hooks_with_kwargs ++ ) ++ self._forward_pre_hooks[handle.id] = hook ++ if with_kwargs: ++ self._forward_pre_hooks_with_kwargs[handle.id] = True ++ ++ if prepend: ++ self._forward_pre_hooks.move_to_end(handle.id, last=False) # type: ignore[attr-defined] ++ return handle ++ ++ ++ def register_forward_hook( ++ self, ++ hook: Union[ ++ Callable[[T, Tuple[Any, ...], Any], Optional[Any]], ++ Callable[[T, Tuple[Any, ...], Dict[str, Any], Any], Optional[Any]], ++ ], ++ *, ++ prepend: bool = False, ++ with_kwargs: bool = False, ++ ) -> RemovableHandle: ++ r"""Registers a forward hook on the module. ++ ++ The hook will be called every time after :func:`forward` has computed an output. ++ ++ If ``with_kwargs`` is ``False`` or not specified, the input contains only ++ the positional arguments given to the module. Keyword arguments won't be ++ passed to the hooks and only to the ``forward``. The hook can modify the ++ output. It can modify the input inplace but it will not have effect on ++ forward since this is called after :func:`forward` is called. The hook ++ should have the following signature:: ++ ++ hook(module, args, output) -> None or modified output ++ ++ If ``with_kwargs`` is ``True``, the forward hook will be passed the ++ ``kwargs`` given to the forward function and be expected to return the ++ output possibly modified. The hook should have the following signature:: ++ ++ hook(module, args, kwargs, output) -> None or modified output ++ ++ Args: ++ hook (Callable): The user defined hook to be registered. ++ prepend (bool): If ``True``, the provided ``hook`` will be fired ++ before all existing ``forward`` hooks on this ++ :class:`nn.modules.Module`. Otherwise, the provided ++ ``hook`` will be fired after all existing ``forward`` hooks on ++ this :class:`nn.modules.Module`. Note that global ++ ``forward`` hooks registered with ++ :func:`register_module_forward_hook` will fire before all hooks ++ registered by this method. ++ Default: ``False`` ++ with_kwargs (bool): If ``True``, the ``hook`` will be passed the ++ kwargs given to the forward function. ++ Default: ``False`` ++ ++ Returns: ++ :class:`utils.hooks.RemovableHandle`: ++ a handle that can be used to remove the added hook by calling ++ ``handle.remove()`` ++ """ ++ handle = hooks.RemovableHandle( ++ self._forward_hooks, ++ extra_dict=self._forward_hooks_with_kwargs ++ ) ++ self._forward_hooks[handle.id] = hook ++ if with_kwargs: ++ self._forward_hooks_with_kwargs[handle.id] = True ++ ++ if prepend: ++ self._forward_hooks.move_to_end(handle.id, last=False) # type: ignore[attr-defined] ++ return handle ++ ++ def zero_grad(self, set_to_none: bool = True) -> None: ++ r"""Reset gradients of all model parameters. ++ ++ See similar function under :class:`mindtorch.optim.Optimizer` for more context. ++ ++ Args: ++ set_to_none (bool): instead of setting to zero, set the grads to None. ++ See :meth:`mindtorch.optim.Optimizer.zero_grad` for details. ++ """ ++ if getattr(self, "_is_replica", False): ++ warnings.warn( ++ "Calling .zero_grad() from a module created with nn.DataParallel() has no effect. " ++ "The parameters are copied (in a differentiable manner) from the original module. " ++ "This means they are not leaf nodes in autograd and so don't accumulate gradients. " ++ "If you need gradients in your forward method, consider using autograd.grad instead." ++ ) ++ ++ for p in self.parameters(): ++ if p.grad is not None: ++ p.grad = None diff --git a/open_r1/readme.md b/open_r1/readme.md index e69de29bb..20c4a7fbf 100644 --- a/open_r1/readme.md +++ b/open_r1/readme.md @@ -0,0 +1,310 @@ +## Open-R1 基于 MindNLP 的完全复现 + +### 仓库用途 +- 本仓库用于在 MindSpore + MindNLP 环境中完全复现 DeepSeek-R1 / Open-R1 的训练与推理流程。 +- 目标是在尽量对齐 Hugging Face Transformers / TRL 的接口与训练流程的前提下,提供可直接运行的复现方案与脚本。 + +### 快速开始 +- 启动监督微调(SFT)训练: + +```bash +bash sh/sft.sh +``` + +- 说明: + - 脚本会调用仓库内的 `src/mind_openr1/sft.py` 并加载配置(参见 `src/mind_openr1/configs.py`)。 + - 训练日志与权重等产物默认输出到 `trainer_output/` 与 `logs/`(可在脚本或配置中修改)。 + +### 目录结构(节选) +- `sh/`:运行脚本(监督微调:`sft.sh`;如需扩展,可在此目录新增脚本)。 +- `src/mind_openr1/`:核心源码与配置: + - `sft.py`:监督微调入口。 + - `configs.py`:训练与模型/数据相关配置。 + - `rewards.py`、`grpo.py`:与强化学习相关模块(如需扩展可参考此处)。 + - `utils/`:数据、评估、回调、日志等辅助模块。 +- `data/`(如存在):数据相关目录(按需准备)。 +- `logs/`、`trainer_output/`:训练日志与输出目录。 + +### 环境与依赖 +- 建议版本:MindSpore(GPU/Ascend 任一;参考官方安装指南)、MindNLP、Python 3.9+。 +- 请根据本地硬件与驱动选择合适的 MindSpore 发行版;MindNLP 请参考官方文档安装。 +- 额外 Python 依赖可按需在本地安装;如需固定版本,可在本仓库补充 `requirements.txt`。 + +### 数据准备 +- 请准备符合监督微调(SFT)需求的数据集,并在 `configs.py` 或脚本参数中填入数据路径与格式。 +- 如需自定义数据加载/预处理,可在 `src/mind_openr1/utils/data.py` 中扩展,或在 `sft.py` 内接入自定义 Dataset。 + +### 训练配置与输出 +- 训练超参(batch size、学习率、训练步数、保存/评估间隔等)可在 `configs.py` 中修改。 +- 运行中会在 `logs/` 输出日志,在 `trainer_output/`(或你自定义的路径)保存权重/检查点。 + +### 与 TRL 的兼容性说明 +为让 MindNLP 在训练环节尽量对齐/兼容 TRL 与部分 Transformers 训练器行为,本仓库对 `mindnlp/mindtorch` 的若干底层组件做了必要的调整(详见文末“附录:源码改动(为 TRL 兼容所做)”): +- 调整 autograd 接口,支持“仅对张量输入求导、手动回填参数梯度”等场景。 +- 完善模块 Hook 与 `state_dict` 行为,便于与上层训练器/加速器协同。 +- 提供 `autograd.graph` 的最小 API 以兼容现有调用路径。 +- 在 CPU 后端缺失场景提供稳健回退。 +- 在 `Trainer` 的 `training_step` 中对梯度累积/分布式做安全处理以对齐常见训练器行为。 + +如需将这些改动上游化,请以 PR 形式提交至对应上游项目并依据社区规范完善测试用例。 + +--- + +## 附录:源码改动(为 TRL 兼容所做) + +以下内容为相对 `origin/master` 的本地修改汇总,旨在说明为 TRL 兼容所做的变更。 + +### 概览 +- 变更基线:分支 `master` 跟踪 `origin/master`,无额外提交差异;改动均为未提交的本地修改 +- 修改文件与规模(插入/删除): + - `mindnlp/transformers/trainer.py`: +46 / -4 + - `mindtorch/_apis/cpu.py`: +7 / -1 + - `mindtorch/autograd/__init__.py`: +1 / -0 + - `mindtorch/autograd/function.py`: +154 / -126 + - `mindtorch/nn/modules/module.py`: +2393 / -2373 +- 新增文件: + - `mindtorch/autograd/graph.py` + +--- + +### 详细改动与位置 + +#### 1) mindnlp/transformers/trainer.py +- 变更要点: + - 引入 `_mindspore_grad_enabled` 开关。 + - 将原先直接对 `forward_fn` 做 `value_and_grad` 的做法,改为对 `inputs` 张量键进行扁平化,仅以张量参数参与求导,避免将 `dict` 作为求导输入。 + - 使用 `attach_grads=False` 获取梯度后,手动回填到 `param.grad`,并与梯度累积、分布式场景相容。 + - 在不走自定义求导路径时回退为原始 `compute_loss` 流程。 + +- 位置(hunk): + +```diff +@@ -88,9 +88,51 @@ def training_step( + + return loss, loss_true + +- if not hasattr(self, 'grad_fn'): +- self.grad_fn = autograd.value_and_grad(forward_fn, model.trainable_params(), has_aux=True) ++ if not hasattr(self, '_mindspore_grad_enabled'): ++ self._mindspore_grad_enabled = True ++ ++ if self._mindspore_grad_enabled: ++ # 仅传入张量参数,避免将 dict 作为 grad 输入 ++ input_keys = tuple(sorted(k for k, v in inputs.items() if hasattr(v, "shape"))) ++ ++ def forward_fn_flat(*flat_tensors): ++ local_inputs = {} ++ # 重建 inputs,仅包含张量键值;非张量保持原值 ++ for k in inputs: ++ if k in input_keys: ++ # 对应位置映射 ++ idx = input_keys.index(k) ++ local_inputs[k] = flat_tensors[idx] ++ else: ++ local_inputs[k] = inputs[k] ++ ++ with self.compute_loss_context_manager(): ++ loss = self.compute_loss(model, local_inputs, num_items_in_batch=num_items_in_batch) ++ ++ if self.args.n_gpu > 1: ++ loss = loss.mean() + +- loss_scaled, (loss_true,) = self.grad_fn(inputs, num_items_in_batch) ++ if (not self.model_accepts_loss_kwargs or num_items_in_batch is None) and self.compute_loss_func is None: ++ loss = loss / self.current_gradient_accumulation_steps ++ ++ if self.accelerator.distributed_type != DistributedType.DEEPSPEED: ++ loss = loss / self.accelerator.gradient_accumulation_steps ++ ++ return loss, loss ++ ++ weights = model.trainable_params() ++ flat_args = tuple(inputs[k] for k in input_keys) ++ grad_fn = autograd.value_and_grad(forward_fn_flat, weights, has_aux=True, attach_grads=False) ++ (loss_scaled, loss_true), grads = grad_fn(*flat_args) ++ ++ # 回填梯度,供优化器使用 ++ for param, grad in zip(weights, grads): ++ if getattr(param, 'grad', None) is None: ++ param.grad = mindtorch.tensor(grad, device=param.device) ++ else: ++ param.grad += mindtorch.tensor(grad, device=param.device) ++ ++ return loss_true + +- return loss_true ++ loss_scaled = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch) ++ return loss_scaled +``` + +--- + +#### 2) mindtorch/_apis/cpu.py +- 变更要点: + - `empty` 新增 CPU 后端未实现时的回退逻辑,优雅降级为 `numpy.empty` 并封装为 `mindtorch.Tensor`。 + +- 位置(hunk): + +```diff +@@ -9,7 +9,13 @@ from .._op_prim.cpu import legacy + + empty_op = Empty().set_device('CPU') + def empty(size, dtype): +- return empty_op(size, dtype=dtype, device='CPU') ++ try: ++ return empty_op(size, dtype=dtype, device='CPU') ++ except RuntimeError as err: # pragma: no cover - fallback path depends on runtime ++ if 'Not implement' not in str(err): ++ raise ++ # MindSpore 默认 CPU backend 未实现 Empty 原语,退回到 numpy 实现 ++ return mindtorch.Tensor.from_numpy(np.empty(size, mindtorch.dtype2np[dtype])) +``` + +--- + +#### 3) mindtorch/autograd/__init__.py +- 变更要点: + - 新增导出 `saved_tensors_hooks` 与 `current_hooks`,与 PyTorch 接口对齐。 + +- 位置(hunk): + +```diff +@@ -2,3 +2,4 @@ + from .node import Node + from .function import Function, value_and_grad + from .grad_mode import no_grad, enable_grad, inference_mode ++from .graph import saved_tensors_hooks, current_hooks +``` + +--- + +#### 4) mindtorch/autograd/function.py +- 变更要点: + - 重写 `value_and_grad`: + - 当 `attach_grads=True` 时,采用 MindSpore 的 `value_and_grad` 并安全合并到 `param.grad`。 + - 当 `attach_grads=False` 时,显式构建/运行 PyNative 求导图,返回 `(values, grads)`。 + - 对参数集合进行缓存/清零以避免跨次调用的梯度污染。 + +- 位置(hunk,节选): + +```diff +@@ -1,126 +1,154 @@ +-"""functional autograd""" +-... +-def value_and_grad(fn, params_or_argnums, has_aux=False, attach_grads=True): +- grad_fn = mindspore.value_and_grad(fn, None, tuple(params_or_argnums), has_aux) +- if attach_grads: +- def new_grad_fn(*args, **kwargs): +- values, grads = grad_fn(*args, **kwargs) +- for param, grad in zip(params_or_argnums, grads): +- grad = mindtorch.tensor(grad, device=param.device) +- if param.grad is None: +- param.grad = grad +- else: +- param.grad += grad +- return values +- return new_grad_fn +- return grad_fn ++"""functional autograd""" ++... ++def value_and_grad(fn, params_or_argnums, has_aux=False, attach_grads=True): ++ params = tuple(params_or_argnums) ++ # Fast path: let MindSpore wrap gradients when we want autoupdate of .grad ++ if attach_grads: ++ grad_fn = mindspore.value_and_grad(fn, None, params, has_aux) ++ def new_grad_fn(*args, **kwargs): ++ attached_params = getattr(new_grad_fn, 'attached_params', None) ++ if attached_params is not params: ++ if attached_params is not None: ++ for param in attached_params: ++ if param.grad is not None: ++ param.grad = mindtorch.zeros_like(param.grad.detach()) ++ new_grad_fn.attached_params = params ++ values, grads = grad_fn(*args, **kwargs) ++ for param, grad in zip(params, grads): ++ grad = mindtorch.tensor(grad, device=param.device) ++ if param.grad is None: ++ param.grad = grad ++ else: ++ updated_grad = mindtorch.zeros_like(param.grad, device=param.device) ++ updated_grad.copy_(param.grad) ++ updated_grad += grad ++ param.grad = updated_grad ++ return values ++ return new_grad_fn ++ ++ # Stable path for MindSpore PyNative: explicitly build and run grad graph ++ def value_and_grad_f(*args, **kwargs): ++ fn_ = fn ++ _pynative_executor.set_grad_flag(True) ++ _pynative_executor.new_graph(fn_, *args, **kwargs) ++ values = fn_(*args, **kwargs) ++ _pynative_executor.end_graph(fn_, values, *args, **kwargs) ++ ++ run_args = args ++ if kwargs: ++ run_args = args + tuple(kwargs.values()) ++ ++ grads = _pynative_executor.grad(fn_, grad_, params, None, *run_args) ++ return values, grads +``` + +--- + +#### 5) mindtorch/nn/modules/module.py +- 变更要点: + - 大幅对齐/扩展 `Module` API 与全局/实例级 Hook 机制(buffer/parameter/module 注册 hook、forward/forward_pre hooks、backward hooks 等)。 + - 伴随大量内部工具与 `state_dict`/`load_state_dict` 行为增强。 + - 改动范围覆盖全文件(行范围见 hunk 头)。 + +- 位置(hunk 概览): + +```diff +@@ -1,2373 +1,2393 @@ +``` + +- 完整 diff 已保存:`final/mindnlp/open_r1/module.diff` + +--- + +#### 6) 新增:mindtorch/autograd/graph.py +- 目的: + - 提供最小可用的 `saved_tensors_hooks`/`current_hooks` API(线程本地栈维护),与 `torch.autograd.graph` 接口对齐;当前不直接接入 MindSpore 梯度记录流水线,但可安全作为上下文管理器使用。 + +- 文件片段: + +```1:24:/Users/guojialiang/code/kaiyuanzhixia/mindnlp/mindtorch/autograd/graph.py +"""Autograd graph utilities. + +This module provides a minimal ``saved_tensors_hooks`` implementation so that +``mindtorch.autograd.graph`` exposes the same API surface as +``torch.autograd.graph``. The current implementation focuses on API +compatibility and keeps a thread-local stack of the registered hooks. The +stored hooks are not yet wired into MindSpore's gradient recording pipeline, +but existing code can safely enter/exit the context manager without raising +``ImportError``. +""" + +from __future__ import annotations + +from contextlib import ContextDecorator +from typing import Any, Callable, List, Optional, Tuple +import threading + +PackHook = Callable[[Any], Any] +UnpackHook = Callable[[Any], Any] +``` + +--- + +### 复现实验/对比方法 +- 列出本地相对上游差异(文件与统计): + - `git diff --stat origin/master...master` +- 查看具体文件差异: + - `git diff -- mindnlp/transformers/trainer.py` + - `git diff -- mindtorch/_apis/cpu.py` + - `git diff -- mindtorch/autograd/__init__.py` + - `git diff -- mindtorch/autograd/function.py` + - `git diff -- mindtorch/nn/modules/module.py` +- 导出 `module.py` 全量 diff(已导出): + - `git diff --no-color -- mindtorch/nn/modules/module.py > final/mindnlp/open_r1/module.diff` From 5f3ef1f6483f0e998fdeb7f5ca5485f2c03635fe Mon Sep 17 00:00:00 2001 From: guojialiang <2802427218@qq.com> Date: Wed, 8 Oct 2025 14:17:18 +0800 Subject: [PATCH 6/7] add:readme --- open_r1/readme.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/open_r1/readme.md b/open_r1/readme.md index 20c4a7fbf..006cb288f 100644 --- a/open_r1/readme.md +++ b/open_r1/readme.md @@ -20,7 +20,7 @@ bash sh/sft.sh - `src/mind_openr1/`:核心源码与配置: - `sft.py`:监督微调入口。 - `configs.py`:训练与模型/数据相关配置。 - - `rewards.py`、`grpo.py`:与强化学习相关模块(如需扩展可参考此处)。 + - `rewards.py`、`grpo.py`:与强化学习相关模块。 - `utils/`:数据、评估、回调、日志等辅助模块。 - `data/`(如存在):数据相关目录(按需准备)。 - `logs/`、`trainer_output/`:训练日志与输出目录。 From 134365df714117673dbbcba589e668df1952d7ba Mon Sep 17 00:00:00 2001 From: guojialiang <2802427218@qq.com> Date: Wed, 8 Oct 2025 14:18:59 +0800 Subject: [PATCH 7/7] add:readme --- open_r1/readme.md | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/open_r1/readme.md b/open_r1/readme.md index 006cb288f..9a104a972 100644 --- a/open_r1/readme.md +++ b/open_r1/readme.md @@ -26,9 +26,8 @@ bash sh/sft.sh - `logs/`、`trainer_output/`:训练日志与输出目录。 ### 环境与依赖 -- 建议版本:MindSpore(GPU/Ascend 任一;参考官方安装指南)、MindNLP、Python 3.9+。 -- 请根据本地硬件与驱动选择合适的 MindSpore 发行版;MindNLP 请参考官方文档安装。 -- 额外 Python 依赖可按需在本地安装;如需固定版本,可在本仓库补充 `requirements.txt`。 +- 建议版本:MindSpore 2.6、MindNLP 0.5.0rc2、Python 3.10+。 +- MindNLP 请参考官方文档安装。 ### 数据准备 - 请准备符合监督微调(SFT)需求的数据集,并在 `configs.py` 或脚本参数中填入数据路径与格式。 @@ -46,7 +45,6 @@ bash sh/sft.sh - 在 CPU 后端缺失场景提供稳健回退。 - 在 `Trainer` 的 `training_step` 中对梯度累积/分布式做安全处理以对齐常见训练器行为。 -如需将这些改动上游化,请以 PR 形式提交至对应上游项目并依据社区规范完善测试用例。 ---