|
46 | 46 | from verl.utils.debug import log_gpu_memory_usage
|
47 | 47 | from verl.utils.distributed import initialize_global_process_group
|
48 | 48 | from verl.utils.fs import copy_to_local
|
49 |
| -from verl.utils.fsdp_utils import get_fsdp_wrap_policy, get_init_weight_context_manager, init_fn |
| 49 | +from verl.utils.fsdp_utils import ( |
| 50 | + CPUOffloadPolicy, |
| 51 | + MixedPrecisionPolicy, |
| 52 | + apply_fsdp2, |
| 53 | + fsdp2_load_full_state_dict, |
| 54 | + get_fsdp_wrap_policy, |
| 55 | + get_init_weight_context_manager, |
| 56 | + init_fn, |
| 57 | + fsdp2_clip_grad_norm_ |
| 58 | +) |
50 | 59 | from verl.utils.torch_functional import get_cosine_schedule_with_warmup, get_wsd_schedule_with_warmup
|
51 | 60 | from verl.utils.py_functional import convert_to_regular_types
|
52 | 61 | from verl.utils.tracking import Tracking
|
@@ -173,6 +182,7 @@ def _build_model_optimizer(self):
|
173 | 182 | trust_remote_code = self.config.model.trust_remote_code
|
174 | 183 | # load config first
|
175 | 184 | config = AutoConfig.from_pretrained(local_model_path, trust_remote_code=trust_remote_code)
|
| 185 | + self.model_config = config |
176 | 186 | if self.config.ulysses_sequence_parallel_size > 1:
|
177 | 187 | assert self.use_remove_padding, "Sequence parallel is only supported when remove_padding is enabled"
|
178 | 188 |
|
@@ -231,18 +241,38 @@ def _build_model_optimizer(self):
|
231 | 241 | else:
|
232 | 242 | cpu_offload = CPUOffload(offload_params=self.config.model.fsdp_config.offload_params)
|
233 | 243 |
|
234 |
| - self.fsdp_model = FSDP( |
235 |
| - module=self.model, |
236 |
| - auto_wrap_policy=auto_wrap_policy, |
237 |
| - param_init_fn=init_fn, |
238 |
| - sharding_strategy=ShardingStrategy.FULL_SHARD, |
239 |
| - mixed_precision=mixed_precision, |
240 |
| - device_mesh=self.device_mesh, |
241 |
| - sync_module_states=True, |
242 |
| - device_id=get_torch_device().current_device(), |
243 |
| - cpu_offload=cpu_offload, |
244 |
| - use_orig_params=False, |
245 |
| - ) |
| 244 | + fsdp_strategy = self.config.model.strategy |
| 245 | + if fsdp_strategy == "fsdp": |
| 246 | + self.fsdp_model = FSDP( |
| 247 | + self.model, |
| 248 | + cpu_offload=cpu_offload, |
| 249 | + param_init_fn=init_fn, |
| 250 | + use_orig_params=False, |
| 251 | + auto_wrap_policy=auto_wrap_policy, |
| 252 | + device_id=get_torch_device().current_device(), |
| 253 | + sharding_strategy=ShardingStrategy.FULL_SHARD, |
| 254 | + mixed_precision=mixed_precision, |
| 255 | + sync_module_states=True, |
| 256 | + device_mesh=self.device_mesh, |
| 257 | + forward_prefetch=False, |
| 258 | + ) |
| 259 | + elif fsdp_strategy == "fsdp2": |
| 260 | + assert CPUOffloadPolicy is not None, "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)" |
| 261 | + mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.float32, |
| 262 | + cast_forward_inputs=True) |
| 263 | + |
| 264 | + fsdp_kwargs = { |
| 265 | + "mesh": self.device_mesh, |
| 266 | + "mp_policy": mp_policy, |
| 267 | + "offload_policy": cpu_offload, |
| 268 | + "reshard_after_forward": True, |
| 269 | + } |
| 270 | + full_state = self.model.state_dict() |
| 271 | + apply_fsdp2(self.model, fsdp_kwargs, self.config.model.fsdp_config) |
| 272 | + fsdp2_load_full_state_dict(self.model, full_state, self.device_mesh, cpu_offload) |
| 273 | + self.fsdp_model = self.model |
| 274 | + else: |
| 275 | + raise NotImplementedError(f"not implement {fsdp_strategy}") |
246 | 276 |
|
247 | 277 | log_gpu_memory_usage("After FSDP wrapping", logger=logger)
|
248 | 278 |
|
@@ -373,7 +403,12 @@ def training_step(self, batch: TensorDict):
|
373 | 403 | loss = self._compute_loss_and_backward(batch=micro_batch) / n_micro_batches
|
374 | 404 | step_loss += loss.item()
|
375 | 405 |
|
376 |
| - grad_norm = self.fsdp_model.clip_grad_norm_(max_norm=self.config.optim.clip_grad) |
| 406 | + if self.config.model.strategy == 'fsdp': |
| 407 | + grad_norm = self.fsdp_model.clip_grad_norm_(max_norm=self.config.optim.clip_grad) |
| 408 | + elif self.config.model.strategy == 'fsdp2': |
| 409 | + grad_norm = fsdp2_clip_grad_norm_(self.fsdp_model.parameters(), max_norm=self.config.optim.clip_grad) |
| 410 | + else: |
| 411 | + raise NotImplementedError(f"not implement {self.config.model.strategy}") |
377 | 412 |
|
378 | 413 | log_gpu_memory_usage("Before optimizer step", logger=logger)
|
379 | 414 |
|
@@ -414,21 +449,44 @@ def validation_step(self, batch: TensorDict):
|
414 | 449 |
|
415 | 450 | def save_checkpoint(self, step):
|
416 | 451 | # save checkpoint
|
417 |
| - from torch.distributed.fsdp import FullStateDictConfig, StateDictType |
| 452 | + path = os.path.join(self.config.trainer.default_local_dir, f"global_step_{step}") |
418 | 453 |
|
419 |
| - cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) |
420 |
| - with FSDP.state_dict_type(self.fsdp_model, StateDictType.FULL_STATE_DICT, cfg): |
421 |
| - state_dict = self.fsdp_model.state_dict() |
| 454 | + fsdp_strategy = self.config.model.strategy |
| 455 | + if fsdp_strategy == "fsdp": |
| 456 | + # FSDP1 checkpoint saving |
| 457 | + from torch.distributed.fsdp import FullStateDictConfig, StateDictType |
| 458 | + |
| 459 | + cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) |
| 460 | + with FSDP.state_dict_type(self.fsdp_model, StateDictType.FULL_STATE_DICT, cfg): |
| 461 | + state_dict = self.fsdp_model.state_dict() |
| 462 | + |
| 463 | + # save huggingface model |
| 464 | + if self.device_mesh.get_rank() == 0: |
| 465 | + os.makedirs(path, exist_ok=True) |
| 466 | + self.model.save_pretrained(path, state_dict=state_dict) |
| 467 | + self.tokenizer.save_pretrained(path) |
| 468 | + elif fsdp_strategy == "fsdp2": |
| 469 | + # FSDP2 checkpoint saving |
| 470 | + from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict |
| 471 | + |
| 472 | + # Get full state dict with FSDP2 |
| 473 | + options = StateDictOptions(full_state_dict=True, cpu_offload=True) |
| 474 | + state_dict = get_model_state_dict(self.fsdp_model, options=options) |
| 475 | + |
| 476 | + # save huggingface model |
| 477 | + if self.device_mesh.get_rank() == 0: |
| 478 | + os.makedirs(path, exist_ok=True) |
| 479 | + self.model.save_pretrained(path, state_dict=state_dict) |
| 480 | + self.model_config.save_pretrained(path) |
| 481 | + self.tokenizer.save_pretrained(path) |
| 482 | + else: |
| 483 | + raise NotImplementedError(f"not implement {fsdp_strategy}") |
| 484 | + |
| 485 | + # Copy to HDFS if configured |
| 486 | + if self.device_mesh.get_rank() == 0 and self.config.trainer.default_hdfs_dir: |
| 487 | + hdfs_io.makedirs(self.config.trainer.default_hdfs_dir, exist_ok=True) |
| 488 | + hdfs_io.copy(src=path, dst=self.config.trainer.default_hdfs_dir, dirs_exist_ok=True) |
422 | 489 |
|
423 |
| - path = os.path.join(self.config.trainer.default_local_dir, f"global_step_{step}") |
424 |
| - # save huggingface model |
425 |
| - if self.device_mesh.get_rank() == 0: |
426 |
| - os.makedirs(path, exist_ok=True) |
427 |
| - self.model.save_pretrained(path, state_dict=state_dict) |
428 |
| - self.tokenizer.save_pretrained(path) |
429 |
| - if self.config.trainer.default_hdfs_dir: |
430 |
| - hdfs_io.makedirs(self.config.trainer.default_hdfs_dir, exist_ok=True) |
431 |
| - hdfs_io.copy(src=path, dst=self.config.trainer.default_hdfs_dir, dirs_exist_ok=True) |
432 | 490 | torch.distributed.barrier()
|
433 | 491 |
|
434 | 492 | def fit(self):
|
@@ -462,6 +520,7 @@ def fit(self):
|
462 | 520 | self.train_dataloader,
|
463 | 521 | total=self.steps_per_epoch,
|
464 | 522 | desc=f"Epoch {epoch + 1}/{self.config.trainer.total_epochs}",
|
| 523 | + disable=rank != 0 |
465 | 524 | ):
|
466 | 525 | global_step += 1
|
467 | 526 | data = TensorDict(data, batch_size=self.config.data.train_batch_size).to(self.device_name)
|
|
0 commit comments