diff --git a/src/transformers/core_model_loading.py b/src/transformers/core_model_loading.py index 4505f2a84376..f0e9dffd93a8 100644 --- a/src/transformers/core_model_loading.py +++ b/src/transformers/core_model_loading.py @@ -847,6 +847,7 @@ def convert_and_load_state_dict_in_model( if hf_quantizer and hf_quantizer.pre_quantized and original_key != renamed_key: # if the key was renamed as it is not available in the state dict otherwise, it means that we are deserializing it, # so we need to make sure to load the tensor with the same dtype from the checkpoint + # TODO: make the condition more srict more native fp8 model such as qwen2moe fp8 _dtype = None elif dtype_plan != {} and dtype_policy_alt.search(renamed_key): matched_dtype_pattern = dtype_policy_alt.search(renamed_key) diff --git a/src/transformers/integrations/finegrained_fp8.py b/src/transformers/integrations/finegrained_fp8.py index d35aa9919dfb..9c3b5af4bf33 100644 --- a/src/transformers/integrations/finegrained_fp8.py +++ b/src/transformers/integrations/finegrained_fp8.py @@ -13,11 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -import re from collections.abc import Sequence from typing import Any from ..core_model_loading import ConversionOps +from ..quantizers.quantizers_utils import should_convert_module from ..utils import is_accelerate_available, is_torch_accelerator_available, is_torch_available, logging @@ -307,37 +307,23 @@ def w8a8_block_fp8_matmul_compile( class FP8Linear(nn.Linear): - dtype = torch.float8_e4m3fn - def __init__( self, in_features: int, out_features: int, bias: bool = False, - dtype=None, + dtype=torch.float8_e4m3fn, block_size: tuple[int, int] | None = None, - device=None, activation_scheme="dynamic", ): super().__init__(in_features, out_features) - self.in_features = in_features - self.out_features = out_features - - self.weight = torch.nn.Parameter(torch.empty(out_features, in_features, dtype=FP8Linear.dtype, device=device)) - - if self.weight.element_size() == 1: - scale_out_features = (out_features + block_size[0] - 1) // block_size[0] - scale_in_features = (in_features + block_size[1] - 1) // block_size[1] - self.weight_scale_inv = nn.Parameter( - torch.empty(scale_out_features, scale_in_features, dtype=torch.float32, device=device) - ) - else: - self.register_parameter("weight_scale_inv", None) - self.block_size = block_size - self.activation_scheme = activation_scheme + self.weight = torch.nn.Parameter(torch.empty(out_features, in_features, dtype=dtype)) + scale_out_features = (out_features + block_size[0] - 1) // block_size[0] + scale_in_features = (in_features + block_size[1] - 1) // block_size[1] + self.weight_scale_inv = nn.Parameter(torch.empty(scale_out_features, scale_in_features, dtype=torch.float32)) if bias: self.bias = nn.Parameter(torch.empty(self.out_features)) else: @@ -380,9 +366,7 @@ def _ceil_div(a, b): class FP8Expert(nn.Module): - dtype = torch.float8_e4m3fn - - def __init__(self, config, block_size, device): + def __init__(self, config, block_size, dtype=torch.float8_e4m3fn): super().__init__() from ..activations import ACT2FN @@ -395,34 +379,24 @@ def __init__(self, config, block_size, device): Wg_out, Wg_in = 2 * self.intermediate_dim, self.hidden_dim Wd_out, Wd_in = self.hidden_dim, self.intermediate_dim - self.gate_up_proj = nn.Parameter( - torch.zeros(self.num_experts, Wg_out, Wg_in, dtype=FP8Expert.dtype, device=device) - ) - self.down_proj = nn.Parameter( - torch.zeros(self.num_experts, Wd_out, Wd_in, dtype=FP8Expert.dtype, device=device) - ) + self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, Wg_out, Wg_in, dtype=dtype)) + self.down_proj = nn.Parameter(torch.zeros(self.num_experts, Wd_out, Wd_in, dtype=dtype)) - # Create inverse scale tiles only when using 1-byte types (fp8) - if self.gate_up_proj.element_size() == 1: - bo, bi = self.block_size + bo, bi = self.block_size - # gate_up tiles: ceil(Wg_out/bo) x ceil(Wg_in/bi) - gu_scale_o = _ceil_div(Wg_out, bo) - gu_scale_i = _ceil_div(Wg_in, bi) - self.gate_up_proj_scale_inv = nn.Parameter( - torch.zeros(self.num_experts, gu_scale_o, gu_scale_i, dtype=torch.float32, device=device) - ) + # gate_up tiles: ceil(Wg_out/bo) x ceil(Wg_in/bi) + gu_scale_o = _ceil_div(Wg_out, bo) + gu_scale_i = _ceil_div(Wg_in, bi) + self.gate_up_proj_scale_inv = nn.Parameter( + torch.zeros(self.num_experts, gu_scale_o, gu_scale_i, dtype=torch.float32) + ) - # down tiles: ceil(Wd_out/bo) x ceil(Wd_in/bi) - dp_scale_o = _ceil_div(Wd_out, bo) - dp_scale_i = _ceil_div(Wd_in, bi) - self.down_proj_scale_inv = nn.Parameter( - torch.zeros(self.num_experts, dp_scale_o, dp_scale_i, dtype=torch.float32, device=device) - ) - else: - # Match FP8Linear behavior when not using 1-byte weights - self.register_parameter("gate_up_proj_scale_inv", None) - self.register_parameter("down_proj_scale_inv", None) + # down tiles: ceil(Wd_out/bo) x ceil(Wd_in/bi) + dp_scale_o = _ceil_div(Wd_out, bo) + dp_scale_i = _ceil_div(Wd_in, bi) + self.down_proj_scale_inv = nn.Parameter( + torch.zeros(self.num_experts, dp_scale_o, dp_scale_i, dtype=torch.float32) + ) # (Optional) bias per projection — many MoEs omit bias; keep None to match your FP8Linear default self.register_parameter("gate_up_bias", None) @@ -488,87 +462,43 @@ def linear(self, input: torch.Tensor, weight: torch.Tensor, weight_scale_inv: to return output.to(dtype=input.dtype) -# TODO: we do need this.... but not recursive... -def _replace_with_fp8_linear( - model, - tp_plan=None, - modules_to_not_convert=None, - current_key_name=None, - quantization_config=None, - has_been_replaced=False, -): - iterator = list(model.named_parameters()).copy() - for name, empty_tensor in iterator: - current_key_name = name - name = name.rsplit(".", 1)[0] if "." in name else name - module = model.get_submodule(name) - - current_key_name_str = re.sub(r"\d+", "*", current_key_name) - if not any(key in current_key_name_str for key in (modules_to_not_convert or [])): - with init_empty_weights(): - if ( - "gate_up_proj" in current_key_name - or "down_proj" in current_key_name - and "experts" in current_key_name - ): # Experts! - in_features = empty_tensor.size(-2) - out_features = empty_tensor.size(-1) - model.set_submodule( - name, - FP8Expert( - config=model.config, - block_size=quantization_config.weight_block_size, - device=empty_tensor.device, - ), - ) - - elif isinstance(module, nn.Linear): - in_features = module.in_features - out_features = module.out_features - model.set_submodule( - name, - FP8Linear( - in_features=in_features, - out_features=out_features, - bias=module.bias is not None, - device=module.weight.device, - dtype=module.weight.dtype, - activation_scheme=quantization_config.activation_scheme, - block_size=quantization_config.weight_block_size, - ), - ) - has_been_replaced = True - # when changing a layer the TP PLAN for that layer should be updated. TODO - - return model, has_been_replaced - - def replace_with_fp8_linear( model, modules_to_not_convert=None, quantization_config=None, + pre_quantized=False, ): """Helper function to replace model layers with FP8 versions.""" - if modules_to_not_convert is None: - modules_to_not_convert = [] - modules_to_not_convert += ["lm_head"] - - if quantization_config.modules_to_not_convert is not None: - modules_to_not_convert.extend(quantization_config.modules_to_not_convert) - modules_to_not_convert = list(set(modules_to_not_convert)) - model, has_been_replaced = _replace_with_fp8_linear( - model, - tp_plan=model._tp_plan, - modules_to_not_convert=modules_to_not_convert, - quantization_config=quantization_config, - ) + has_been_replaced = False + for module_name, module in model.named_modules(): + if not should_convert_module(module_name, modules_to_not_convert): + continue + # we need this to correctly materialize the weights during quantization + module_kwargs = {} if pre_quantized else {"dtype": None} + new_module = None + with init_empty_weights(): + if "gate_up_proj" in module_name or "down_proj" in module_name and "experts" in module_name: + new_module = FP8Expert( + config=model.config, block_size=quantization_config.weight_block_size, **module_kwargs + ) + elif isinstance(module, nn.Linear): + new_module = FP8Linear( + in_features=module.in_features, + out_features=module.out_features, + bias=module.bias is not None, + activation_scheme=quantization_config.activation_scheme, + block_size=quantization_config.weight_block_size, + **module_kwargs, + ) + if new_module is not None: + model.set_submodule(module_name, new_module) + has_been_replaced = True if not has_been_replaced: logger.warning( "You are loading your model using fp8 but no linear modules were found in your model." " Please double check your model architecture." ) - return model diff --git a/src/transformers/integrations/mxfp4.py b/src/transformers/integrations/mxfp4.py index 6430c0d9d57d..4f4fee40bf96 100644 --- a/src/transformers/integrations/mxfp4.py +++ b/src/transformers/integrations/mxfp4.py @@ -26,10 +26,9 @@ if is_accelerate_available(): from accelerate import init_empty_weights -import re from contextlib import contextmanager -from ..quantizers.quantizers_utils import get_module_from_name +from ..quantizers.quantizers_utils import get_module_from_name, should_convert_module logger = logging.get_logger(__name__) @@ -436,15 +435,6 @@ def mlp_forward(self, hidden_states): return routed_out, router_logits -def should_convert_module(current_key_name, patterns): - current_key_name_str = ".".join(current_key_name) - if not any( - re.match(f"{key}\\.", current_key_name_str) or re.match(f"{key}", current_key_name_str) for key in patterns - ): - return True - return False - - def dequantize(module, param_name, param_value, target_device, dq_param_name, **kwargs): from ..integrations.tensor_parallel import shard_and_distribute_module @@ -604,70 +594,28 @@ def swizzle_mxfp4_convertops(blocks, scales, module, proj, target_device, triton ) -def _replace_with_mxfp4_linear( - model, - modules_to_not_convert=None, - current_key_name=None, - quantization_config=None, - has_been_replaced=False, - config=None, -): - if current_key_name is None: - current_key_name = [] +def replace_with_mxfp4_linear(model, modules_to_not_convert=None, quantization_config=None): + if quantization_config.dequantize: + return model + + from kernels import get_kernel + + global triton_kernels_hub + triton_kernels_hub = get_kernel("kernels-community/triton_kernels") - for name, module in model.named_children(): - current_key_name.append(name) - if not should_convert_module(current_key_name, modules_to_not_convert): - current_key_name.pop(-1) + has_been_replaced = False + for module_name, module in model.named_modules(): + if not should_convert_module(module_name, modules_to_not_convert): continue if module.__class__.__name__ == "GptOssExperts" and not quantization_config.dequantize: with init_empty_weights(): - model._modules[name] = Mxfp4GptOssExperts(config) + model.set_submodule(module_name, Mxfp4GptOssExperts(model.config)) has_been_replaced = True if module.__class__.__name__ == "GptOssMLP" and not quantization_config.dequantize: from types import MethodType module.forward = MethodType(mlp_forward, module) - if len(list(module.children())) > 0: - _, has_been_replaced = _replace_with_mxfp4_linear( - module, - modules_to_not_convert, - current_key_name, - quantization_config, - has_been_replaced=has_been_replaced, - config=config, - ) - current_key_name.pop(-1) - return model, has_been_replaced - -def replace_with_mxfp4_linear( - model, - modules_to_not_convert=None, - current_key_name=None, - quantization_config=None, - config=None, -): - if quantization_config.dequantize: - return model - else: - from kernels import get_kernel - - global triton_kernels_hub - triton_kernels_hub = get_kernel("kernels-community/triton_kernels") - - modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert - - if quantization_config.modules_to_not_convert is not None: - modules_to_not_convert.extend(quantization_config.modules_to_not_convert) - modules_to_not_convert = list(set(modules_to_not_convert)) - model, has_been_replaced = _replace_with_mxfp4_linear( - model, - modules_to_not_convert, - current_key_name, - quantization_config, - config=config, - ) if not has_been_replaced: logger.warning( "You are loading your model using mixed-precision FP4 quantization but no linear modules were found in your model." diff --git a/src/transformers/quantizers/base.py b/src/transformers/quantizers/base.py index b0a5873da303..d8f5609a36f4 100644 --- a/src/transformers/quantizers/base.py +++ b/src/transformers/quantizers/base.py @@ -12,17 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC, abstractmethod -from copy import deepcopy from typing import TYPE_CHECKING, Any -from ..utils import is_accelerate_available, is_torch_available, logging +from ..utils import is_torch_available, logging from ..utils.quantization_config import QuantizationConfigMixin, QuantizationMethod from .quantizers_utils import get_module_from_name -if is_accelerate_available(): - from accelerate.utils import find_tied_parameters - if TYPE_CHECKING: from ..modeling_utils import PreTrainedModel @@ -47,48 +43,28 @@ def _assign_original_dtype(module, original_dtype): def get_keys_to_not_convert(model): r""" - An utility function to get the key of the module to keep in full precision if any For example for CausalLM modules - we may want to keep the lm_head in full precision for numerical stability reasons. For other architectures, we want - to keep the tied weights of the model. The function will return a list of the keys of the modules to not convert in - int8. - - Parameters: - model (`torch.nn.Module`): - Input model + Function to automatically detect keys to not convert for usage like quantization. For example for CausalLM modules + we may want to keep the lm_head in full precision for numerical stability reasons. """ - # Create a copy of the model and tie the weights, then - # check if it contains tied weights - tied_model = deepcopy(model) # this has 0 cost since it is done inside `init_empty_weights` context manager` - tied_model.tie_weights() - - tied_params = find_tied_parameters(tied_model) - tied_keys = sum(tied_params, []) - has_tied_params = len(tied_keys) > 0 - - # If there is not tied weights, we want to keep the lm_head(output_embedding) in full precision - if not has_tied_params: - output_emb = model.get_output_embeddings() - if output_emb is not None: - list_last_module = [name for name, module in model.named_modules() if id(module) == id(output_emb)] - return list_last_module - - # otherwise, no tied weights, no output embedding defined, simply keep the last module in full precision - list_modules = list(model.named_parameters()) - list_last_module = [list_modules[-1][0]] - # add last module together with tied weights - intersection = set(list_last_module) - set(tied_keys) - list_untouched = list(set(tied_keys)) + list(intersection) - - # remove ".weight" from the keys - names_to_remove = [".weight", ".bias"] - filtered_module_names = [] - for name in list_untouched: - for name_to_remove in names_to_remove: - if name_to_remove in name: - name = name.replace(name_to_remove, "") - filtered_module_names.append(name) - - return filtered_module_names + # remove tied weights + tied_keys = set() + if len(model.all_tied_weights_keys) > 0: + tied_keys = set(model.all_tied_weights_keys.values()) | set(model.all_tied_weights_keys.keys()) + + # remove last module + last_module_key = {list(model.named_parameters())[-1][0]} + + # remove output emb + output_emb_module = model.get_output_embeddings() + output_emb_keys = { + name + for name, module in model.named_modules() + if output_emb_module is not None and id(module) == id(output_emb_module) + } + candidates = tied_keys | last_module_key | output_emb_keys + + modules_to_not_convert = {name.replace(suffix, "") for name in candidates for suffix in [".weight", ".bias"]} + return modules_to_not_convert class HfQuantizer(ABC): @@ -360,6 +336,8 @@ def get_modules_to_not_convert( if keep_in_fp32_modules is not None: modules_to_not_convert.extend(keep_in_fp32_modules) + modules_to_not_convert = list(set(modules_to_not_convert)) + return modules_to_not_convert @property diff --git a/src/transformers/quantizers/quantizer_finegrained_fp8.py b/src/transformers/quantizers/quantizer_finegrained_fp8.py index 293dede7dc52..190263b292d1 100644 --- a/src/transformers/quantizers/quantizer_finegrained_fp8.py +++ b/src/transformers/quantizers/quantizer_finegrained_fp8.py @@ -158,16 +158,15 @@ def _process_model_before_weight_loading( ): from ..integrations.finegrained_fp8 import replace_with_fp8_linear - # takes 2 fucking seconds self.modules_to_not_convert = self.get_modules_to_not_convert( model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules ) - # while this one is 81ms :) model = replace_with_fp8_linear( model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config, + pre_quantized=self.pre_quantized, ) model.config.quantization_config = self.quantization_config diff --git a/src/transformers/quantizers/quantizer_mxfp4.py b/src/transformers/quantizers/quantizer_mxfp4.py index 89926e4f6339..e5e70af1ab6b 100644 --- a/src/transformers/quantizers/quantizer_mxfp4.py +++ b/src/transformers/quantizers/quantizer_mxfp4.py @@ -301,6 +301,7 @@ def _process_model_before_weight_loading( self, model: "PreTrainedModel", keep_in_fp32_modules: list[str] | None = None, + use_kernels: bool = False, **kwargs, ): from ..integrations import replace_with_mxfp4_linear @@ -309,7 +310,6 @@ def _process_model_before_weight_loading( model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules ) - use_kernels = kwargs.get("use_kernels", False) # if we are using kernels, we can't use the quantized model, since the forward pass is different and needs special handling if use_kernels: logger.warning_once( @@ -318,12 +318,8 @@ def _process_model_before_weight_loading( ) self.quantization_config.dequantize = True - config = model.config model = replace_with_mxfp4_linear( - model, - modules_to_not_convert=self.modules_to_not_convert, - quantization_config=self.quantization_config, - config=config, + model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config ) model.config.quantization_config = self.quantization_config diff --git a/src/transformers/quantizers/quantizers_utils.py b/src/transformers/quantizers/quantizers_utils.py index 94955495f561..5e3429b602f8 100644 --- a/src/transformers/quantizers/quantizers_utils.py +++ b/src/transformers/quantizers/quantizers_utils.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any +import re +from typing import Any, Optional def get_module_from_name(module, tensor_name: str) -> tuple[Any, str]: @@ -19,3 +20,22 @@ def get_module_from_name(module, tensor_name: str) -> tuple[Any, str]: module_name, tensor_name = tensor_name.rsplit(".", 1) module = module.get_submodule(module_name) return module, tensor_name + + +def should_convert_module(full_name, patterns: Optional[list] = None): + if patterns is None: + return True + + # We should avoid converting in the following situations: + # 1. The pattern appears as a prefix followed by a dot in `full_name` + # (e.g., "model.decoder.layer.11." matches "model.decoder.layer.11.attn.weight"). + # 2. The pattern matches `full_name` exactly or via regex + # (e.g., "lm_head" matches "lm_head"; "model.decoder.layer.*" matches "model.decoder.layer.11.attn.weight"). + # 3. `full_name` ends with the pattern + # (e.g., "fc1" matches "model.decoder.layers.23.fc1"). + + should_not_convert = any( + re.match(f"{key}\\.", full_name) or re.match(f"{key}", full_name) or full_name.endswith(key) + for key in patterns + ) + return not should_not_convert diff --git a/tests/quantization/finegrained_fp8/test_fp8.py b/tests/quantization/finegrained_fp8/test_fp8.py index b05ee61dcb38..2d9137486bcc 100644 --- a/tests/quantization/finegrained_fp8/test_fp8.py +++ b/tests/quantization/finegrained_fp8/test_fp8.py @@ -73,7 +73,11 @@ class FP8QuantizerTest(unittest.TestCase): model_name = "meta-llama/Llama-3.2-1B" input_text = "Once upon a time" max_new_tokens = 10 - EXPECTED_OUTPUT = "Once upon a time, there was a man who was very rich." + EXPECTED_OUTPUTS = { + "Once upon a time, there was a little girl who loved to play", + "Once upon a time, there was a man who was very rich.", + } + device_map = torch_device offload_device_map = { "model.embed_tokens": 0, @@ -132,25 +136,21 @@ def test_quantized_model_conversion(self): for module in model.modules(): if isinstance(module, torch.nn.Linear): nb_linears += 1 - model = replace_with_fp8_linear(model, quantization_config=quantization_config) nb_fp8_linear = 0 for module in model.modules(): if isinstance(module, FP8Linear): nb_fp8_linear += 1 - print(model) - self.assertEqual(nb_linears - 1, nb_fp8_linear) - + self.assertEqual(nb_linears, nb_fp8_linear) with init_empty_weights(): model = OPTForCausalLM(config) - quantization_config = FineGrainedFP8Config(modules_to_not_convert=["fc1"]) - model = replace_with_fp8_linear(model, quantization_config=quantization_config) + quantization_config = FineGrainedFP8Config() + model = replace_with_fp8_linear(model, modules_to_not_convert=["fc1"], quantization_config=quantization_config) nb_fp8_linear = 0 for module in model.modules(): if isinstance(module, FP8Linear): nb_fp8_linear += 1 - - self.assertEqual(nb_linears - 25, nb_fp8_linear) + self.assertEqual(nb_linears - 24, nb_fp8_linear) def test_quantized_model(self): """ @@ -160,7 +160,7 @@ def test_quantized_model(self): output = self.quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens, do_sample=False) output_tokens = self.tokenizer.decode(output[0], skip_special_tokens=True) - self.assertEqual(output_tokens, self.EXPECTED_OUTPUT) + self.assertIn(output_tokens, self.EXPECTED_OUTPUTS) def test_save_pretrained(self): """ @@ -174,7 +174,7 @@ def test_save_pretrained(self): input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(self.device_map) output = model.generate(**input_ids, max_new_tokens=self.max_new_tokens, do_sample=False) - self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) + self.assertIn(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS) def test_weight_and_weight_scale_inv(self): """ @@ -209,11 +209,10 @@ def test_quantized_model_multi_accelerator(self): quantized_model = AutoModelForCausalLM.from_pretrained( self.model_name, device_map="auto", quantization_config=quantization_config ) - print("hf_device_map", quantized_model.hf_device_map) self.assertTrue(set(quantized_model.hf_device_map.values()) == {0, 1}) output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens, do_sample=False) - self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) + self.assertIn(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS) @require_torch_multi_accelerator def test_save_pretrained_multi_accelerators(self): @@ -229,7 +228,7 @@ def test_save_pretrained_multi_accelerators(self): input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(self.device_map) output = model.generate(**input_ids, max_new_tokens=self.max_new_tokens, do_sample=False) - self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) + self.assertIn(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS) def test_quantized_model_offload(self): """ @@ -253,7 +252,7 @@ def test_save_pretrained_offload(self): quantized_model = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map=self.offload_device_map) output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens, do_sample=False) - self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) + self.assertIn(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS) @require_torch_accelerator @@ -271,7 +270,7 @@ def test_linear_preserves_shape(self): """ from transformers.integrations import FP8Linear - linear = FP8Linear(256, 256, block_size=(128, 128), device=self.device) + linear = FP8Linear(256, 256, block_size=(128, 128)).to(self.device) x = torch.rand((1, 5, 256)).to(self.device) x_ = linear(x) @@ -283,7 +282,7 @@ def test_linear_with_diff_feature_size_preserves_shape(self): """ from transformers.integrations import FP8Linear - linear = FP8Linear(128, 256, block_size=(128, 128), device=self.device) + linear = FP8Linear(128, 256, block_size=(128, 128)).to(self.device) x = torch.rand((1, 5, 128)).to(self.device) x_ = linear(x)