Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
212 changes: 210 additions & 2 deletions swift/megatron/argument/megatron_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@
from transformers.utils.versions import require_version

from swift.llm.argument.base_args import to_abspath
from swift.utils import get_dist_setting, get_logger, json_parse_to_dict
from swift.utils import get_current_device, get_dist_setting, get_logger, is_master, json_parse_to_dict

logger = get_logger()


@dataclass
class RLHFMegatronArgumentsMixin:
perform_initialization: bool = True
rlhf_type: Literal['dpo', 'grpo'] = 'dpo'
ref_load: Optional[str] = None
ref_adapter_load: Optional[str] = None

Expand All @@ -27,6 +29,211 @@ class RLHFMegatronArgumentsMixin:
f_divergence_type: str = 'reverse_kl'
loss_type: str = 'sigmoid'

# =========================== GRPO ===========================
generation_batch_size: Optional[int] = None
steps_per_generation: Optional[int] = None
num_generations: int = 8
max_completion_length: int = 512

# ─────────────────────────── Sampling ───────────────────────────
epsilon: float = 0.2
epsilon_high: Optional[float] = None
delta: Optional[float] = None
top_k: int = 50
top_p: float = 0.9
repetition_penalty: float = 1.
# ─────────────────────────── VLLM ───────────────────────────
use_vllm: bool = False
vllm_mode: Literal['server', 'colocate'] = 'colocate'
# ────────────── Internal VLLM (colocate) ──────────────
vllm_enable_prefix_caching: bool = True
vllm_gpu_memory_utilization: float = 0.9
vllm_tensor_parallel_size: int = 1
vllm_max_model_len: Optional[int] = None
vllm_enforce_eager: bool = False
vllm_limit_mm_per_prompt: Optional[Union[dict, str]] = None # '{"image": 5, "video": 2}'
vllm_disable_cascade_attn: bool = False
sleep_level: Literal[0, 1, 2] = 0

# ────────────── External VLLM (server) ──────────────
vllm_server_base_url: Optional[List[str]] = None
vllm_server_host: Optional[List[str]] = None
vllm_server_port: List[int] = field(default_factory=lambda: [8000])
vllm_server_timeout: float = 240.0
vllm_client: Optional[object] = field(init=False, default=None)

# ─────────────────────────── Reward ───────────────────────────
reward_funcs: List[str] = field(default_factory=list)
reward_weights: List[float] = None
# see details in swift/plugin/orm.py
# cosine reward, https://arxiv.org/abs/2502.03373
cosine_min_len_value_wrong: float = -0.5 # r^w_0 in paper, Reward for wrong answers with zero completion length.
cosine_max_len_value_wrong: float = 0.0 # r^w_L in paper, Reward for wrong answers with max completion length.
cosine_min_len_value_correct: float = 1.0 # r^c_0 in paper, Reward for correct answers with zero completion length.
cosine_max_len_value_correct: float = 0.5 # r^c_L in paper, Reward for correct answers with max completion length.
cosine_max_len: Optional[int] = None # Lmax in paper, default equal to max_completion_length
# repetition penalty, https://arxiv.org/abs/2502.03373
repetition_n_grams: int = 3
repetition_max_penalty: float = -1.0
# soft_overlong, https://arxiv.org/abs/2503.14476
soft_max_length: Optional[int] = None
soft_cache_length: Optional[int] = None

reward_model: Optional[List[str]] = None
reward_model_plugin: Optional[List[str]] = None

# ─────────────────────────── Not Supported Yet ───────────────────────────
# reward model
reward_model: Optional[List[str]] = None
reward_model_plugin: Optional[List[str]] = None
# sync ref model
sync_ref_model: bool = False
ref_model_sync_steps: int = 512
ref_model_mixup_alpha: float = 0.6

async_generate: bool = False

move_model_batches: Optional[int] = None
offload_optimizer: bool = False
offload_model: bool = False
gc_collect_after_offload: bool = False # deprecated

# multi turn
multi_turn_func: Optional[str] = None # deprecated
multi_turn_scheduler: Optional[str] = None
max_turns: Optional[int] = None
completion_length_limit_scope: Literal['total', 'per_round'] = 'per_round'
vllm_server_pass_dataset: bool = False

# DAPO, https://arxiv.org/abs/2503.14476
dynamic_sample: bool = False
max_resample_times: int = 3
overlong_filter: bool = False

# Dr. GRPO, https://arxiv.org/abs/2503.20783
scale_rewards: bool = True

# entropy
log_entropy: bool = False
# Beyond the 80/20 Rule, https://arxiv.org/abs/2506.01939
top_entropy_quantile: float = 1.0

