From 2388fe7a3031821a7c0835b5eaec335402b164a3 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Thu, 27 Nov 2025 15:14:14 +0100 Subject: [PATCH 1/9] Fix fp8 + some enhancement --- src/transformers/core_model_loading.py | 1 + .../integrations/finegrained_fp8.py | 122 ++++++------------ src/transformers/integrations/mxfp4.py | 12 +- src/transformers/quantizers/base.py | 62 +++------ .../quantizers/quantizer_finegrained_fp8.py | 3 +- .../quantizers/quantizers_utils.py | 8 ++ 6 files changed, 73 insertions(+), 135 deletions(-) diff --git a/src/transformers/core_model_loading.py b/src/transformers/core_model_loading.py index 30ce9d6b6732..01dc33de3d66 100644 --- a/src/transformers/core_model_loading.py +++ b/src/transformers/core_model_loading.py @@ -732,6 +732,7 @@ def convert_and_load_state_dict_in_model( if hf_quantizer and hf_quantizer.pre_quantized and original_key != renamed_key: # if the key was renamed as it is not available in the state dict otherwise, it means that we are deserializing it, # so we need to make sure to load the tensor with the same dtype from the checkpoint + # TODO: make the condition more srict more native fp8 model such as qwen2moe fp8 _dtype = None elif dtype_plan != {} and dtype_policy_alt.search(renamed_key): matched_dtype_pattern = dtype_policy_alt.search(renamed_key) diff --git a/src/transformers/integrations/finegrained_fp8.py b/src/transformers/integrations/finegrained_fp8.py index 42854324bc0d..d2e20bcbcd5b 100644 --- a/src/transformers/integrations/finegrained_fp8.py +++ b/src/transformers/integrations/finegrained_fp8.py @@ -18,6 +18,7 @@ from typing import Any from ..core_model_loading import ConversionOps +from ..quantizers.quantizers_utils import should_convert_module from ..utils import is_accelerate_available, is_torch_accelerator_available, is_torch_available, logging @@ -307,29 +308,25 @@ def w8a8_block_fp8_matmul_compile( class FP8Linear(nn.Linear): - dtype = torch.float8_e4m3fn - def __init__( self, in_features: int, out_features: int, bias: bool = False, - dtype=None, + dtype=torch.float8_e4m3fn, block_size: tuple[int, int] | None = None, - device=None, activation_scheme="dynamic", ): super().__init__(in_features, out_features) self.in_features = in_features self.out_features = out_features - - self.weight = torch.nn.Parameter(torch.empty(out_features, in_features, dtype=FP8Linear.dtype, device=device)) + self.weight = torch.nn.Parameter(torch.empty(out_features, in_features, dtype=dtype)) if self.weight.element_size() == 1: scale_out_features = (out_features + block_size[0] - 1) // block_size[0] scale_in_features = (in_features + block_size[1] - 1) // block_size[1] self.weight_scale_inv = nn.Parameter( - torch.empty(scale_out_features, scale_in_features, dtype=torch.float32, device=device) + torch.empty(scale_out_features, scale_in_features, dtype=torch.float32) ) else: self.register_parameter("weight_scale_inv", None) @@ -380,9 +377,7 @@ def _ceil_div(a, b): class FP8Expert(nn.Module): - dtype = torch.float8_e4m3fn - - def __init__(self, config, block_size, device): + def __init__(self, config, block_size, dtype=torch.float8_e4m3fn): super().__init__() from ..activations import ACT2FN @@ -396,10 +391,10 @@ def __init__(self, config, block_size, device): Wd_out, Wd_in = self.hidden_dim, self.intermediate_dim self.gate_up_proj = nn.Parameter( - torch.zeros(self.num_experts, Wg_out, Wg_in, dtype=FP8Expert.dtype, device=device) + torch.zeros(self.num_experts, Wg_out, Wg_in, dtype=dtype) ) self.down_proj = nn.Parameter( - torch.zeros(self.num_experts, Wd_out, Wd_in, dtype=FP8Expert.dtype, device=device) + torch.zeros(self.num_experts, Wd_out, Wd_in, dtype=dtype) ) # Create inverse scale tiles only when using 1-byte types (fp8) @@ -410,14 +405,14 @@ def __init__(self, config, block_size, device): gu_scale_o = _ceil_div(Wg_out, bo) gu_scale_i = _ceil_div(Wg_in, bi) self.gate_up_proj_scale_inv = nn.Parameter( - torch.zeros(self.num_experts, gu_scale_o, gu_scale_i, dtype=torch.float32, device=device) + torch.zeros(self.num_experts, gu_scale_o, gu_scale_i, dtype=torch.float32) ) # down tiles: ceil(Wd_out/bo) x ceil(Wd_in/bi) dp_scale_o = _ceil_div(Wd_out, bo) dp_scale_i = _ceil_div(Wd_in, bi) self.down_proj_scale_inv = nn.Parameter( - torch.zeros(self.num_experts, dp_scale_o, dp_scale_i, dtype=torch.float32, device=device) + torch.zeros(self.num_experts, dp_scale_o, dp_scale_i, dtype=torch.float32) ) else: # Match FP8Linear behavior when not using 1-byte weights @@ -487,81 +482,46 @@ def linear(self, input: torch.Tensor, weight: torch.Tensor, weight_scale_inv: to torch_accelerator_module.synchronize() return output.to(dtype=input.dtype) - -# TODO: we do need this.... but not recursive... -def _replace_with_fp8_linear( +def replace_with_fp8_linear( model, - tp_plan=None, modules_to_not_convert=None, - current_key_name=None, quantization_config=None, - has_been_replaced=False, + pre_quantized=False, ): + """Helper function to replace model layers with FP8 versions.""" + iterator = list(model.named_parameters()).copy() - for name, empty_tensor in iterator: - current_key_name = name - name = name.rsplit(".", 1)[0] if "." in name else name - module = model.get_submodule(name) - - current_key_name_str = re.sub(r"\d+", "*", current_key_name) - if not any(key in current_key_name_str for key in (modules_to_not_convert or [])): - with init_empty_weights(): - if ( - "gate_up_proj" in current_key_name - or "down_proj" in current_key_name - and "experts" in current_key_name - ): # Experts! - in_features = empty_tensor.size(-2) - out_features = empty_tensor.size(-1) - model.set_submodule( - name, - FP8Expert( - config=model.config, - block_size=quantization_config.weight_block_size, - device=empty_tensor.device, - ), + for full_name, empty_tensor in iterator: + if not should_convert_module(full_name, modules_to_not_convert): + continue + module_name = full_name.rsplit(".", 1)[0] if "." in full_name else full_name + module = model.get_submodule(module_name) + # we need this to correctly materialize the weights during quantization + module_kwargs = {} if pre_quantized else {"dtype": None} + new_module = None + with init_empty_weights(): + if ( + "gate_up_proj" in module_name + or "down_proj" in module_name + and "experts" in module_name + ): + new_module = FP8Expert( + config=model.config, + block_size=quantization_config.weight_block_size, + **module_kwargs ) - - elif isinstance(module, nn.Linear): - in_features = module.in_features - out_features = module.out_features - model.set_submodule( - name, - FP8Linear( - in_features=in_features, - out_features=out_features, - bias=module.bias is not None, - device=module.weight.device, - dtype=module.weight.dtype, - activation_scheme=quantization_config.activation_scheme, - block_size=quantization_config.weight_block_size, - ), + elif isinstance(module, nn.Linear): + new_module = FP8Linear( + in_features=module.in_features, + out_features=module.out_features, + bias=module.bias is not None, + activation_scheme=quantization_config.activation_scheme, + block_size=quantization_config.weight_block_size, + **module_kwargs ) + if new_module is not None: + model.set_submodule(module_name, new_module) has_been_replaced = True - # when changing a layer the TP PLAN for that layer should be updated. TODO - - return model, has_been_replaced - - -def replace_with_fp8_linear( - model, - modules_to_not_convert=None, - quantization_config=None, -): - """Helper function to replace model layers with FP8 versions.""" - if modules_to_not_convert is None: - modules_to_not_convert = [] - modules_to_not_convert += ["lm_head"] - - if quantization_config.modules_to_not_convert is not None: - modules_to_not_convert.extend(quantization_config.modules_to_not_convert) - modules_to_not_convert = list(set(modules_to_not_convert)) - model, has_been_replaced = _replace_with_fp8_linear( - model, - tp_plan=model._tp_plan, - modules_to_not_convert=modules_to_not_convert, - quantization_config=quantization_config, - ) if not has_been_replaced: logger.warning( diff --git a/src/transformers/integrations/mxfp4.py b/src/transformers/integrations/mxfp4.py index 6430c0d9d57d..4c3c8d6194f8 100644 --- a/src/transformers/integrations/mxfp4.py +++ b/src/transformers/integrations/mxfp4.py @@ -29,7 +29,7 @@ import re from contextlib import contextmanager -from ..quantizers.quantizers_utils import get_module_from_name +from ..quantizers.quantizers_utils import get_module_from_name, should_convert_module logger = logging.get_logger(__name__) @@ -435,16 +435,6 @@ def mlp_forward(self, hidden_states): routed_out = routed_out.reshape(batch_size, -1, self.router.hidden_dim) return routed_out, router_logits - -def should_convert_module(current_key_name, patterns): - current_key_name_str = ".".join(current_key_name) - if not any( - re.match(f"{key}\\.", current_key_name_str) or re.match(f"{key}", current_key_name_str) for key in patterns - ): - return True - return False - - def dequantize(module, param_name, param_value, target_device, dq_param_name, **kwargs): from ..integrations.tensor_parallel import shard_and_distribute_module diff --git a/src/transformers/quantizers/base.py b/src/transformers/quantizers/base.py index b0a5873da303..e601934207fb 100644 --- a/src/transformers/quantizers/base.py +++ b/src/transformers/quantizers/base.py @@ -47,49 +47,27 @@ def _assign_original_dtype(module, original_dtype): def get_keys_to_not_convert(model): r""" - An utility function to get the key of the module to keep in full precision if any For example for CausalLM modules - we may want to keep the lm_head in full precision for numerical stability reasons. For other architectures, we want - to keep the tied weights of the model. The function will return a list of the keys of the modules to not convert in - int8. - - Parameters: - model (`torch.nn.Module`): - Input model + Function to automatically detect keys to not convert for usage like quantization. For example for CausalLM modules + we may want to keep the lm_head in full precision for numerical stability reasons. """ - # Create a copy of the model and tie the weights, then - # check if it contains tied weights - tied_model = deepcopy(model) # this has 0 cost since it is done inside `init_empty_weights` context manager` - tied_model.tie_weights() - - tied_params = find_tied_parameters(tied_model) - tied_keys = sum(tied_params, []) - has_tied_params = len(tied_keys) > 0 - - # If there is not tied weights, we want to keep the lm_head(output_embedding) in full precision - if not has_tied_params: - output_emb = model.get_output_embeddings() - if output_emb is not None: - list_last_module = [name for name, module in model.named_modules() if id(module) == id(output_emb)] - return list_last_module - - # otherwise, no tied weights, no output embedding defined, simply keep the last module in full precision - list_modules = list(model.named_parameters()) - list_last_module = [list_modules[-1][0]] - # add last module together with tied weights - intersection = set(list_last_module) - set(tied_keys) - list_untouched = list(set(tied_keys)) + list(intersection) - - # remove ".weight" from the keys - names_to_remove = [".weight", ".bias"] - filtered_module_names = [] - for name in list_untouched: - for name_to_remove in names_to_remove: - if name_to_remove in name: - name = name.replace(name_to_remove, "") - filtered_module_names.append(name) - - return filtered_module_names + # remove tied weights + tied_keys = set() + if len(model.all_tied_weights_keys) > 0: + tied_keys = set(model.all_tied_weights_keys.values()) | set(model.all_tied_weights_keys.keys()) + + # remove last module + last_module_key = {list(model.named_parameters())[-1][0]} + + # remove output emb + output_emb_module = model.get_output_embeddings() + output_emb_keys = { + name for name, module in model.named_modules() + if output_emb_module is not None and id(module) == id(output_emb_module) + } + candidates = tied_keys | last_module_key | output_emb_keys + modules_to_not_convert = {name.replace(suffix, "") for name in candidates for suffix in [".weight", ".bias"]} + return modules_to_not_convert class HfQuantizer(ABC): """ @@ -360,6 +338,8 @@ def get_modules_to_not_convert( if keep_in_fp32_modules is not None: modules_to_not_convert.extend(keep_in_fp32_modules) + modules_to_not_convert = list(set(modules_to_not_convert)) + return modules_to_not_convert @property diff --git a/src/transformers/quantizers/quantizer_finegrained_fp8.py b/src/transformers/quantizers/quantizer_finegrained_fp8.py index 293dede7dc52..190263b292d1 100644 --- a/src/transformers/quantizers/quantizer_finegrained_fp8.py +++ b/src/transformers/quantizers/quantizer_finegrained_fp8.py @@ -158,16 +158,15 @@ def _process_model_before_weight_loading( ): from ..integrations.finegrained_fp8 import replace_with_fp8_linear - # takes 2 fucking seconds self.modules_to_not_convert = self.get_modules_to_not_convert( model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules ) - # while this one is 81ms :) model = replace_with_fp8_linear( model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config, + pre_quantized=self.pre_quantized, ) model.config.quantization_config = self.quantization_config diff --git a/src/transformers/quantizers/quantizers_utils.py b/src/transformers/quantizers/quantizers_utils.py index 94955495f561..ad6dfaf75ed1 100644 --- a/src/transformers/quantizers/quantizers_utils.py +++ b/src/transformers/quantizers/quantizers_utils.py @@ -19,3 +19,11 @@ def get_module_from_name(module, tensor_name: str) -> tuple[Any, str]: module_name, tensor_name = tensor_name.rsplit(".", 1) module = module.get_submodule(module_name) return module, tensor_name + +def should_convert_module(current_key_name, patterns): + current_key_name_str = ".".join(current_key_name) + if not any( + re.match(f"{key}\\.", current_key_name_str) or re.match(f"{key}", current_key_name_str) for key in patterns + ): + return True + return False From 260f7c529673978408e9b6611f122d938c37a8c6 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Thu, 27 Nov 2025 14:20:02 +0000 Subject: [PATCH 2/9] style --- .../integrations/finegrained_fp8.py | 36 +++++++------------ src/transformers/integrations/mxfp4.py | 2 +- src/transformers/quantizers/base.py | 9 ++--- .../quantizers/quantizers_utils.py | 2 ++ 4 files changed, 21 insertions(+), 28 deletions(-) diff --git a/src/transformers/integrations/finegrained_fp8.py b/src/transformers/integrations/finegrained_fp8.py index d2e20bcbcd5b..37ecae6809fb 100644 --- a/src/transformers/integrations/finegrained_fp8.py +++ b/src/transformers/integrations/finegrained_fp8.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import re from collections.abc import Sequence from typing import Any @@ -390,12 +389,8 @@ def __init__(self, config, block_size, dtype=torch.float8_e4m3fn): Wg_out, Wg_in = 2 * self.intermediate_dim, self.hidden_dim Wd_out, Wd_in = self.hidden_dim, self.intermediate_dim - self.gate_up_proj = nn.Parameter( - torch.zeros(self.num_experts, Wg_out, Wg_in, dtype=dtype) - ) - self.down_proj = nn.Parameter( - torch.zeros(self.num_experts, Wd_out, Wd_in, dtype=dtype) - ) + self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, Wg_out, Wg_in, dtype=dtype)) + self.down_proj = nn.Parameter(torch.zeros(self.num_experts, Wd_out, Wd_in, dtype=dtype)) # Create inverse scale tiles only when using 1-byte types (fp8) if self.gate_up_proj.element_size() == 1: @@ -482,6 +477,7 @@ def linear(self, input: torch.Tensor, weight: torch.Tensor, weight_scale_inv: to torch_accelerator_module.synchronize() return output.to(dtype=input.dtype) + def replace_with_fp8_linear( model, modules_to_not_convert=None, @@ -500,25 +496,19 @@ def replace_with_fp8_linear( module_kwargs = {} if pre_quantized else {"dtype": None} new_module = None with init_empty_weights(): - if ( - "gate_up_proj" in module_name - or "down_proj" in module_name - and "experts" in module_name - ): + if "gate_up_proj" in module_name or "down_proj" in module_name and "experts" in module_name: new_module = FP8Expert( - config=model.config, - block_size=quantization_config.weight_block_size, - **module_kwargs - ) + config=model.config, block_size=quantization_config.weight_block_size, **module_kwargs + ) elif isinstance(module, nn.Linear): new_module = FP8Linear( - in_features=module.in_features, - out_features=module.out_features, - bias=module.bias is not None, - activation_scheme=quantization_config.activation_scheme, - block_size=quantization_config.weight_block_size, - **module_kwargs - ) + in_features=module.in_features, + out_features=module.out_features, + bias=module.bias is not None, + activation_scheme=quantization_config.activation_scheme, + block_size=quantization_config.weight_block_size, + **module_kwargs, + ) if new_module is not None: model.set_submodule(module_name, new_module) has_been_replaced = True diff --git a/src/transformers/integrations/mxfp4.py b/src/transformers/integrations/mxfp4.py index 4c3c8d6194f8..14943ff4a4d2 100644 --- a/src/transformers/integrations/mxfp4.py +++ b/src/transformers/integrations/mxfp4.py @@ -26,7 +26,6 @@ if is_accelerate_available(): from accelerate import init_empty_weights -import re from contextlib import contextmanager from ..quantizers.quantizers_utils import get_module_from_name, should_convert_module @@ -435,6 +434,7 @@ def mlp_forward(self, hidden_states): routed_out = routed_out.reshape(batch_size, -1, self.router.hidden_dim) return routed_out, router_logits + def dequantize(module, param_name, param_value, target_device, dq_param_name, **kwargs): from ..integrations.tensor_parallel import shard_and_distribute_module diff --git a/src/transformers/quantizers/base.py b/src/transformers/quantizers/base.py index e601934207fb..1ae09ca9993c 100644 --- a/src/transformers/quantizers/base.py +++ b/src/transformers/quantizers/base.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC, abstractmethod -from copy import deepcopy from typing import TYPE_CHECKING, Any from ..utils import is_accelerate_available, is_torch_available, logging @@ -21,7 +20,7 @@ if is_accelerate_available(): - from accelerate.utils import find_tied_parameters + pass if TYPE_CHECKING: from ..modeling_utils import PreTrainedModel @@ -57,11 +56,12 @@ def get_keys_to_not_convert(model): # remove last module last_module_key = {list(model.named_parameters())[-1][0]} - + # remove output emb output_emb_module = model.get_output_embeddings() output_emb_keys = { - name for name, module in model.named_modules() + name + for name, module in model.named_modules() if output_emb_module is not None and id(module) == id(output_emb_module) } candidates = tied_keys | last_module_key | output_emb_keys @@ -69,6 +69,7 @@ def get_keys_to_not_convert(model): modules_to_not_convert = {name.replace(suffix, "") for name in candidates for suffix in [".weight", ".bias"]} return modules_to_not_convert + class HfQuantizer(ABC): """ Abstract class of the HuggingFace quantizer. Supports for now quantizing HF transformers models for inference and/or quantization. diff --git a/src/transformers/quantizers/quantizers_utils.py b/src/transformers/quantizers/quantizers_utils.py index ad6dfaf75ed1..065b62448ec0 100644 --- a/src/transformers/quantizers/quantizers_utils.py +++ b/src/transformers/quantizers/quantizers_utils.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import re from typing import Any @@ -20,6 +21,7 @@ def get_module_from_name(module, tensor_name: str) -> tuple[Any, str]: module = module.get_submodule(module_name) return module, tensor_name + def should_convert_module(current_key_name, patterns): current_key_name_str = ".".join(current_key_name) if not any( From 88f9788660e932f0e46a3f339b1c55c35d925aab Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Thu, 27 Nov 2025 14:22:07 +0000 Subject: [PATCH 3/9] Add coauthor Co-authored-by: Yang Kai From e2c8e5bfc7d2368346999f8e8504783d63cf67f2 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Thu, 27 Nov 2025 16:15:06 +0000 Subject: [PATCH 4/9] fix --- .../integrations/finegrained_fp8.py | 60 ++++++----------- src/transformers/integrations/mxfp4.py | 66 ++++--------------- .../quantizers/quantizer_mxfp4.py | 8 +-- .../quantizers/quantizers_utils.py | 24 +++++-- .../quantization/finegrained_fp8/test_fp8.py | 17 ++--- 5 files changed, 57 insertions(+), 118 deletions(-) diff --git a/src/transformers/integrations/finegrained_fp8.py b/src/transformers/integrations/finegrained_fp8.py index 37ecae6809fb..dffaf18c4f08 100644 --- a/src/transformers/integrations/finegrained_fp8.py +++ b/src/transformers/integrations/finegrained_fp8.py @@ -317,23 +317,13 @@ def __init__( activation_scheme="dynamic", ): super().__init__(in_features, out_features) - self.in_features = in_features - self.out_features = out_features - self.weight = torch.nn.Parameter(torch.empty(out_features, in_features, dtype=dtype)) - - if self.weight.element_size() == 1: - scale_out_features = (out_features + block_size[0] - 1) // block_size[0] - scale_in_features = (in_features + block_size[1] - 1) // block_size[1] - self.weight_scale_inv = nn.Parameter( - torch.empty(scale_out_features, scale_in_features, dtype=torch.float32) - ) - else: - self.register_parameter("weight_scale_inv", None) - self.block_size = block_size - self.activation_scheme = activation_scheme + self.weight = torch.nn.Parameter(torch.empty(out_features, in_features, dtype=dtype)) + scale_out_features = (out_features + block_size[0] - 1) // block_size[0] + scale_in_features = (in_features + block_size[1] - 1) // block_size[1] + self.weight_scale_inv = nn.Parameter(torch.empty(scale_out_features, scale_in_features, dtype=torch.float32)) if bias: self.bias = nn.Parameter(torch.empty(self.out_features)) else: @@ -392,27 +382,21 @@ def __init__(self, config, block_size, dtype=torch.float8_e4m3fn): self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, Wg_out, Wg_in, dtype=dtype)) self.down_proj = nn.Parameter(torch.zeros(self.num_experts, Wd_out, Wd_in, dtype=dtype)) - # Create inverse scale tiles only when using 1-byte types (fp8) - if self.gate_up_proj.element_size() == 1: - bo, bi = self.block_size + bo, bi = self.block_size - # gate_up tiles: ceil(Wg_out/bo) x ceil(Wg_in/bi) - gu_scale_o = _ceil_div(Wg_out, bo) - gu_scale_i = _ceil_div(Wg_in, bi) - self.gate_up_proj_scale_inv = nn.Parameter( - torch.zeros(self.num_experts, gu_scale_o, gu_scale_i, dtype=torch.float32) - ) + # gate_up tiles: ceil(Wg_out/bo) x ceil(Wg_in/bi) + gu_scale_o = _ceil_div(Wg_out, bo) + gu_scale_i = _ceil_div(Wg_in, bi) + self.gate_up_proj_scale_inv = nn.Parameter( + torch.zeros(self.num_experts, gu_scale_o, gu_scale_i, dtype=torch.float32) + ) - # down tiles: ceil(Wd_out/bo) x ceil(Wd_in/bi) - dp_scale_o = _ceil_div(Wd_out, bo) - dp_scale_i = _ceil_div(Wd_in, bi) - self.down_proj_scale_inv = nn.Parameter( - torch.zeros(self.num_experts, dp_scale_o, dp_scale_i, dtype=torch.float32) - ) - else: - # Match FP8Linear behavior when not using 1-byte weights - self.register_parameter("gate_up_proj_scale_inv", None) - self.register_parameter("down_proj_scale_inv", None) + # down tiles: ceil(Wd_out/bo) x ceil(Wd_in/bi) + dp_scale_o = _ceil_div(Wd_out, bo) + dp_scale_i = _ceil_div(Wd_in, bi) + self.down_proj_scale_inv = nn.Parameter( + torch.zeros(self.num_experts, dp_scale_o, dp_scale_i, dtype=torch.float32) + ) # (Optional) bias per projection — many MoEs omit bias; keep None to match your FP8Linear default self.register_parameter("gate_up_bias", None) @@ -485,13 +469,10 @@ def replace_with_fp8_linear( pre_quantized=False, ): """Helper function to replace model layers with FP8 versions.""" - - iterator = list(model.named_parameters()).copy() - for full_name, empty_tensor in iterator: - if not should_convert_module(full_name, modules_to_not_convert): + has_been_replaced = False + for module_name, module in model.named_modules(): + if not should_convert_module(module_name, modules_to_not_convert): continue - module_name = full_name.rsplit(".", 1)[0] if "." in full_name else full_name - module = model.get_submodule(module_name) # we need this to correctly materialize the weights during quantization module_kwargs = {} if pre_quantized else {"dtype": None} new_module = None @@ -518,7 +499,6 @@ def replace_with_fp8_linear( "You are loading your model using fp8 but no linear modules were found in your model." " Please double check your model architecture." ) - return model diff --git a/src/transformers/integrations/mxfp4.py b/src/transformers/integrations/mxfp4.py index 14943ff4a4d2..4f4fee40bf96 100644 --- a/src/transformers/integrations/mxfp4.py +++ b/src/transformers/integrations/mxfp4.py @@ -594,70 +594,28 @@ def swizzle_mxfp4_convertops(blocks, scales, module, proj, target_device, triton ) -def _replace_with_mxfp4_linear( - model, - modules_to_not_convert=None, - current_key_name=None, - quantization_config=None, - has_been_replaced=False, - config=None, -): - if current_key_name is None: - current_key_name = [] +def replace_with_mxfp4_linear(model, modules_to_not_convert=None, quantization_config=None): + if quantization_config.dequantize: + return model - for name, module in model.named_children(): - current_key_name.append(name) - if not should_convert_module(current_key_name, modules_to_not_convert): - current_key_name.pop(-1) + from kernels import get_kernel + + global triton_kernels_hub + triton_kernels_hub = get_kernel("kernels-community/triton_kernels") + + has_been_replaced = False + for module_name, module in model.named_modules(): + if not should_convert_module(module_name, modules_to_not_convert): continue if module.__class__.__name__ == "GptOssExperts" and not quantization_config.dequantize: with init_empty_weights(): - model._modules[name] = Mxfp4GptOssExperts(config) + model.set_submodule(module_name, Mxfp4GptOssExperts(model.config)) has_been_replaced = True if module.__class__.__name__ == "GptOssMLP" and not quantization_config.dequantize: from types import MethodType module.forward = MethodType(mlp_forward, module) - if len(list(module.children())) > 0: - _, has_been_replaced = _replace_with_mxfp4_linear( - module, - modules_to_not_convert, - current_key_name, - quantization_config, - has_been_replaced=has_been_replaced, - config=config, - ) - current_key_name.pop(-1) - return model, has_been_replaced - -def replace_with_mxfp4_linear( - model, - modules_to_not_convert=None, - current_key_name=None, - quantization_config=None, - config=None, -): - if quantization_config.dequantize: - return model - else: - from kernels import get_kernel - - global triton_kernels_hub - triton_kernels_hub = get_kernel("kernels-community/triton_kernels") - - modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert - - if quantization_config.modules_to_not_convert is not None: - modules_to_not_convert.extend(quantization_config.modules_to_not_convert) - modules_to_not_convert = list(set(modules_to_not_convert)) - model, has_been_replaced = _replace_with_mxfp4_linear( - model, - modules_to_not_convert, - current_key_name, - quantization_config, - config=config, - ) if not has_been_replaced: logger.warning( "You are loading your model using mixed-precision FP4 quantization but no linear modules were found in your model." diff --git a/src/transformers/quantizers/quantizer_mxfp4.py b/src/transformers/quantizers/quantizer_mxfp4.py index ae91138653d2..534c24e660ef 100644 --- a/src/transformers/quantizers/quantizer_mxfp4.py +++ b/src/transformers/quantizers/quantizer_mxfp4.py @@ -301,6 +301,7 @@ def _process_model_before_weight_loading( self, model: "PreTrainedModel", keep_in_fp32_modules: list[str] | None = None, + use_kernels: bool = False, **kwargs, ): from ..integrations import replace_with_mxfp4_linear @@ -309,7 +310,6 @@ def _process_model_before_weight_loading( model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules ) - use_kernels = kwargs.get("use_kernels", False) # if we are using kernels, we can't use the quantized model, since the forward pass is different and needs special handling if use_kernels: logger.warning_once( @@ -318,12 +318,8 @@ def _process_model_before_weight_loading( ) self.quantization_config.dequantize = True - config = model.config model = replace_with_mxfp4_linear( - model, - modules_to_not_convert=self.modules_to_not_convert, - quantization_config=self.quantization_config, - config=config, + model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config ) model.config.quantization_config = self.quantization_config diff --git a/src/transformers/quantizers/quantizers_utils.py b/src/transformers/quantizers/quantizers_utils.py index 065b62448ec0..5e3429b602f8 100644 --- a/src/transformers/quantizers/quantizers_utils.py +++ b/src/transformers/quantizers/quantizers_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import re -from typing import Any +from typing import Any, Optional def get_module_from_name(module, tensor_name: str) -> tuple[Any, str]: @@ -22,10 +22,20 @@ def get_module_from_name(module, tensor_name: str) -> tuple[Any, str]: return module, tensor_name -def should_convert_module(current_key_name, patterns): - current_key_name_str = ".".join(current_key_name) - if not any( - re.match(f"{key}\\.", current_key_name_str) or re.match(f"{key}", current_key_name_str) for key in patterns - ): +def should_convert_module(full_name, patterns: Optional[list] = None): + if patterns is None: return True - return False + + # We should avoid converting in the following situations: + # 1. The pattern appears as a prefix followed by a dot in `full_name` + # (e.g., "model.decoder.layer.11." matches "model.decoder.layer.11.attn.weight"). + # 2. The pattern matches `full_name` exactly or via regex + # (e.g., "lm_head" matches "lm_head"; "model.decoder.layer.*" matches "model.decoder.layer.11.attn.weight"). + # 3. `full_name` ends with the pattern + # (e.g., "fc1" matches "model.decoder.layers.23.fc1"). + + should_not_convert = any( + re.match(f"{key}\\.", full_name) or re.match(f"{key}", full_name) or full_name.endswith(key) + for key in patterns + ) + return not should_not_convert diff --git a/tests/quantization/finegrained_fp8/test_fp8.py b/tests/quantization/finegrained_fp8/test_fp8.py index b05ee61dcb38..414cc3f8f437 100644 --- a/tests/quantization/finegrained_fp8/test_fp8.py +++ b/tests/quantization/finegrained_fp8/test_fp8.py @@ -132,25 +132,21 @@ def test_quantized_model_conversion(self): for module in model.modules(): if isinstance(module, torch.nn.Linear): nb_linears += 1 - model = replace_with_fp8_linear(model, quantization_config=quantization_config) nb_fp8_linear = 0 for module in model.modules(): if isinstance(module, FP8Linear): nb_fp8_linear += 1 - print(model) - self.assertEqual(nb_linears - 1, nb_fp8_linear) - + self.assertEqual(nb_linears, nb_fp8_linear) with init_empty_weights(): model = OPTForCausalLM(config) - quantization_config = FineGrainedFP8Config(modules_to_not_convert=["fc1"]) - model = replace_with_fp8_linear(model, quantization_config=quantization_config) + quantization_config = FineGrainedFP8Config() + model = replace_with_fp8_linear(model, modules_to_not_convert=["fc1"], quantization_config=quantization_config) nb_fp8_linear = 0 for module in model.modules(): if isinstance(module, FP8Linear): nb_fp8_linear += 1 - - self.assertEqual(nb_linears - 25, nb_fp8_linear) + self.assertEqual(nb_linears - 24, nb_fp8_linear) def test_quantized_model(self): """ @@ -209,7 +205,6 @@ def test_quantized_model_multi_accelerator(self): quantized_model = AutoModelForCausalLM.from_pretrained( self.model_name, device_map="auto", quantization_config=quantization_config ) - print("hf_device_map", quantized_model.hf_device_map) self.assertTrue(set(quantized_model.hf_device_map.values()) == {0, 1}) output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens, do_sample=False) @@ -271,7 +266,7 @@ def test_linear_preserves_shape(self): """ from transformers.integrations import FP8Linear - linear = FP8Linear(256, 256, block_size=(128, 128), device=self.device) + linear = FP8Linear(256, 256, block_size=(128, 128)).to(self.device) x = torch.rand((1, 5, 256)).to(self.device) x_ = linear(x) @@ -283,7 +278,7 @@ def test_linear_with_diff_feature_size_preserves_shape(self): """ from transformers.integrations import FP8Linear - linear = FP8Linear(128, 256, block_size=(128, 128), device=self.device) + linear = FP8Linear(128, 256, block_size=(128, 128)).to(self.device) x = torch.rand((1, 5, 128)).to(self.device) x_ = linear(x) From 1ee8e5e19246c61aca0c9f0163d03c732c16a6fd Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Fri, 28 Nov 2025 15:27:28 +0000 Subject: [PATCH 5/9] style --- src/transformers/quantizers/base.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/transformers/quantizers/base.py b/src/transformers/quantizers/base.py index 1ae09ca9993c..8f08a1b33f69 100644 --- a/src/transformers/quantizers/base.py +++ b/src/transformers/quantizers/base.py @@ -18,10 +18,6 @@ from ..utils.quantization_config import QuantizationConfigMixin, QuantizationMethod from .quantizers_utils import get_module_from_name - -if is_accelerate_available(): - pass - if TYPE_CHECKING: from ..modeling_utils import PreTrainedModel From 13620e0378003567aa8ffd505c3164dd76c4f547 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Fri, 28 Nov 2025 15:29:38 +0000 Subject: [PATCH 6/9] fix tests --- tests/quantization/finegrained_fp8/test_fp8.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/quantization/finegrained_fp8/test_fp8.py b/tests/quantization/finegrained_fp8/test_fp8.py index 414cc3f8f437..81e619cfa93e 100644 --- a/tests/quantization/finegrained_fp8/test_fp8.py +++ b/tests/quantization/finegrained_fp8/test_fp8.py @@ -73,7 +73,7 @@ class FP8QuantizerTest(unittest.TestCase): model_name = "meta-llama/Llama-3.2-1B" input_text = "Once upon a time" max_new_tokens = 10 - EXPECTED_OUTPUT = "Once upon a time, there was a man who was very rich." + EXPECTED_OUTPUT = "Once upon a time, there was a little girl who loved to play" device_map = torch_device offload_device_map = { "model.embed_tokens": 0, From 3ac9276ae5322bb6370c339db9b4b0dad6397a78 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Fri, 28 Nov 2025 15:32:45 +0000 Subject: [PATCH 7/9] style --- src/transformers/quantizers/base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/quantizers/base.py b/src/transformers/quantizers/base.py index 8f08a1b33f69..d8f5609a36f4 100644 --- a/src/transformers/quantizers/base.py +++ b/src/transformers/quantizers/base.py @@ -14,10 +14,11 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any -from ..utils import is_accelerate_available, is_torch_available, logging +from ..utils import is_torch_available, logging from ..utils.quantization_config import QuantizationConfigMixin, QuantizationMethod from .quantizers_utils import get_module_from_name + if TYPE_CHECKING: from ..modeling_utils import PreTrainedModel From f30a51e166c38dedcc3d91ed2d2cb682d20825f1 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Fri, 28 Nov 2025 16:55:31 +0000 Subject: [PATCH 8/9] assertin --- tests/quantization/finegrained_fp8/test_fp8.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/tests/quantization/finegrained_fp8/test_fp8.py b/tests/quantization/finegrained_fp8/test_fp8.py index 81e619cfa93e..8a2521daa8c3 100644 --- a/tests/quantization/finegrained_fp8/test_fp8.py +++ b/tests/quantization/finegrained_fp8/test_fp8.py @@ -73,7 +73,8 @@ class FP8QuantizerTest(unittest.TestCase): model_name = "meta-llama/Llama-3.2-1B" input_text = "Once upon a time" max_new_tokens = 10 - EXPECTED_OUTPUT = "Once upon a time, there was a little girl who loved to play" + EXPECTED_OUTPUTS = {"Once upon a time, there was a little girl who loved to play", "Once upon a time, there was a man who was very rich."} + device_map = torch_device offload_device_map = { "model.embed_tokens": 0, @@ -156,7 +157,7 @@ def test_quantized_model(self): output = self.quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens, do_sample=False) output_tokens = self.tokenizer.decode(output[0], skip_special_tokens=True) - self.assertEqual(output_tokens, self.EXPECTED_OUTPUT) + self.assertIn(output_tokens, self.EXPECTED_OUTPUTS) def test_save_pretrained(self): """ @@ -170,7 +171,7 @@ def test_save_pretrained(self): input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(self.device_map) output = model.generate(**input_ids, max_new_tokens=self.max_new_tokens, do_sample=False) - self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) + self.assertIn(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS) def test_weight_and_weight_scale_inv(self): """ @@ -208,7 +209,7 @@ def test_quantized_model_multi_accelerator(self): self.assertTrue(set(quantized_model.hf_device_map.values()) == {0, 1}) output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens, do_sample=False) - self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) + self.assertIn(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS) @require_torch_multi_accelerator def test_save_pretrained_multi_accelerators(self): @@ -224,7 +225,7 @@ def test_save_pretrained_multi_accelerators(self): input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(self.device_map) output = model.generate(**input_ids, max_new_tokens=self.max_new_tokens, do_sample=False) - self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) + self.assertIn(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS) def test_quantized_model_offload(self): """ @@ -248,7 +249,7 @@ def test_save_pretrained_offload(self): quantized_model = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map=self.offload_device_map) output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens, do_sample=False) - self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) + self.assertIn(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS) @require_torch_accelerator From aea1af6915ef1fffe6c774926db4eed6a5490f76 Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Fri, 28 Nov 2025 16:56:40 +0000 Subject: [PATCH 9/9] style --- tests/quantization/finegrained_fp8/test_fp8.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/quantization/finegrained_fp8/test_fp8.py b/tests/quantization/finegrained_fp8/test_fp8.py index 8a2521daa8c3..2d9137486bcc 100644 --- a/tests/quantization/finegrained_fp8/test_fp8.py +++ b/tests/quantization/finegrained_fp8/test_fp8.py @@ -73,8 +73,11 @@ class FP8QuantizerTest(unittest.TestCase): model_name = "meta-llama/Llama-3.2-1B" input_text = "Once upon a time" max_new_tokens = 10 - EXPECTED_OUTPUTS = {"Once upon a time, there was a little girl who loved to play", "Once upon a time, there was a man who was very rich."} - + EXPECTED_OUTPUTS = { + "Once upon a time, there was a little girl who loved to play", + "Once upon a time, there was a man who was very rich.", + } + device_map = torch_device offload_device_map = { "model.embed_tokens": 0,