From 7465d6090704815c6f1eabf4b3cb3857ad36a486 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 18 Nov 2025 17:36:44 +0000 Subject: [PATCH] wrap better Signed-off-by: Kyle Sayers --- .../quantization/lifecycle/forward.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index 573826e1..ca060cf9 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -363,28 +363,28 @@ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme): @wraps(forward_func_orig) # ensures docstring, names, etc are propagated def wrapped_forward(self, *args, **kwargs): - if not getattr(module, "quantization_enabled", True): + if not getattr(self, "quantization_enabled", True): # quantization is disabled on forward passes, return baseline # forward call - return forward_func_orig.__get__(module, module.__class__)(*args, **kwargs) + return forward_func_orig.__get__(self, self.__class__)(*args, **kwargs) input_ = args[0] - compressed = module.quantization_status == QuantizationStatus.COMPRESSED + compressed = self.quantization_status == QuantizationStatus.COMPRESSED if scheme.input_activations is not None: # prehook should calibrate activations before forward call - input_ = forward_quantize(module, input_, "input", scheme.input_activations) + input_ = forward_quantize(self, input_, "input", scheme.input_activations) if scheme.weights is not None and not compressed: # calibrate and (fake) quantize weights when applicable unquantized_weight = self.weight.data.clone() self.weight.data = forward_quantize( - module, self.weight, "weight", scheme.weights + self, self.weight, "weight", scheme.weights ) # perform wrapped forward call - output = forward_func_orig.__get__(module, module.__class__)( + output = forward_func_orig.__get__(self, self.__class__)( input_, *args[1:], **kwargs ) @@ -395,14 +395,12 @@ def wrapped_forward(self, *args, **kwargs): if scheme.output_activations is not None: # forward-hook should calibrate/forward_quantize if ( - module.quantization_status == QuantizationStatus.CALIBRATION + self.quantization_status == QuantizationStatus.CALIBRATION and not scheme.output_activations.dynamic ): return output - output = forward_quantize( - module, output, "output", scheme.output_activations - ) + output = forward_quantize(self, output, "output", scheme.output_activations) return output # bind wrapped forward to module class so reference to `self` is correct