diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 3d7a5efaca..247f186af1 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -14,6 +14,7 @@ from transformer_engine.common.recipe import Recipe from .base import ( get_multi_stream_cublas_workspace, + get_dummy_wgrad, TransformerEngineBaseModule, _2X_ACC_FPROP, _2X_ACC_DGRAD, @@ -80,6 +81,7 @@ def forward( module, skip_fp8_weight_update, save_original_input, + fine_grained_activation_offloading, *weights_and_biases, ) -> torch.Tensor: # pylint: disable=missing-function-docstring @@ -209,6 +211,30 @@ def forward( if isinstance(weight, QuantizedTensorBase): weight.update_usage(columnwise_usage=True) + for i in range(num_gemms): + weights[i].offloading_activation = False + weights_fp8[i].offloading_activation = False + biases[i].offloading_activation = False + ctx.fine_grained_activation_offloading = fine_grained_activation_offloading + + if fine_grained_activation_offloading and cpu_offloading: + raise ValueError( + f"Do not use fine_grained_activation_offloading and cpu_offloading at the same" + f" time." + ) + + if ( + fine_grained_activation_offloading + and weights[0].requires_grad + and fuse_wgrad_accumulation + ): + grad_added_to_main_grad_list = [] + for weight in weights: + if weight.requires_grad and hasattr(weight, "grad_added_to_main_grad"): + grad_added_to_main_grad_list.append(weight.grad_added_to_main_grad) + weight.grad_added_to_main_grad = True + ctx.grad_added_to_main_grad_list = grad_added_to_main_grad_list + tensors_to_save, tensor_objects = prepare_for_saving( *inputmats, *weights_fp8, @@ -271,11 +297,15 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], biases = saved_tensors[3 * N : 4 * N] main_grads = [main_grad_func() for main_grad_func in ctx.main_grad_funcs] - if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: + if ( + ctx.cpu_offloading or ctx.fine_grained_activation_offloading + ) and ctx.fuse_wgrad_accumulation: for i in range(ctx.num_gemms): - w = torch.nn.Parameter(weights[i], weights[i].requires_grad) - w.main_grad = main_grads[i] - weights[i] = w + if not ctx.cpu_offloading: + w = torch.nn.Parameter(weights[i], weights[i].requires_grad) + weights[i] = w + weights[i].main_grad = main_grads[i] + weights[i].grad_added_to_main_grad = ctx.grad_added_to_main_grad_list[i] # Preprocess grad output grad_output_view = grad_output.contiguous().view(-1, grad_output.shape[-1]) @@ -426,18 +456,15 @@ def handle_custom_ddp_from_mcore(weight, wgrad): ): weight.grad_added_to_main_grad = True if getattr(weight, "zero_out_wgrad", False): - wgrad = torch.zeros( - weight.main_grad.shape, - dtype=weight.dtype, - device=torch.cuda.current_device(), - requires_grad=False, + wgrad = get_dummy_wgrad( + list(weight.main_grad.shape), + weight.dtype, + zero=True, ) else: - wgrad = torch.empty( - weight.main_grad.shape, - dtype=weight.dtype, - device=torch.cuda.current_device(), - requires_grad=False, + wgrad = get_dummy_wgrad( + list(weight.main_grad.shape), + weight.dtype, ) elif ctx.fuse_wgrad_accumulation: wgrad = None @@ -484,6 +511,7 @@ def handle_custom_ddp_from_mcore(weight, wgrad): None, None, None, + None, *wgrad_list, *grad_biases, ) @@ -565,6 +593,7 @@ def __init__( ub_overlap_rs: bool = False, ub_overlap_ag: bool = False, ub_name: Optional[str] = None, + fine_grained_activation_offloading: bool = False, delay_wgrad_compute: bool = False, save_original_input: bool = False, ) -> None: @@ -588,6 +617,8 @@ def __init__( self.get_rng_state_tracker = get_rng_state_tracker self.rng_tracker_name = rng_tracker_name + self.fine_grained_activation_offloading = fine_grained_activation_offloading + self.wgrad_store = WeightGradStore(delay_wgrad_compute) self._offsets = {"input": 0, "weight": 1, "output": 2, "grad_output": 0, "grad_input": 1} @@ -806,6 +837,7 @@ def forward( self, skip_fp8_weight_update, self.save_original_input, + self.fine_grained_activation_offloading, *weight_tensors, *bias_tensors, ) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index cd02f31132..abb96f11a8 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -122,6 +122,7 @@ def forward( ub_bulk_wgrad: bool, ub_bulk_dgrad: bool, ub_name: str, + fine_grained_activation_offloading: bool, fsdp_group: Union[dist_group_type, None], module: torch.nn.Module, skip_fp8_weight_update: bool, @@ -424,10 +425,37 @@ def forward( ) nvtx_range_pop(f"{nvtx_label}.fsdp_scatter") + # Do not offload weights and biases + weight.offloading_activation = False + weightmat.offloading_activation = False + if bias is not None: + bias.offloading_activation = False + ln_weight.offloading_activation = False + ctx.fine_grained_activation_offloading = fine_grained_activation_offloading + + if fine_grained_activation_offloading and cpu_offloading: + raise ValueError( + f"Do not use fine_grained_activation_offloading and cpu_offloading at the same" + f" time." + ) + + if ( + fine_grained_activation_offloading + and weight.requires_grad + and fuse_wgrad_accumulation + ): + if hasattr(weight, "grad_added_to_main_grad"): + ctx.has_grad_added_to_main_grad = True + ctx.grad_added_to_main_grad = weight.grad_added_to_main_grad + weight.grad_added_to_main_grad = True + ctx.weight_object = weight + else: + ctx.has_grad_added_to_main_grad = False + if cpu_offloading: - ctx.grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad") + ctx.has_grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad") - if ctx.grad_added_to_main_grad: + if ctx.has_grad_added_to_main_grad: # If you are passing torch.nn.Parameter through the Torch hooks, you will # get back torch.Tensor. Torch rips off the Parameter wrapper. # You need to preserve the weight object to have all the attributes user @@ -560,9 +588,11 @@ def backward( # For CPU offloading, we offloaded weight and weight.main_grad to different tensors, # we need to connect them into one. - if ctx.cpu_offloading: - if ctx.grad_added_to_main_grad: + if ctx.cpu_offloading or ctx.fine_grained_activation_offloading: + if ctx.has_grad_added_to_main_grad: origin_weight = ctx.weight_object + if ctx.fine_grained_activation_offloading: + origin_weight.grad_added_to_main_grad = ctx.grad_added_to_main_grad if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation: origin_weight.main_grad = main_grad @@ -1021,6 +1051,7 @@ def wgrad_gemm( None, # ub_bulk_dgrad None, # ub_bulk_wgrad None, # ub_name + None, # fine_grained_activation_offloading None, # fsdp_group None, # debug None, # module @@ -1156,6 +1187,7 @@ def __init__( delay_wgrad_compute: bool = False, symmetric_ar_type: Optional[str] = None, name: str = None, + fine_grained_activation_offloading: bool = False, ) -> None: super().__init__() @@ -1172,7 +1204,7 @@ def __init__( self.return_layernorm_output_gathered = return_layernorm_output_gathered self.zero_centered_gamma = zero_centered_gamma self.symmetric_ar_type = symmetric_ar_type - + self.fine_grained_activation_offloading = fine_grained_activation_offloading self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad) self.name = name @@ -1575,6 +1607,7 @@ def forward( self.ub_bulk_wgrad, self.ub_bulk_dgrad, self.ub_name, + self.fine_grained_activation_offloading, self.fsdp_group, self, skip_fp8_weight_update, diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 2ce6fb4c1d..a5ccefbb1d 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -109,6 +109,7 @@ def forward( ub_bulk_dgrad: bool, ub_bulk_wgrad: bool, ub_name: str, + fine_grained_activation_offloading: bool, fp8_output: bool, # pylint: disable=unused-argument fsdp_group: Union[dist_group_type, None], module: torch.nn.Module, @@ -395,10 +396,31 @@ def forward( ) nvtx_range_pop(f"{nvtx_label}.fsdp_scatter") + ctx.fine_grained_activation_offloading = fine_grained_activation_offloading + + if fine_grained_activation_offloading and cpu_offloading: + raise ValueError( + f"Do not use fine_grained_activation_offloading and cpu_offloading at the same" + f" time." + ) + + if ( + fine_grained_activation_offloading + and weight.requires_grad + and fuse_wgrad_accumulation + ): + if hasattr(weight, "grad_added_to_main_grad"): + ctx.has_grad_added_to_main_grad = True + ctx.grad_added_to_main_grad = weight.grad_added_to_main_grad + weight.grad_added_to_main_grad = True + ctx.weight_object = weight + else: + ctx.has_grad_added_to_main_grad = False + if cpu_offloading: - ctx.grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad") + ctx.has_grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad") - if ctx.grad_added_to_main_grad: + if ctx.has_grad_added_to_main_grad: # If you are passing torch.nn.Parameter through the Torch hooks, you will # get back torch.Tensor. Torch rips off the Parameter wrapper. # You need to preserve the weight object to have all the attributes user @@ -406,6 +428,11 @@ def forward( # weights if weights are externally touched outside this module ctx.weight_object = weight + # Do not offload weights and biases + weight.offloading_activation = False + weightmat.offloading_activation = False + if bias is not None: + bias.offloading_activation = False # TODO(ksivamani): Check memory usage tensors_to_save, tensor_objects = prepare_for_saving( saved_inputmat, @@ -493,9 +520,11 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], else None ) - if ctx.cpu_offloading: - if ctx.grad_added_to_main_grad: + if ctx.cpu_offloading or ctx.fine_grained_activation_offloading: + if ctx.has_grad_added_to_main_grad: weight = ctx.weight_object + if ctx.fine_grained_activation_offloading: + weight.grad_added_to_main_grad = ctx.grad_added_to_main_grad if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation: weight.main_grad = main_grad @@ -968,6 +997,7 @@ def wgrad_gemm( None, # ub_bulk_dgrad None, # ub_bulk_wgrad None, # ub_name + None, # fine_grained_activation_offloading None, # fp8_output None, # fsdp_group None, # module @@ -1090,6 +1120,7 @@ def __init__( symmetric_ar_type: Optional[str] = None, save_original_input: bool = False, name: Optional[str] = None, + fine_grained_activation_offloading: bool = False, ) -> None: super().__init__() @@ -1105,7 +1136,7 @@ def __init__( self.symmetric_ar_type = symmetric_ar_type self.save_original_input = save_original_input self.name = name - + self.fine_grained_activation_offloading = fine_grained_activation_offloading self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad) if device == "meta": @@ -1452,6 +1483,7 @@ def forward( self.ub_bulk_dgrad, self.ub_bulk_wgrad, self.ub_name, + self.fine_grained_activation_offloading, fp8_output, self.fsdp_group, self,