# GSPO https://www.arxiv.org/abs/2507.18071
importance_sampling_level: Literal['token', 'sequence', 'sequence_token'] = 'token'

wandb_log_unique_prompts: Optional[bool] = None
num_iterations: int = 1

# dataset
dataset_shuffle: Optional[bool] = True

def __post_init__(self):
if self.rlhf_type == 'grpo':
self._init_grpo()
super().__post_init__()
if self.rlhf_type == 'grpo':
self._set_grpo_default()

def _set_grpo_default(self):
if self.beta is None:
self.beta = 0.04 # https://arxiv.org/abs/2402.03300

def _init_grpo(self):

def _init_external_vllm():
if self.rlhf_type != 'grpo' or (self.vllm_server_host is None and self.vllm_server_base_url is None):
return
from swift.trainers.rlhf_trainer.vllm_client import VLLMClient
if is_master():
logger.info('Start connecting to vLLM server')
self.vllm_client = VLLMClient(
base_urls=self.vllm_server_base_url,
hosts=self.vllm_server_host,
server_ports=self.vllm_server_port,
connection_timeout=self.vllm_server_timeout)
self.vllm_client.init_communicator(device=get_current_device())
logger.info('Connected to vLLM server')

def _check_not_supported():
# TODO: check
# bool
not_supported_args = [
'sync_ref_model',
'async_generate',
]
for arg in not_supported_args:
if getattr(self, arg):
raise ValueError(f'{arg} is not supported for Megatron-GRPO yet, please unset it.')
# else
if self.num_iterations > 1:
raise ValueError('num_iterations > 1 is not supported for Megatron-GRPO yet, please set it to 1.')

def _check_batch_params():
# assert self.micro_batch_size % self.num_generations == 0, \
# f'micro_batch_size ({self.micro_batch_size}) must be divisible' \
# f' by the number of generations ({self.num_generations})'
if self.generation_batch_size is None and self.steps_per_generation is None:
self.steps_per_generation = 1
self.generation_batch_size = self.global_batch_size * self.steps_per_generation
elif self.generation_batch_size is not None and self.steps_per_generation is None:
# Just ensure the value is divisible by the global batch size
if self.generation_batch_size % self.global_batch_size != 0:
raise ValueError(f'generation_batch_size ({self.generation_batch_size}) '
f'must be divisible by the global batch size ({self.global_batch_size}).')
self.steps_per_generation = self.generation_batch_size // self.global_batch_size
elif self.generation_batch_size is None and self.steps_per_generation is not None:
self.generation_batch_size = self.global_batch_size * self.steps_per_generation
else:
raise ValueError(
"'generation_batch_size' and 'steps_per_generation' can not be both configured at the same time")
world_size = torch.distributed.get_world_size()
assert self.generation_batch_size % world_size == 0, \
f'generation_batch_size ({self.generation_batch_size}) ' \
f'must be divisible by the world size ({world_size})'
self.per_device_generation_batch_size = self.generation_batch_size // world_size

_init_external_vllm()
_check_not_supported()
_check_batch_params()
# default loss_type if no loss_type is provided
if self.loss_type == 'sigmoid':
self.loss_type = 'grpo'
assert self.loss_type in ['grpo', 'bnpo', 'dr_grpo'], \
f'loss_type must be one of [grpo, bnpo, dr_grpo], but got {self.loss_type}'
if self.async_generate or not self.use_vllm:
self.sleep_level = 0
self.remove_unused_columns = False
logger.info(f'Setting args.remove_unused_columns: {self.remove_unused_columns}')
if self.truncation_strategy is None:
self.truncation_strategy = 'left'
assert self.truncation_strategy in ['left', 'delete'
], ("GRPO requires `truncation_strategy 'left' or 'delete'`, "
f"Current value: `truncation_strategy='{self.truncation_strategy}'`."
) # noqa
if self.beta is None:
self.beta = 0.04 # https://arxiv.org/abs/2402.03300
if self.async_generate:
logger.info('Using async mode. This is a approximate version which '
'will use the old weights to generate responses to accelerate. '
'This will ignore the `CLIP` of advantages, if you found the training '
'is unstable, you may consider using --async_generate false.')
if 'soft_overlong' in self.reward_funcs:
assert self.soft_cache_length is not None, \
'The soft_cache_length must be set when using soft overlong rewards.'
if self.soft_max_length is None:
self.soft_max_length = self.max_completion_length
logger.info(f'Auto-configured soft_max_length = max_completion_length {self.max_completion_length}')
if self.use_vllm:
# set vllm mode
if self.vllm_server_host is not None or self.vllm_server_base_url is not None:
if self.vllm_mode != 'server':
self.vllm_mode = 'server'
logger.warning('set vllm_mode to `server` since vllm server host/base_url is provided')
else:
if self.vllm_mode != 'colocate':
self.vllm_mode = 'colocate'
logger.warning('set vllm_mode to `colocate` since vllm_server_host is not provided')


