Skip to content

Conversation

@jingyu-ml
Copy link
Contributor

@jingyu-ml jingyu-ml commented Jan 28, 2026

What does this PR do?

Type of change: new feature

Overview:

This MR improves robustness when forward() is monkey‑patched (replaced at runtime) on modules that later get wrapped/converted by ModelOpt (DynamicModule + quant wrappers).

It addresses two concrete failure modes introduced/exposed by supporting “patched forward” modules:

  1. Forward “leakage” after export: a dynamic wrapper forward could remain bound on an instance even after export() restores the original (non‑dynamic) class, causing runtime errors in unrelated codepaths (e.g. KD export/save/restore chains).

  2. Infinite recursion in quant wrappers: _forward_pre_dm can sometimes point to a wrapper forward that already participates in the class chain, causing a recursion loop when quant wrappers call _forward_pre_dm directly.

Usage

lin = torch.nn.Linear(4, 4)

def upcast_forward(x):
    # external closure: NOT part of any class MRO
    return torch.nn.functional.linear(x, lin.weight.to(x.dtype), lin.bias.to(x.dtype))

lin.forward = upcast_forward  # framework/user patches forward

# Later, ModelOpt converts/wraps the module.
# It stashes the patched function as `_forward_pre_dm` and binds the wrapper forward on the class.

# During quantization, QuantInputBase.forward sees `_forward_pre_dm` is NOT in MRO -> calls it.
# Imagine a module already wrapped by quant classes:
# QuantLinearConvBase.forward -> super().forward -> QuantInputBase.forward -> ...

# If `_forward_pre_dm` accidentally points to QuantLinearConvBase.forward (which IS in MRO),
# and QuantInputBase.forward calls it directly, you get:
# QuantInputBase.forward -> _forward_pre_dm (QuantLinearConvBase.forward)
# -> super().forward -> QuantInputBase.forward -> ...
# infinite recursion

# The fix: if `_forward_pre_dm` is a forward already in MRO, ignore it and use super().forward.

Testing

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: No
  • Did you write any new necessary tests?: Yes
  • Did you add or update any necessary documentation?: No
  • Did you update Changelog?:No

Additional Information

Summary by CodeRabbit

Release Notes

  • Bug Fixes

    • Improved forward method restoration during module export to prevent state leakage
    • Enhanced quantization behavior when using chained optimization modes
  • Tests

    • Added regression tests for quantization with runtime forward patching
    • Added validation tests for sparse quantization combined with distillation workflows

✏️ Tip: You can customize this high-level summary in your review settings.

Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
@jingyu-ml jingyu-ml requested review from a team as code owners January 28, 2026 21:40
@jingyu-ml jingyu-ml marked this pull request as draft January 28, 2026 21:40
@copy-pr-bot
Copy link

copy-pr-bot bot commented Jan 28, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@jingyu-ml jingyu-ml self-assigned this Jan 28, 2026
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 28, 2026

Important

Review skipped

Auto incremental reviews are disabled on this repository.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

  • 🔍 Trigger a full review
📝 Walkthrough

Walkthrough

This PR implements a forward method preservation mechanism during dynamic module conversion and quantization. It adds logic to save the pre-monkey-patched forward method to a _forward_pre_dm attribute and restore it on export, while conditional dispatch in quantization determines whether to use the patched or original forward path based on MRO comparison.

Changes

Cohort / File(s) Summary
Forward Patching Core
modelopt/torch/opt/dynamic.py, modelopt/torch/quantization/nn/modules/quant_module.py
Added forward method tracking and restoration via _forward_pre_dm attribute. Dynamic.py saves the current forward before binding and restores it on export. QuantInputBase.forward conditionally dispatches based on whether _forward_pre_dm exists, using MRO comparison to determine if it refers to the module's own forward implementation.
Forward Patching Tests
tests/unit/torch/opt/test_chaining.py, tests/unit/torch/quantization/test_forward_patching.py
Introduced SimpleLinearModel utility class and four test cases validating forward patching preservation across sparse+quantize+distill chaining, quantization calibration with runtime patching, and MRO-based forward dispatch correctness.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 28.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'Update on the QuantModule & DynamicModule to accept external forward' accurately describes the main changes: modifications to both QuantModule and DynamicModule to handle externally monkey-patched forward methods.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch jingyux/1.9-diffusion-export

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@codecov
Copy link

codecov bot commented Jan 28, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 73.85%. Comparing base (81b67dd) to head (2acbb62).

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #824      +/-   ##
==========================================
+ Coverage   73.82%   73.85%   +0.02%     
==========================================
  Files         193      193              
  Lines       19745    19763      +18     
==========================================
+ Hits        14577    14595      +18     
  Misses       5168     5168              

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
@jingyu-ml jingyu-ml marked this pull request as ready for review January 29, 2026 23:13
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@tests/unit/torch/opt/test_chaining.py`:
- Around line 244-295: The patched_forward closure currently closes over the
outer variable model and is assigned to both model.linear.forward and
teacher_model.linear.forward, causing the teacher to use the student's weights
after apply_mode; fix by creating a bound forward for each linear instance
(e.g., a factory make_patched_forward(linear_module) that captures linear_module
and returns a function using linear_module.weight/bias, or bind the function to
the instance with types.MethodType) and assign model.linear.forward =
make_patched_forward(model.linear) and teacher_model.linear.forward =
make_patched_forward(teacher_model.linear) so each module uses its own weights
during test_sparse_quantize_kd_linear_forward_backward.

Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
@jingyu-ml jingyu-ml enabled auto-merge (squash) January 29, 2026 23:37
@jingyu-ml jingyu-ml requested a review from mxinO January 29, 2026 23:49
Copy link
Contributor

@Edwardf0t1 Edwardf0t1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

# accelerate patched module
bind_forward_method(self, self.__class__.forward)
else:
# https://github.com/NVIDIA/Model-Optimizer/pull/824
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we may not need this comment as the IDE can show the code change is from which PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants