|
14 | 14 | # limitations under the License. |
15 | 15 |
|
16 | 16 | import re |
17 | | -from collections.abc import Sequence |
18 | | -from typing import Any |
| 17 | +from typing import Optional |
19 | 18 |
|
20 | 19 | from ..core_model_loading import ConversionOps |
21 | 20 | from ..utils import is_accelerate_available, is_torch_accelerator_available, is_torch_available, logging |
@@ -549,6 +548,9 @@ def replace_with_fp8_linear( |
549 | 548 | quantization_config=None, |
550 | 549 | ): |
551 | 550 | """Helper function to replace model layers with FP8 versions.""" |
| 551 | + if quantization_config.dequantize: |
| 552 | + return model |
| 553 | + |
552 | 554 | if modules_to_not_convert is None: |
553 | 555 | modules_to_not_convert = [] |
554 | 556 | modules_to_not_convert += ["lm_head"] |
@@ -652,41 +654,45 @@ def convert(self, input_dict: torch.Tensor, **kwargs) -> dict[str, torch.Tensor] |
652 | 654 | class Fp8Dequantize(ConversionOps): |
653 | 655 | """Inverse operation of :class:`Fp8Quantize`. Takes a pair (weight, scale) and reconstructs the fp32 tensor.""" |
654 | 656 |
|
655 | | - def __init__(self, block_size: tuple[int, int] | None = None): |
656 | | - self.block_size = block_size |
| 657 | + def __init__(self, hf_quantizer): |
| 658 | + self.hf_quantizer = hf_quantizer |
657 | 659 |
|
658 | 660 | def convert( |
659 | 661 | self, |
660 | | - value: Sequence[torch.Tensor] | dict[str, torch.Tensor], |
661 | | - *, |
662 | | - context: dict[str, Any], |
663 | | - ) -> torch.Tensor: |
664 | | - if isinstance(value, dict): |
665 | | - tensors = list(value.values()) |
666 | | - else: |
667 | | - tensors = list(value) if isinstance(value, Sequence) else [value] |
668 | | - if len(tensors) != 2: |
669 | | - raise ValueError("Fp8Dequantize expects exactly two tensors: quantized weights and scales.") |
670 | | - quantized, scales = tensors |
671 | | - if not isinstance(quantized, torch.Tensor) or not isinstance(scales, torch.Tensor): |
672 | | - raise TypeError("Fp8Dequantize expects tensors as inputs.") |
673 | | - |
674 | | - quantized_fp32 = quantized.to(torch.float32) |
675 | | - rows, cols = quantized_fp32.shape[-2:] |
676 | | - block_size = self.block_size |
677 | | - if block_size is None: |
678 | | - quant_config = context.get("quantization_config") |
679 | | - block_size = getattr(quant_config, "weight_block_size", None) |
680 | | - if block_size is None: |
681 | | - block_size = (rows, cols) |
| 662 | + input_dict: dict[str, torch.Tensor], |
| 663 | + model: Optional[torch.nn.Module] = None, |
| 664 | + full_layer_name: str | None = None, |
| 665 | + missing_keys=None, |
| 666 | + **kwargs, |
| 667 | + ) -> dict[str, torch.Tensor]: |
| 668 | + if len(input_dict) != 2: |
| 669 | + # in case of no scales, the weights are not quantized, so we return the weights as is |
| 670 | + return { |
| 671 | + full_layer_name: input_dict["weight$"][0] |
| 672 | + if isinstance(input_dict["weight$"], list) |
| 673 | + else input_dict["weight$"] |
| 674 | + } |
| 675 | + quantized = input_dict["weight$"][0] if isinstance(input_dict["weight$"], list) else input_dict["weight$"] |
| 676 | + scales = ( |
| 677 | + input_dict["weight_scale_inv"][0] |
| 678 | + if isinstance(input_dict["weight_scale_inv"], list) |
| 679 | + else input_dict["weight_scale_inv"] |
| 680 | + ) |
| 681 | + |
| 682 | + rows, cols = quantized.shape[-2:] |
| 683 | + block_size = self.hf_quantizer.quantization_config.weight_block_size |
| 684 | + |
682 | 685 | block_m, block_n = block_size |
683 | 686 | if rows % block_m != 0 or cols % block_n != 0: |
684 | 687 | raise ValueError( |
685 | 688 | f"Matrix dimensions ({rows}, {cols}) must be divisible by block sizes ({block_m}, {block_n})." |
686 | 689 | ) |
687 | 690 |
|
688 | | - reshaped = quantized_fp32.reshape(-1, rows // block_m, block_m, cols // block_n, block_n) |
| 691 | + reshaped = quantized.reshape(-1, rows // block_m, block_m, cols // block_n, block_n) |
689 | 692 | expanded_scales = scales.to(torch.float32).reshape(-1, rows // block_m, cols // block_n) |
690 | 693 | expanded_scales = expanded_scales.unsqueeze(-1).unsqueeze(2) |
691 | 694 | dequantized = reshaped * expanded_scales |
692 | | - return dequantized.reshape(quantized_fp32.shape) |
| 695 | + |
| 696 | + return { |
| 697 | + full_layer_name: dequantized.reshape(quantized.shape), |
| 698 | + } |
0 commit comments