Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/transformers/core_model_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,6 +847,7 @@ def convert_and_load_state_dict_in_model(
if hf_quantizer and hf_quantizer.pre_quantized and original_key != renamed_key:
# if the key was renamed as it is not available in the state dict otherwise, it means that we are deserializing it,
# so we need to make sure to load the tensor with the same dtype from the checkpoint
# TODO: make the condition more srict more native fp8 model such as qwen2moe fp8
_dtype = None
elif dtype_plan != {} and dtype_policy_alt.search(renamed_key):
matched_dtype_pattern = dtype_policy_alt.search(renamed_key)
Expand Down
164 changes: 47 additions & 117 deletions src/transformers/integrations/finegrained_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import re
from collections.abc import Sequence
from typing import Any

from ..core_model_loading import ConversionOps
from ..quantizers.quantizers_utils import should_convert_module
from ..utils import is_accelerate_available, is_torch_accelerator_available, is_torch_available, logging


Expand Down Expand Up @@ -307,37 +307,23 @@ def w8a8_block_fp8_matmul_compile(


class FP8Linear(nn.Linear):
dtype = torch.float8_e4m3fn

def __init__(
self,
in_features: int,
out_features: int,
bias: bool = False,
dtype=None,
dtype=torch.float8_e4m3fn,
block_size: tuple[int, int] | None = None,
device=None,
activation_scheme="dynamic",
):
super().__init__(in_features, out_features)
self.in_features = in_features
self.out_features = out_features

self.weight = torch.nn.Parameter(torch.empty(out_features, in_features, dtype=FP8Linear.dtype, device=device))

if self.weight.element_size() == 1:
scale_out_features = (out_features + block_size[0] - 1) // block_size[0]
scale_in_features = (in_features + block_size[1] - 1) // block_size[1]
self.weight_scale_inv = nn.Parameter(
torch.empty(scale_out_features, scale_in_features, dtype=torch.float32, device=device)
)
else:
self.register_parameter("weight_scale_inv", None)

self.block_size = block_size

self.activation_scheme = activation_scheme

self.weight = torch.nn.Parameter(torch.empty(out_features, in_features, dtype=dtype))
scale_out_features = (out_features + block_size[0] - 1) // block_size[0]
scale_in_features = (in_features + block_size[1] - 1) // block_size[1]
self.weight_scale_inv = nn.Parameter(torch.empty(scale_out_features, scale_in_features, dtype=torch.float32))
if bias:
self.bias = nn.Parameter(torch.empty(self.out_features))
else:
Expand Down Expand Up @@ -380,9 +366,7 @@ def _ceil_div(a, b):


class FP8Expert(nn.Module):
dtype = torch.float8_e4m3fn

def __init__(self, config, block_size, device):
def __init__(self, config, block_size, dtype=torch.float8_e4m3fn):
super().__init__()

from ..activations import ACT2FN
Expand All @@ -395,34 +379,24 @@ def __init__(self, config, block_size, device):
Wg_out, Wg_in = 2 * self.intermediate_dim, self.hidden_dim
Wd_out, Wd_in = self.hidden_dim, self.intermediate_dim

self.gate_up_proj = nn.Parameter(
torch.zeros(self.num_experts, Wg_out, Wg_in, dtype=FP8Expert.dtype, device=device)
)
self.down_proj = nn.Parameter(
torch.zeros(self.num_experts, Wd_out, Wd_in, dtype=FP8Expert.dtype, device=device)
)
self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, Wg_out, Wg_in, dtype=dtype))
self.down_proj = nn.Parameter(torch.zeros(self.num_experts, Wd_out, Wd_in, dtype=dtype))

# Create inverse scale tiles only when using 1-byte types (fp8)
if self.gate_up_proj.element_size() == 1:
bo, bi = self.block_size
bo, bi = self.block_size

