Skip to content

Commit 195f61b

Browse files
authored
[feat] add fsdp2 to fsdp_sft_trainer (#1713)
### Checklist Before Starting - [x] Search for similar PR(s). ### What does this PR do? Add fsdp2 to fsdp_sft_trainer. Resolve issue #1504. ### High-Level Design Refer to the implementation of #1026. ### Usage Example ```python model.strategy=fsdp2 ``` ### Test <img width="1095" alt="image" src="https://github.com/user-attachments/assets/1f70db1c-9ac3-448e-abca-fd302480f0c7" /> ### Additional Info. - **Issue Number**: #1504 - **Training**: [Note which backend this PR will affect: FSDP] ### Checklist Before Submitting - [x] Read the [Contribute Guide](https://github.com/volcengine/verl?tab=readme-ov-file#contribution-guide). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl?tab=readme-ov-file#code-linting-and-formatting). - [ ] Add `[BREAKING]` to the PR title if it breaks any API. - [ ] Update the documentation about your changes in the [docs](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add CI test(s) if necessary.
1 parent 7853292 commit 195f61b

File tree

2 files changed

+87
-27
lines changed

2 files changed

+87
-27
lines changed

verl/trainer/config/sft_trainer.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ model:
3636
lora_alpha: 16 # LoRA scaling factor
3737
target_modules: all-linear # Target modules for LoRA adaptation
3838
use_liger: False
39+
strategy: fsdp2
3940
optim:
4041
lr: 1e-5
4142
betas: [0.9, 0.95]

verl/trainer/fsdp_sft_trainer.py

Lines changed: 86 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,16 @@
4646
from verl.utils.debug import log_gpu_memory_usage
4747
from verl.utils.distributed import initialize_global_process_group
4848
from verl.utils.fs import copy_to_local
49-
from verl.utils.fsdp_utils import get_fsdp_wrap_policy, get_init_weight_context_manager, init_fn
49+
from verl.utils.fsdp_utils import (
50+
CPUOffloadPolicy,
51+
MixedPrecisionPolicy,
52+
apply_fsdp2,
53+
fsdp2_load_full_state_dict,
54+
get_fsdp_wrap_policy,
55+
get_init_weight_context_manager,
56+
init_fn,
57+
fsdp2_clip_grad_norm_
58+
)
5059
from verl.utils.torch_functional import get_cosine_schedule_with_warmup, get_wsd_schedule_with_warmup
5160
from verl.utils.py_functional import convert_to_regular_types
5261
from verl.utils.tracking import Tracking
@@ -173,6 +182,7 @@ def _build_model_optimizer(self):
173182
trust_remote_code = self.config.model.trust_remote_code
174183
# load config first
175184
config = AutoConfig.from_pretrained(local_model_path, trust_remote_code=trust_remote_code)
185+
self.model_config = config
176186
if self.config.ulysses_sequence_parallel_size > 1:
177187
assert self.use_remove_padding, "Sequence parallel is only supported when remove_padding is enabled"
178188

@@ -231,18 +241,38 @@ def _build_model_optimizer(self):
231241
else:
232242
cpu_offload = CPUOffload(offload_params=self.config.model.fsdp_config.offload_params)
233243

234-
self.fsdp_model = FSDP(
235-
module=self.model,
236-
auto_wrap_policy=auto_wrap_policy,
237-
param_init_fn=init_fn,
238-
sharding_strategy=ShardingStrategy.FULL_SHARD,
239-
mixed_precision=mixed_precision,
240-
device_mesh=self.device_mesh,
241-
sync_module_states=True,
242-
device_id=get_torch_device().current_device(),
243-
cpu_offload=cpu_offload,
244-
use_orig_params=False,
245-
)
244+
fsdp_strategy = self.config.model.strategy
245+
if fsdp_strategy == "fsdp":
246+
self.fsdp_model = FSDP(
247+
self.model,
248+
cpu_offload=cpu_offload,
249+
param_init_fn=init_fn,
250+
use_orig_params=False,
251+
auto_wrap_policy=auto_wrap_policy,
252+
device_id=get_torch_device().current_device(),
253+
sharding_strategy=ShardingStrategy.FULL_SHARD,
254+
mixed_precision=mixed_precision,
255+
sync_module_states=True,
256+
device_mesh=self.device_mesh,
257+
forward_prefetch=False,
258+
)
259+
elif fsdp_strategy == "fsdp2":
260+
assert CPUOffloadPolicy is not None, "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)"
261+
mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.float32,
262+
cast_forward_inputs=True)
263+
264+
fsdp_kwargs = {
265+
"mesh": self.device_mesh,
266+
"mp_policy": mp_policy,
267+
"offload_policy": cpu_offload,
268+
"reshard_after_forward": True,
269+
}
270+
full_state = self.model.state_dict()
271+
apply_fsdp2(self.model, fsdp_kwargs, self.config.model.fsdp_config)
272+
fsdp2_load_full_state_dict(self.model, full_state, self.device_mesh, cpu_offload)
273+
self.fsdp_model = self.model
274+
else:
275+
raise NotImplementedError(f"not implement {fsdp_strategy}")
246276

