diff --git a/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py b/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py index a8fca0577..5f97bcc76 100644 --- a/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py +++ b/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py @@ -496,7 +496,9 @@ def _test_mini_optimizer(dp_group): torch.testing.assert_close(w1, w3, atol=0, rtol=0) -def _test_cast_master_weights_to_fp8(quantization, dp_group, manual_post_all_gather_processing): +def _test_cast_master_weights_to_fp8( + quantization, dp_group, manual_post_all_gather_processing, keep_fp8_weight_transpose_cache +): rank = dist.get_rank(dp_group) world_size = dist.get_world_size(dp_group) @@ -506,7 +508,12 @@ def _test_cast_master_weights_to_fp8(quantization, dp_group, manual_post_all_gat mock_groups = [dist.new_group(ranks=[i]) for i in range(world_size)] mock_group = mock_groups[rank] - linear_kwargs = {"params_dtype": torch.bfloat16, "bias": False, "fuse_wgrad_accumulation": True} + linear_kwargs = { + "params_dtype": torch.bfloat16, + "bias": False, + "fuse_wgrad_accumulation": True, + "keep_fp8_weight_transpose_cache": keep_fp8_weight_transpose_cache, + } # Create model with FP8 weights with te.quantized_model_init( @@ -583,7 +590,7 @@ def _test_cast_master_weights_to_fp8(quantization, dp_group, manual_post_all_gat def _test_fsdp_cast_master_weights_to_fp8( - quantization, dp_group, manual_post_all_gather_processing + quantization, dp_group, manual_post_all_gather_processing, keep_fp8_weight_transpose_cache ): rank = dist.get_rank(dp_group) world_size = dist.get_world_size(dp_group) @@ -602,6 +609,7 @@ def _test_fsdp_cast_master_weights_to_fp8( "params_dtype": torch.bfloat16, "bias": False, "fuse_wgrad_accumulation": True, + "keep_fp8_weight_transpose_cache": keep_fp8_weight_transpose_cache, } # Create model with FP8 weights @@ -702,13 +710,19 @@ def run_parallel_tests() -> None: quantizations.append("fp8_block") manual_post_all_gather_processings = [False, True] + keep_fp8_weight_transpose_caches = [True, False] _test_mini_optimizer(dp_group) for quantization in quantizations: for post_ag_processing in manual_post_all_gather_processings: - _test_cast_master_weights_to_fp8(quantization, dp_group, post_ag_processing) - _test_fsdp_cast_master_weights_to_fp8(quantization, dp_group, post_ag_processing) + for keep_fp8_weight_transpose_cache in keep_fp8_weight_transpose_caches: + _test_cast_master_weights_to_fp8( + quantization, dp_group, post_ag_processing, keep_fp8_weight_transpose_cache + ) + _test_fsdp_cast_master_weights_to_fp8( + quantization, dp_group, post_ag_processing, keep_fp8_weight_transpose_cache + ) dist.destroy_process_group() diff --git a/transformer_engine/pytorch/tensor/utils.py b/transformer_engine/pytorch/tensor/utils.py index 5da5fcdc1..e535260df 100644 --- a/transformer_engine/pytorch/tensor/utils.py +++ b/transformer_engine/pytorch/tensor/utils.py @@ -482,7 +482,7 @@ def post_all_gather_processing(model_weights: Union[torch.Tensor, List[torch.Ten if isinstance(model_weight, Float8Tensor): # Delayed scaling and per-tensor current scaling: if backend does not support # non-transposed FP8 GEMM, pre-create the transpose. - if not is_non_tn_fp8_gemm_supported(): + if model_weight._quantizer.columnwise_usage and not is_non_tn_fp8_gemm_supported(): model_weight._create_transpose() elif isinstance(model_weight, Float8BlockwiseQTensor): # Blockwise scaling: create column-wise storage.