# gate_up tiles: ceil(Wg_out/bo) x ceil(Wg_in/bi)
gu_scale_o = _ceil_div(Wg_out, bo)
gu_scale_i = _ceil_div(Wg_in, bi)
self.gate_up_proj_scale_inv = nn.Parameter(
torch.zeros(self.num_experts, gu_scale_o, gu_scale_i, dtype=torch.float32, device=device)
)
# gate_up tiles: ceil(Wg_out/bo) x ceil(Wg_in/bi)
gu_scale_o = _ceil_div(Wg_out, bo)
gu_scale_i = _ceil_div(Wg_in, bi)
self.gate_up_proj_scale_inv = nn.Parameter(
torch.zeros(self.num_experts, gu_scale_o, gu_scale_i, dtype=torch.float32)
)

# down tiles: ceil(Wd_out/bo) x ceil(Wd_in/bi)
dp_scale_o = _ceil_div(Wd_out, bo)
dp_scale_i = _ceil_div(Wd_in, bi)
self.down_proj_scale_inv = nn.Parameter(
torch.zeros(self.num_experts, dp_scale_o, dp_scale_i, dtype=torch.float32, device=device)
)
else:
# Match FP8Linear behavior when not using 1-byte weights
self.register_parameter("gate_up_proj_scale_inv", None)
self.register_parameter("down_proj_scale_inv", None)
# down tiles: ceil(Wd_out/bo) x ceil(Wd_in/bi)
dp_scale_o = _ceil_div(Wd_out, bo)
dp_scale_i = _ceil_div(Wd_in, bi)
self.down_proj_scale_inv = nn.Parameter(
torch.zeros(self.num_experts, dp_scale_o, dp_scale_i, dtype=torch.float32)
)

