Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 8 additions & 10 deletions src/compressed_tensors/quantization/lifecycle/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,28 +363,28 @@ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):

@wraps(forward_func_orig) # ensures docstring, names, etc are propagated
def wrapped_forward(self, *args, **kwargs):
if not getattr(module, "quantization_enabled", True):
if not getattr(self, "quantization_enabled", True):
# quantization is disabled on forward passes, return baseline
# forward call
return forward_func_orig.__get__(module, module.__class__)(*args, **kwargs)
return forward_func_orig.__get__(self, self.__class__)(*args, **kwargs)

input_ = args[0]

compressed = module.quantization_status == QuantizationStatus.COMPRESSED
compressed = self.quantization_status == QuantizationStatus.COMPRESSED

if scheme.input_activations is not None:
# prehook should calibrate activations before forward call
input_ = forward_quantize(module, input_, "input", scheme.input_activations)
input_ = forward_quantize(self, input_, "input", scheme.input_activations)

if scheme.weights is not None and not compressed:
# calibrate and (fake) quantize weights when applicable
unquantized_weight = self.weight.data.clone()
self.weight.data = forward_quantize(
module, self.weight, "weight", scheme.weights
self, self.weight, "weight", scheme.weights
)

# perform wrapped forward call
output = forward_func_orig.__get__(module, module.__class__)(
output = forward_func_orig.__get__(self, self.__class__)(
input_, *args[1:], **kwargs
)

Expand All @@ -395,14 +395,12 @@ def wrapped_forward(self, *args, **kwargs):
if scheme.output_activations is not None:
# forward-hook should calibrate/forward_quantize
if (
module.quantization_status == QuantizationStatus.CALIBRATION
self.quantization_status == QuantizationStatus.CALIBRATION
and not scheme.output_activations.dynamic
):
return output

output = forward_quantize(
module, output, "output", scheme.output_activations
)
output = forward_quantize(self, output, "output", scheme.output_activations)
return output

# bind wrapped forward to module class so reference to `self` is correct
Expand Down