From 27e22bdb8428ed3457cd2b61333c54b1bacf2822 Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Wed, 28 Jan 2026 21:30:03 +0000 Subject: [PATCH 1/5] Update on the quantlinear & dynamic module Signed-off-by: Jingyu Xin --- modelopt/torch/opt/dynamic.py | 12 +++++++++++ .../quantization/nn/modules/quant_module.py | 20 ++++++++++++++++++- 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/modelopt/torch/opt/dynamic.py b/modelopt/torch/opt/dynamic.py index a2834329e..8950d1e5c 100644 --- a/modelopt/torch/opt/dynamic.py +++ b/modelopt/torch/opt/dynamic.py @@ -584,6 +584,14 @@ def export(self) -> nn.Module: assert not is_dynamic, "Exported module must not be a DynamicModule anymore!" delattr(self, "_dm_attribute_manager") + # If this module had a monkey-patched forward before DynamicModule.convert(), we may have + # overridden it by binding the dynamic forward onto the instance (to follow the MRO). + # On final export, restore the original forward to avoid leaking a dynamic forward + # (e.g., DistillationModel.forward) onto the exported (non-dynamic) module instance. + if hasattr(self, "_forward_pre_dm"): + setattr(self, "forward", getattr(self, "_forward_pre_dm")) + delattr(self, "_forward_pre_dm") + return self @classmethod @@ -621,6 +629,10 @@ def bind_forward_method_if_needed(self): # accelerate patched module bind_forward_method(self, self.__class__.forward) else: + if not hasattr(self, "_forward_pre_dm"): + # Keep the patched forward for downstream modules that want to call it. + self._forward_pre_dm = self.forward + bind_forward_method(self, self.__class__.forward) warnings.warn( "Received a module with monkey patched forward method. Dynamic converted module" " might not work." diff --git a/modelopt/torch/quantization/nn/modules/quant_module.py b/modelopt/torch/quantization/nn/modules/quant_module.py index 12aaee3f8..e00e7c77d 100644 --- a/modelopt/torch/quantization/nn/modules/quant_module.py +++ b/modelopt/torch/quantization/nn/modules/quant_module.py @@ -110,7 +110,25 @@ class QuantInputBase(QuantModule): def forward(self, input, *args, **kwargs): """Quantize the input before calling the original forward method.""" input = self.input_quantizer(input) - output = super().forward(input, *args, **kwargs) + if hasattr(self, "_forward_pre_dm"): + pre_fwd = getattr(self, "_forward_pre_dm") + + def _is_forward_in_mro(bound_or_func) -> bool: + # If this is a bound method, compare its underlying function to any `forward` + # implementation in the current MRO. If it matches, it's not an external monkey-patch. + if hasattr(bound_or_func, "__func__"): + fn = bound_or_func.__func__ + for cls in type(self).mro(): + if cls.__dict__.get("forward") is fn: + return True + return False + + if pre_fwd is getattr(self, "forward") or _is_forward_in_mro(pre_fwd): + output = super().forward(input, *args, **kwargs) + else: + output = pre_fwd(input, *args, **kwargs) + else: + output = super().forward(input, *args, **kwargs) if isinstance(output, tuple): return (self.output_quantizer(output[0]), *output[1:]) return self.output_quantizer(output) From 54ef5d8390af1a56d1df6d405fc6d2f6f877ed40 Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Thu, 29 Jan 2026 23:10:33 +0000 Subject: [PATCH 2/5] Update the test case Signed-off-by: Jingyu Xin --- tests/unit/torch/opt/test_chaining.py | 106 ++++++++++++++++++ .../quantization/test_forward_patching.py | 93 +++++++++++++++ 2 files changed, 199 insertions(+) create mode 100644 tests/unit/torch/quantization/test_forward_patching.py diff --git a/tests/unit/torch/opt/test_chaining.py b/tests/unit/torch/opt/test_chaining.py index bedbbfee0..0025f0346 100644 --- a/tests/unit/torch/opt/test_chaining.py +++ b/tests/unit/torch/opt/test_chaining.py @@ -15,6 +15,7 @@ import pytest import torch +import torch.nn.functional as F from _test_utils.torch.misc import compare_outputs from _test_utils.torch.opt.utils import apply_mode_with_sampling from torchvision.models.mobilenetv2 import InvertedResidual @@ -22,10 +23,20 @@ import modelopt.torch.distill as mtd import modelopt.torch.nas as mtn import modelopt.torch.opt as mto +import modelopt.torch.quantization as mtq import modelopt.torch.sparsity as mts from modelopt.torch.utils.distributed import _serialize +class SimpleLinearModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(4, 4) + + def forward(self, x): + return self.linear(x) + + def get_model(): return InvertedResidual(16, 32, 1, 6) @@ -228,3 +239,98 @@ def test_sparse_quantized_module(): model = mtn.export(model) assert torch.equal(conv.weight, weight_expected) assert torch.equal(conv._parameters["weight"], weight_expected), "Weight should be overwritten!" + + +def test_sparse_quantize_kd_linear_forward_backward(): + """Ensure sparse + quantize + distill works for linear forward/backward.""" + model = SimpleLinearModel() + teacher_model = SimpleLinearModel() + + def patched_forward(x): + called["patched_forward"] += 1 + w = model.linear.weight + b = model.linear.bias if model.linear.bias is not None else None + return F.linear(x, w, b) + + model.linear.forward = patched_forward + teacher_model.linear.forward = patched_forward + + def _get_linear_kd_mode(): + config = { + "teacher_model": teacher_model, + "criterion": {("linear", "linear"): mtd.LogitsDistillationLoss()}, + "loss_balancer": mtd.StaticLossBalancer(), + } + return [("kd_loss", config)] + + model = mto.apply_mode(model, mode="sparse_magnitude", init_state=True) + model = mto.apply_mode(model, mode="quantize") + model = mto.apply_mode(model, mode=_get_linear_kd_mode()) + + called = {"patched_forward": 0, "input_q": 0, "weight_q": 0, "pass": 0} + + def _count_quant_input(_m, _inp, _out): + called["input_q"] += 1 + + def _count_quant_weight(_m, _inp, _out): + called["weight_q"] += 1 + + model.linear.input_quantizer.register_forward_hook(_count_quant_input) + model.linear.weight_quantizer.register_forward_hook(_count_quant_weight) + + model.train() + x = torch.randn(2, 4) + target = torch.randn(2, 4) + output = model(x) + loss = F.mse_loss(output, target) + loss.backward() + + assert output.shape == target.shape + assert any(p.grad is not None for p in model.parameters() if p.requires_grad), ( + "Expected gradients on student parameters." + ) + assert called["patched_forward"] == 2 + assert called["input_q"] == 1 + assert called["weight_q"] == 1 + + +def test_chained_modes_preserve_forward_patching_during_quantize(): + """Ensure chained modes do not break runtime forward patching during quantize.""" + model = InvertedResidual(16, 32, 1, 6).to(torch.float16) + model = mto.apply_mode(model, mode="fastnas", init_state=True) + model = mto.apply_mode(model, mode="export_nas") + + conv = model.conv[0][0] + called = {"patched_forward": 0, "input_q": 0, "weight_q": 0} + + def patched_forward(x): + called["patched_forward"] += 1 + return F.conv2d( + x, + conv.weight, + conv.bias, + conv.stride, + conv.padding, + conv.dilation, + conv.groups, + ) + + conv.forward = patched_forward + + def _count_input(_m, _inp, _out): + called["input_q"] += 1 + + def _count_weight(_m, _inp, _out): + called["weight_q"] += 1 + + def forward_loop(model): + conv.input_quantizer.register_forward_hook(_count_input) + conv.weight_quantizer.register_forward_hook(_count_weight) + x = torch.randn(1, 16, 8, 8, dtype=torch.float16) + model(x) + + mtq.quantize(model, mtq.INT8_DEFAULT_CFG, forward_loop) + + assert called["patched_forward"] == 1 + assert called["input_q"] == 1 + assert called["weight_q"] == 1 diff --git a/tests/unit/torch/quantization/test_forward_patching.py b/tests/unit/torch/quantization/test_forward_patching.py new file mode 100644 index 000000000..a28a1bce6 --- /dev/null +++ b/tests/unit/torch/quantization/test_forward_patching.py @@ -0,0 +1,93 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import types + +import torch +import torch.nn.functional as F +from torch import nn + +import modelopt.torch.quantization as mtq +from modelopt.torch.quantization import QuantModuleRegistry +from modelopt.torch.quantization.nn.modules.quant_module import QuantLinearConvBase + + +def test_quant_input_base_ignores_forward_pre_dm_in_mro(): + """Regression test for recursion when `_forward_pre_dm` points to a wrapper forward in the MRO. + + In complex wrapper stacks, `_forward_pre_dm` may accidentally end up referencing a `forward` + method already present in the quant wrapper MRO (e.g. QuantLinearConvBase.forward). If + QuantInputBase.forward calls that directly, it can recurse indefinitely: + + QuantLinearConvBase.forward -> super().forward (QuantInputBase.forward) + -> _forward_pre_dm (QuantLinearConvBase.forward) -> ... + + The fix is to detect this case and fall back to `super().forward` instead. + """ + lin = nn.Linear(8, 8, bias=False) + QuantModuleRegistry.convert(lin) + + # Force the problematic state: `_forward_pre_dm` points to a wrapper forward already in MRO. + lin._forward_pre_dm = types.MethodType(QuantLinearConvBase.forward, lin) + + x = torch.randn(2, 8) + y = lin(x) + assert isinstance(y, torch.Tensor) + assert y.shape == (2, 8) + + +def test_quantize_calibration_calls_quantizers_with_runtime_forward_patch(): + """Regression test for on-the-fly forward patching during mtq.quantize calibration. + + Some frameworks replace `module.forward` on-the-fly with a closure just before a forward pass. + During mtq.quantize calibration, quantizers must still run (input + weight at minimum). + """ + lin = nn.Linear(8, 8, bias=True).to(torch.float32) + + called = {"patched_forward": 0, "input_q": 0, "weight_q": 0} + + # Monkey patch instance-level forward (closure-style, no `self` argument). + def patched_forward(x): + called["patched_forward"] += 1 + # Use module parameters directly; if quantization wrappers are active, weight access + # should still be routed through the quantized path. + w = lin.weight.to(dtype=x.dtype) + b = lin.bias.to(dtype=x.dtype) if lin.bias is not None else None + return F.linear(x, w, b) + + def _count_input(_m, _inp, _out): + called["input_q"] += 1 + + def _count_weight(_m, _inp, _out): + called["weight_q"] += 1 + + lin.forward = patched_forward + x = torch.randn(2, 8, dtype=torch.float16) + + def forward_loop(model): + # Patch forward on-the-fly (after conversion, right before calibration forward). + + # Count quantizer executions during calibration. + model.input_quantizer.register_forward_hook(_count_input) + model.weight_quantizer.register_forward_hook(_count_weight) + + model(x) + + mtq.quantize(lin, mtq.INT8_DEFAULT_CFG, forward_loop) + lin(x) + + assert called["patched_forward"] == 2 + assert called["input_q"] == 2 + assert called["weight_q"] == 2 From 1226cceb1eb9ec97f1d27bc6656f93e9f5c4613b Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Thu, 29 Jan 2026 23:12:23 +0000 Subject: [PATCH 3/5] Add the MR link Signed-off-by: Jingyu Xin --- modelopt/torch/opt/dynamic.py | 2 ++ modelopt/torch/quantization/nn/modules/quant_module.py | 1 + 2 files changed, 3 insertions(+) diff --git a/modelopt/torch/opt/dynamic.py b/modelopt/torch/opt/dynamic.py index 8950d1e5c..2131639be 100644 --- a/modelopt/torch/opt/dynamic.py +++ b/modelopt/torch/opt/dynamic.py @@ -588,6 +588,7 @@ def export(self) -> nn.Module: # overridden it by binding the dynamic forward onto the instance (to follow the MRO). # On final export, restore the original forward to avoid leaking a dynamic forward # (e.g., DistillationModel.forward) onto the exported (non-dynamic) module instance. + # please see: https://github.com/NVIDIA/Model-Optimizer/pull/824 if hasattr(self, "_forward_pre_dm"): setattr(self, "forward", getattr(self, "_forward_pre_dm")) delattr(self, "_forward_pre_dm") @@ -629,6 +630,7 @@ def bind_forward_method_if_needed(self): # accelerate patched module bind_forward_method(self, self.__class__.forward) else: + # https://github.com/NVIDIA/Model-Optimizer/pull/824 if not hasattr(self, "_forward_pre_dm"): # Keep the patched forward for downstream modules that want to call it. self._forward_pre_dm = self.forward diff --git a/modelopt/torch/quantization/nn/modules/quant_module.py b/modelopt/torch/quantization/nn/modules/quant_module.py index e00e7c77d..f7bfff243 100644 --- a/modelopt/torch/quantization/nn/modules/quant_module.py +++ b/modelopt/torch/quantization/nn/modules/quant_module.py @@ -110,6 +110,7 @@ class QuantInputBase(QuantModule): def forward(self, input, *args, **kwargs): """Quantize the input before calling the original forward method.""" input = self.input_quantizer(input) + # Check MR: https://github.com/NVIDIA/Model-Optimizer/pull/824 if hasattr(self, "_forward_pre_dm"): pre_fwd = getattr(self, "_forward_pre_dm") From 2acbb626c7f718e8a0605c41f81631a81f4400b7 Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Thu, 29 Jan 2026 23:29:19 +0000 Subject: [PATCH 4/5] update the test case Signed-off-by: Jingyu Xin --- tests/unit/torch/opt/test_chaining.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/tests/unit/torch/opt/test_chaining.py b/tests/unit/torch/opt/test_chaining.py index 0025f0346..3bc294f3b 100644 --- a/tests/unit/torch/opt/test_chaining.py +++ b/tests/unit/torch/opt/test_chaining.py @@ -246,14 +246,19 @@ def test_sparse_quantize_kd_linear_forward_backward(): model = SimpleLinearModel() teacher_model = SimpleLinearModel() - def patched_forward(x): - called["patched_forward"] += 1 - w = model.linear.weight - b = model.linear.bias if model.linear.bias is not None else None - return F.linear(x, w, b) + called = {"patched_forward": 0, "input_q": 0, "weight_q": 0, "pass": 0} + + def _make_patched_forward(linear): + def patched_forward(x): + called["patched_forward"] += 1 + w = linear.weight + b = linear.bias if linear.bias is not None else None + return F.linear(x, w, b) - model.linear.forward = patched_forward - teacher_model.linear.forward = patched_forward + return patched_forward + + model.linear.forward = _make_patched_forward(model.linear) + teacher_model.linear.forward = _make_patched_forward(teacher_model.linear) def _get_linear_kd_mode(): config = { @@ -267,8 +272,6 @@ def _get_linear_kd_mode(): model = mto.apply_mode(model, mode="quantize") model = mto.apply_mode(model, mode=_get_linear_kd_mode()) - called = {"patched_forward": 0, "input_q": 0, "weight_q": 0, "pass": 0} - def _count_quant_input(_m, _inp, _out): called["input_q"] += 1 From c45df20638f5da0bdeab798e3d0f17ba845719af Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Fri, 30 Jan 2026 22:34:52 +0000 Subject: [PATCH 5/5] remove some comments Signed-off-by: Jingyu Xin --- modelopt/torch/opt/dynamic.py | 1 - 1 file changed, 1 deletion(-) diff --git a/modelopt/torch/opt/dynamic.py b/modelopt/torch/opt/dynamic.py index 2131639be..5d81af871 100644 --- a/modelopt/torch/opt/dynamic.py +++ b/modelopt/torch/opt/dynamic.py @@ -630,7 +630,6 @@ def bind_forward_method_if_needed(self): # accelerate patched module bind_forward_method(self, self.__class__.forward) else: - # https://github.com/NVIDIA/Model-Optimizer/pull/824 if not hasattr(self, "_forward_pre_dm"): # Keep the patched forward for downstream modules that want to call it. self._forward_pre_dm = self.forward