Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions modelopt/torch/opt/dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,15 @@ def export(self) -> nn.Module:
assert not is_dynamic, "Exported module must not be a DynamicModule anymore!"
delattr(self, "_dm_attribute_manager")

# If this module had a monkey-patched forward before DynamicModule.convert(), we may have
# overridden it by binding the dynamic forward onto the instance (to follow the MRO).
# On final export, restore the original forward to avoid leaking a dynamic forward
# (e.g., DistillationModel.forward) onto the exported (non-dynamic) module instance.
# please see: https://github.com/NVIDIA/Model-Optimizer/pull/824
if hasattr(self, "_forward_pre_dm"):
setattr(self, "forward", getattr(self, "_forward_pre_dm"))
delattr(self, "_forward_pre_dm")

return self

@classmethod
Expand Down Expand Up @@ -621,6 +630,11 @@ def bind_forward_method_if_needed(self):
# accelerate patched module
bind_forward_method(self, self.__class__.forward)
else:
# https://github.com/NVIDIA/Model-Optimizer/pull/824
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we may not need this comment as the IDE can show the code change is from which PR.

if not hasattr(self, "_forward_pre_dm"):
# Keep the patched forward for downstream modules that want to call it.
self._forward_pre_dm = self.forward
bind_forward_method(self, self.__class__.forward)
warnings.warn(
"Received a module with monkey patched forward method. Dynamic converted module"
" might not work."
Expand Down
21 changes: 20 additions & 1 deletion modelopt/torch/quantization/nn/modules/quant_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,26 @@ class QuantInputBase(QuantModule):
def forward(self, input, *args, **kwargs):
"""Quantize the input before calling the original forward method."""
input = self.input_quantizer(input)
output = super().forward(input, *args, **kwargs)
# Check MR: https://github.com/NVIDIA/Model-Optimizer/pull/824
if hasattr(self, "_forward_pre_dm"):
pre_fwd = getattr(self, "_forward_pre_dm")

def _is_forward_in_mro(bound_or_func) -> bool:
# If this is a bound method, compare its underlying function to any `forward`
# implementation in the current MRO. If it matches, it's not an external monkey-patch.
if hasattr(bound_or_func, "__func__"):
fn = bound_or_func.__func__
for cls in type(self).mro():
if cls.__dict__.get("forward") is fn:
return True
return False

if pre_fwd is getattr(self, "forward") or _is_forward_in_mro(pre_fwd):
output = super().forward(input, *args, **kwargs)
else:
output = pre_fwd(input, *args, **kwargs)
else:
output = super().forward(input, *args, **kwargs)
if isinstance(output, tuple):
return (self.output_quantizer(output[0]), *output[1:])
return self.output_quantizer(output)
Expand Down
109 changes: 109 additions & 0 deletions tests/unit/torch/opt/test_chaining.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,28 @@

import pytest
import torch
import torch.nn.functional as F
from _test_utils.torch.misc import compare_outputs
from _test_utils.torch.opt.utils import apply_mode_with_sampling
from torchvision.models.mobilenetv2 import InvertedResidual

import modelopt.torch.distill as mtd
import modelopt.torch.nas as mtn
import modelopt.torch.opt as mto
import modelopt.torch.quantization as mtq
import modelopt.torch.sparsity as mts
from modelopt.torch.utils.distributed import _serialize


class SimpleLinearModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(4, 4)

def forward(self, x):
return self.linear(x)


def get_model():
return InvertedResidual(16, 32, 1, 6)

Expand Down Expand Up @@ -228,3 +239,101 @@ def test_sparse_quantized_module():
model = mtn.export(model)
assert torch.equal(conv.weight, weight_expected)
assert torch.equal(conv._parameters["weight"], weight_expected), "Weight should be overwritten!"


def test_sparse_quantize_kd_linear_forward_backward():
"""Ensure sparse + quantize + distill works for linear forward/backward."""
model = SimpleLinearModel()
teacher_model = SimpleLinearModel()

called = {"patched_forward": 0, "input_q": 0, "weight_q": 0, "pass": 0}

def _make_patched_forward(linear):
def patched_forward(x):
called["patched_forward"] += 1
w = linear.weight
b = linear.bias if linear.bias is not None else None
return F.linear(x, w, b)

return patched_forward

model.linear.forward = _make_patched_forward(model.linear)
teacher_model.linear.forward = _make_patched_forward(teacher_model.linear)

def _get_linear_kd_mode():
config = {
"teacher_model": teacher_model,
"criterion": {("linear", "linear"): mtd.LogitsDistillationLoss()},
"loss_balancer": mtd.StaticLossBalancer(),
}
return [("kd_loss", config)]

model = mto.apply_mode(model, mode="sparse_magnitude", init_state=True)
model = mto.apply_mode(model, mode="quantize")
model = mto.apply_mode(model, mode=_get_linear_kd_mode())

def _count_quant_input(_m, _inp, _out):
called["input_q"] += 1

def _count_quant_weight(_m, _inp, _out):
called["weight_q"] += 1

model.linear.input_quantizer.register_forward_hook(_count_quant_input)
model.linear.weight_quantizer.register_forward_hook(_count_quant_weight)

model.train()
x = torch.randn(2, 4)
target = torch.randn(2, 4)
output = model(x)
loss = F.mse_loss(output, target)
loss.backward()

assert output.shape == target.shape
assert any(p.grad is not None for p in model.parameters() if p.requires_grad), (
"Expected gradients on student parameters."
)
assert called["patched_forward"] == 2
assert called["input_q"] == 1
assert called["weight_q"] == 1


def test_chained_modes_preserve_forward_patching_during_quantize():
"""Ensure chained modes do not break runtime forward patching during quantize."""
model = InvertedResidual(16, 32, 1, 6).to(torch.float16)
model = mto.apply_mode(model, mode="fastnas", init_state=True)
model = mto.apply_mode(model, mode="export_nas")

conv = model.conv[0][0]
called = {"patched_forward": 0, "input_q": 0, "weight_q": 0}

def patched_forward(x):
called["patched_forward"] += 1
return F.conv2d(
x,
conv.weight,
conv.bias,
conv.stride,
conv.padding,
conv.dilation,
conv.groups,
)

conv.forward = patched_forward

def _count_input(_m, _inp, _out):
called["input_q"] += 1

def _count_weight(_m, _inp, _out):
called["weight_q"] += 1

def forward_loop(model):
conv.input_quantizer.register_forward_hook(_count_input)
conv.weight_quantizer.register_forward_hook(_count_weight)
x = torch.randn(1, 16, 8, 8, dtype=torch.float16)
model(x)

mtq.quantize(model, mtq.INT8_DEFAULT_CFG, forward_loop)

assert called["patched_forward"] == 1
assert called["input_q"] == 1
assert called["weight_q"] == 1
93 changes: 93 additions & 0 deletions tests/unit/torch/quantization/test_forward_patching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import types

import torch
import torch.nn.functional as F
from torch import nn

import modelopt.torch.quantization as mtq
from modelopt.torch.quantization import QuantModuleRegistry
from modelopt.torch.quantization.nn.modules.quant_module import QuantLinearConvBase


def test_quant_input_base_ignores_forward_pre_dm_in_mro():
"""Regression test for recursion when `_forward_pre_dm` points to a wrapper forward in the MRO.

In complex wrapper stacks, `_forward_pre_dm` may accidentally end up referencing a `forward`
method already present in the quant wrapper MRO (e.g. QuantLinearConvBase.forward). If
QuantInputBase.forward calls that directly, it can recurse indefinitely:

QuantLinearConvBase.forward -> super().forward (QuantInputBase.forward)
-> _forward_pre_dm (QuantLinearConvBase.forward) -> ...

The fix is to detect this case and fall back to `super().forward` instead.
"""
lin = nn.Linear(8, 8, bias=False)
QuantModuleRegistry.convert(lin)

# Force the problematic state: `_forward_pre_dm` points to a wrapper forward already in MRO.
lin._forward_pre_dm = types.MethodType(QuantLinearConvBase.forward, lin)

x = torch.randn(2, 8)
y = lin(x)
assert isinstance(y, torch.Tensor)
assert y.shape == (2, 8)


def test_quantize_calibration_calls_quantizers_with_runtime_forward_patch():
"""Regression test for on-the-fly forward patching during mtq.quantize calibration.

Some frameworks replace `module.forward` on-the-fly with a closure just before a forward pass.
During mtq.quantize calibration, quantizers must still run (input + weight at minimum).
"""
lin = nn.Linear(8, 8, bias=True).to(torch.float32)

called = {"patched_forward": 0, "input_q": 0, "weight_q": 0}

# Monkey patch instance-level forward (closure-style, no `self` argument).
def patched_forward(x):
called["patched_forward"] += 1
# Use module parameters directly; if quantization wrappers are active, weight access
# should still be routed through the quantized path.
w = lin.weight.to(dtype=x.dtype)
b = lin.bias.to(dtype=x.dtype) if lin.bias is not None else None
return F.linear(x, w, b)

def _count_input(_m, _inp, _out):
called["input_q"] += 1

def _count_weight(_m, _inp, _out):
called["weight_q"] += 1

lin.forward = patched_forward
x = torch.randn(2, 8, dtype=torch.float16)

def forward_loop(model):
# Patch forward on-the-fly (after conversion, right before calibration forward).

# Count quantizer executions during calibration.
model.input_quantizer.register_forward_hook(_count_input)
model.weight_quantizer.register_forward_hook(_count_weight)

model(x)

mtq.quantize(lin, mtq.INT8_DEFAULT_CFG, forward_loop)
lin(x)

assert called["patched_forward"] == 2
assert called["input_q"] == 2
assert called["weight_q"] == 2