@dataclass
class MegatronTunerMixin:
Expand Down Expand Up @@ -156,6 +363,7 @@ class MegatronArguments(ExtraMegatronArguments):
dataloader_type: Literal['single', 'cyclic', 'external'] = 'cyclic'
manual_gc: bool = False
manual_gc_interval: int = 0
use_mbridge: bool = False

# learning rate
lr: Optional[float] = None
Expand Down Expand Up @@ -184,7 +392,7 @@ class MegatronArguments(ExtraMegatronArguments):
no_load_rng: bool = False
finetune: bool = False
ckpt_format: Literal['torch', 'torch_dist', 'zarr'] = 'torch_dist'
no_initialization: bool = True
no_initialization: bool = False
auto_detect_ckpt_format: bool = True
exit_on_missing_checkpoint: bool = True

Expand Down
1 change: 0 additions & 1 deletion swift/megatron/argument/rlhf_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

@dataclass
class MegatronRLHFArguments(MegatronTrainArguments):
rlhf_type: Literal['dpo'] = 'dpo'
loss_scale: str = 'last_round'

calculate_per_token_loss: bool = False
3 changes: 2 additions & 1 deletion swift/megatron/argument/train_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from swift.llm.argument.base_args import to_abspath
from swift.utils import add_version_to_work_dir, get_logger, init_process_group, is_master
from ..model import get_megatron_model_meta
from .megatron_args import MegatronArguments
from .megatron_args import MegatronArguments, RLHFMegatronArgumentsMixin

logger = get_logger()

Expand All @@ -29,6 +29,7 @@ def init_model_args(self, tokenizer, config):
if getattr(self, k) is None:
setattr(self, k, v)
MegatronArguments.__post_init__(self)
RLHFMegatronArgumentsMixin.__post_init__(self)
self.extra_args = self.parse_to_megatron()
self.extra_args['model_info'] = self.model_info
self.extra_args['model_meta'] = self.model_meta
Expand Down
14 changes: 12 additions & 2 deletions swift/megatron/train/rlhf.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import List, Optional, Union

from swift.trainers.rlhf_trainer.utils import identity_data_collator
from swift.utils import get_logger
from ..argument import MegatronRLHFArguments
from ..trainers import MegatronDPOTrainer
from ..trainers import MegatronDPOTrainer, MegatronGRPOTrainer
from .sft import MegatronSft

logger = get_logger()
Expand All @@ -17,13 +18,22 @@ def prepare_trainer(self):
args = self.args
if args.rlhf_type == 'dpo':
trainer_cls = MegatronDPOTrainer
elif args.rlhf_type == 'grpo':
trainer_cls = MegatronGRPOTrainer
else:
raise ValueError(f'The current Megatron-SWIFT does not support rlhf_type: {args.rlhf_type}.')
return trainer_cls(args, self.template)

def _prepare_template(self) -> None:
super()._prepare_template()
self.template.set_mode('rlhf')
model_mapping = {'grpo': 'train'}
self.template.set_mode(model_mapping.get(self.args.rlhf_type, 'rlhf'))

def _get_data_collator(self):
if self.args.rlhf_type == 'grpo':
super()._get_data_collator()
return identity_data_collator
return super()._get_data_collator()


def megatron_rlhf_main(args: Optional[Union[List[str], MegatronRLHFArguments]] = None):
Expand Down
1 change: 1 addition & 0 deletions swift/megatron/trainers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from .dpo_trainer import MegatronDPOTrainer
from .grpo_trainer import MegatronGRPOTrainer
from .trainer import MegatronTrainer
4 changes: 2 additions & 2 deletions swift/megatron/trainers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from megatron.training.utils import reduce_max_stat_across_model_parallel_group, report_memory
from packaging import version

from swift.llm import dynamic_gradient_checkpointing
from swift.llm import Template, dynamic_gradient_checkpointing
from swift.plugin import MeanMetric
from swift.trainers import SwiftMixin
from swift.utils import JsonlWriter, deep_getattr, format_time, get_logger
Expand All @@ -39,7 +39,7 @@

class BaseMegatronTrainer(ABC):

def __init__(self, args, template):
def __init__(self, args, template: Template):
self.args = args
self.template = template
self.stimer = StragglerDetector()
Expand Down
Loading