247277
log_gpu_memory_usage("After FSDP wrapping", logger=logger)
248278

@@ -373,7 +403,12 @@ def training_step(self, batch: TensorDict):
373403
loss = self._compute_loss_and_backward(batch=micro_batch) / n_micro_batches
374404
step_loss += loss.item()
375405

376-
grad_norm = self.fsdp_model.clip_grad_norm_(max_norm=self.config.optim.clip_grad)
406+
if self.config.model.strategy == 'fsdp':
407+
grad_norm = self.fsdp_model.clip_grad_norm_(max_norm=self.config.optim.clip_grad)
408+
elif self.config.model.strategy == 'fsdp2':
409+
grad_norm = fsdp2_clip_grad_norm_(self.fsdp_model.parameters(), max_norm=self.config.optim.clip_grad)
410+
else:
411+
raise NotImplementedError(f"not implement {self.config.model.strategy}")
377412

378413
log_gpu_memory_usage("Before optimizer step", logger=logger)
379414

@@ -414,21 +449,44 @@ def validation_step(self, batch: TensorDict):
414449

415450
def save_checkpoint(self, step):
416451
# save checkpoint
417-
from torch.distributed.fsdp import FullStateDictConfig, StateDictType
452+
path = os.path.join(self.config.trainer.default_local_dir, f"global_step_{step}")
418453

419-
cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
420-
with FSDP.state_dict_type(self.fsdp_model, StateDictType.FULL_STATE_DICT, cfg):
421-
state_dict = self.fsdp_model.state_dict()
454+
fsdp_strategy = self.config.model.strategy
455+
if fsdp_strategy == "fsdp":
456+
# FSDP1 checkpoint saving
457+
from torch.distributed.fsdp import FullStateDictConfig, StateDictType
458+
459+
cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
460+
with FSDP.state_dict_type(self.fsdp_model, StateDictType.FULL_STATE_DICT, cfg):
461+
state_dict = self.fsdp_model.state_dict()
462+
463+
# save huggingface model
464+
if self.device_mesh.get_rank() == 0:
465+
os.makedirs(path, exist_ok=True)
466+
self.model.save_pretrained(path, state_dict=state_dict)
467+
self.tokenizer.save_pretrained(path)
468+
elif fsdp_strategy == "fsdp2":
469+
# FSDP2 checkpoint saving
470+
from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict
471+
472+
# Get full state dict with FSDP2
473+
options = StateDictOptions(full_state_dict=True, cpu_offload=True)
474+
state_dict = get_model_state_dict(self.fsdp_model, options=options)
475+
476+
# save huggingface model
477+
if self.device_mesh.get_rank() == 0:
478+
os.makedirs(path, exist_ok=True)
479+
self.model.save_pretrained(path, state_dict=state_dict)
480+
self.model_config.save_pretrained(path)
481+
self.tokenizer.save_pretrained(path)
482+
else:
483+
raise NotImplementedError(f"not implement {fsdp_strategy}")
484+
485+
# Copy to HDFS if configured
486+
if self.device_mesh.get_rank() == 0 and self.config.trainer.default_hdfs_dir:
487+
hdfs_io.makedirs(self.config.trainer.default_hdfs_dir, exist_ok=True)
488+
hdfs_io.copy(src=path, dst=self.config.trainer.default_hdfs_dir, dirs_exist_ok=True)
422489

423-
path = os.path.join(self.config.trainer.default_local_dir, f"global_step_{step}")
424-
# save huggingface model
425-
if self.device_mesh.get_rank() == 0:
426-
os.makedirs(path, exist_ok=True)
427-
self.model.save_pretrained(path, state_dict=state_dict)
428-
self.tokenizer.save_pretrained(path)
429-
if self.config.trainer.default_hdfs_dir:
430-
hdfs_io.makedirs(self.config.trainer.default_hdfs_dir, exist_ok=True)
431-
hdfs_io.copy(src=path, dst=self.config.trainer.default_hdfs_dir, dirs_exist_ok=True)
432490
torch.distributed.barrier()
433491

434492
def fit(self):
@@ -462,6 +520,7 @@ def fit(self):
462520
self.train_dataloader,
463521
total=self.steps_per_epoch,
464522
desc=f"Epoch {epoch + 1}/{self.config.trainer.total_epochs}",
523+
disable=rank != 0
465524
):
466525
global_step += 1
467526
data = TensorDict(data, batch_size=self.config.data.train_batch_size).to(self.device_name)

0 commit comments

Comments
 (0)