# (Optional) bias per projection — many MoEs omit bias; keep None to match your FP8Linear default
self.register_parameter("gate_up_bias", None)
Expand Down Expand Up @@ -488,87 +462,43 @@ def linear(self, input: torch.Tensor, weight: torch.Tensor, weight_scale_inv: to
return output.to(dtype=input.dtype)


# TODO: we do need this.... but not recursive...
def _replace_with_fp8_linear(
model,
tp_plan=None,
modules_to_not_convert=None,
current_key_name=None,
quantization_config=None,
has_been_replaced=False,
):
iterator = list(model.named_parameters()).copy()
for name, empty_tensor in iterator:
current_key_name = name
name = name.rsplit(".", 1)[0] if "." in name else name
module = model.get_submodule(name)

current_key_name_str = re.sub(r"\d+", "*", current_key_name)
if not any(key in current_key_name_str for key in (modules_to_not_convert or [])):
with init_empty_weights():
if (
"gate_up_proj" in current_key_name
or "down_proj" in current_key_name
and "experts" in current_key_name
): # Experts!
in_features = empty_tensor.size(-2)
out_features = empty_tensor.size(-1)
model.set_submodule(
name,
FP8Expert(
config=model.config,
block_size=quantization_config.weight_block_size,
device=empty_tensor.device,
),
)

elif isinstance(module, nn.Linear):
in_features = module.in_features
out_features = module.out_features
model.set_submodule(
name,
FP8Linear(
in_features=in_features,
out_features=out_features,
bias=module.bias is not None,
device=module.weight.device,
dtype=module.weight.dtype,
activation_scheme=quantization_config.activation_scheme,
block_size=quantization_config.weight_block_size,
),
)
has_been_replaced = True
# when changing a layer the TP PLAN for that layer should be updated. TODO

return model, has_been_replaced


def replace_with_fp8_linear(
model,
modules_to_not_convert=None,
quantization_config=None,
pre_quantized=False,
):
"""Helper function to replace model layers with FP8 versions."""
if modules_to_not_convert is None:
modules_to_not_convert = []
modules_to_not_convert += ["lm_head"]

if quantization_config.modules_to_not_convert is not None:
modules_to_not_convert.extend(quantization_config.modules_to_not_convert)
modules_to_not_convert = list(set(modules_to_not_convert))
model, has_been_replaced = _replace_with_fp8_linear(
model,
tp_plan=model._tp_plan,
modules_to_not_convert=modules_to_not_convert,
quantization_config=quantization_config,
)
has_been_replaced = False
for module_name, module in model.named_modules():
if not should_convert_module(module_name, modules_to_not_convert):
continue
# we need this to correctly materialize the weights during quantization
module_kwargs = {} if pre_quantized else {"dtype": None}
new_module = None
with init_empty_weights():
if "gate_up_proj" in module_name or "down_proj" in module_name and "experts" in module_name:
new_module = FP8Expert(
config=model.config, block_size=quantization_config.weight_block_size, **module_kwargs
)
elif isinstance(module, nn.Linear):
new_module = FP8Linear(
in_features=module.in_features,
out_features=module.out_features,
bias=module.bias is not None,
activation_scheme=quantization_config.activation_scheme,
block_size=quantization_config.weight_block_size,
**module_kwargs,
)
if new_module is not None:
model.set_submodule(module_name, new_module)
has_been_replaced = True

if not has_been_replaced:
logger.warning(
"You are loading your model using fp8 but no linear modules were found in your model."
" Please double check your model architecture."
)

return model


Expand Down
78 changes: 13 additions & 65 deletions src/transformers/integrations/mxfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,9 @@
if is_accelerate_available():
from accelerate import init_empty_weights

import re
from contextlib import contextmanager

from ..quantizers.quantizers_utils import get_module_from_name
from ..quantizers.quantizers_utils import get_module_from_name, should_convert_module


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -436,15 +435,6 @@ def mlp_forward(self, hidden_states):
return routed_out, router_logits


def should_convert_module(current_key_name, patterns):
current_key_name_str = ".".join(current_key_name)
if not any(
re.match(f"{key}\\.", current_key_name_str) or re.match(f"{key}", current_key_name_str) for key in patterns
):
return True
return False


def dequantize(module, param_name, param_value, target_device, dq_param_name, **kwargs):
from ..integrations.tensor_parallel import shard_and_distribute_module

Expand Down Expand Up @@ -604,70 +594,28 @@ def swizzle_mxfp4_convertops(blocks, scales, module, proj, target_device, triton
)


def _replace_with_mxfp4_linear(
model,
modules_to_not_convert=None,
current_key_name=None,
quantization_config=None,
has_been_replaced=False,
config=None,
):
if current_key_name is None:
current_key_name = []
def replace_with_mxfp4_linear(model, modules_to_not_convert=None, quantization_config=None):
if quantization_config.dequantize:
return model

from kernels import get_kernel

global triton_kernels_hub
triton_kernels_hub = get_kernel("kernels-community/triton_kernels")

for name, module in model.named_children():
current_key_name.append(name)
if not should_convert_module(current_key_name, modules_to_not_convert):
current_key_name.pop(-1)
has_been_replaced = False
for module_name, module in model.named_modules():
if not should_convert_module(module_name, modules_to_not_convert):
continue
if module.__class__.__name__ == "GptOssExperts" and not quantization_config.dequantize:
with init_empty_weights():
model._modules[name] = Mxfp4GptOssExperts(config)
model.set_submodule(module_name, Mxfp4GptOssExperts(model.config))
has_been_replaced = True
if module.__class__.__name__ == "GptOssMLP" and not quantization_config.dequantize:
from types import MethodType

module.forward = MethodType(mlp_forward, module)
if len(list(module.children())) > 0:
_, has_been_replaced = _replace_with_mxfp4_linear(
module,
modules_to_not_convert,
current_key_name,
quantization_config,
has_been_replaced=has_been_replaced,
config=config,
)
current_key_name.pop(-1)
return model, has_been_replaced


def replace_with_mxfp4_linear(
model,
modules_to_not_convert=None,
current_key_name=None,
quantization_config=None,
config=None,
):
if quantization_config.dequantize:
return model
else:
from kernels import get_kernel

global triton_kernels_hub
triton_kernels_hub = get_kernel("kernels-community/triton_kernels")

modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert

if quantization_config.modules_to_not_convert is not None:
modules_to_not_convert.extend(quantization_config.modules_to_not_convert)
modules_to_not_convert = list(set(modules_to_not_convert))
model, has_been_replaced = _replace_with_mxfp4_linear(
model,
modules_to_not_convert,
current_key_name,
quantization_config,
config=config,
)
if not has_been_replaced:
logger.warning(
"You are loading your model using mixed-precision FP4 quantization but no linear modules were found in your model."
Expand Down
Loading