Skip to content

Commit df58c80

Browse files
committed
up
1 parent 2a827ec commit df58c80

File tree

2 files changed

+41
-106
lines changed

2 files changed

+41
-106
lines changed

src/diffusers/quantizers/nunchaku/nunchaku_quantizer.py

Lines changed: 41 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,6 @@
2020
if is_torch_available():
2121
import torch
2222

23-
if is_accelerate_available():
24-
pass
25-
26-
if is_nunchaku_available():
27-
from .utils import replace_with_nunchaku_linear
2823

2924
logger = logging.get_logger(__name__)
3025

@@ -79,13 +74,14 @@ def check_if_quantized_param(
7974
state_dict: Dict[str, Any],
8075
**kwargs,
8176
):
82-
from nunchaku.models.linear import SVDQW4A4Linear
83-
84-
module, tensor_name = get_module_from_name(model, param_name)
85-
if self.pre_quantized and isinstance(module, SVDQW4A4Linear):
86-
return True
87-
88-
return False
77+
# TODO: revisit
78+
# Check if the param_name is not in self.modules_to_not_convert
79+
if any((key + "." in param_name) or (key == param_name) for key in self.modules_to_not_convert):
80+
return False
81+
else:
82+
# We only quantize the weight of nn.Linear
83+
module, _ = get_module_from_name(model, param_name)
84+
return isinstance(module, torch.nn.Linear)
8985

9086
def create_quantized_param(
9187
self,
@@ -112,13 +108,32 @@ def create_quantized_param(
112108
module._buffers[tensor_name] = torch.nn.Parameter(param_value.to(target_device))
113109

114110
elif isinstance(module, torch.nn.Linear):
115-
if tensor_name in module._parameters:
116-
module._parameters[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device)
117-
if tensor_name in module._buffers:
118-
module._buffers[tensor_name] = torch.nn.Parameter(param_value).to(target_device)
119-
120-
new_module = SVDQW4A4Linear.from_linear(module)
121-
setattr(model, param_name, new_module)
111+
# TODO: this returns an `SVDQW4A4Linear` layer initialized from the corresponding `linear` module.
112+
# But we need to have a utility that can take a pretrained param value and quantize it. Not sure
113+
# how to do that yet.
114+
# Essentially, we need something like `bnb.nn.Params4bit.from_prequantized`. Or is there a better
115+
# way to do it?
116+
is_param = tensor_name in module._parameters
117+
is_buffer = tensor_name in module._buffers
118+
new_module = SVDQW4A4Linear.from_linear(
119+
module, precision=self.quantization_config.precision, rank=self.quantization_config.rank
120+
)
121+
module_name = ".".join(param_name.split(".")[:-1])
122+
if "." in module_name:
123+
parent_name, leaf = module_name.rsplit(".", 1)
124+
parent = model.get_submodule(parent_name)
125+
else:
126+
parent, leaf = model, module_name
127+
128+
# rebind
129+
# this will result into
130+
# AttributeError: 'SVDQW4A4Linear' object has no attribute 'weight'. Did you mean: 'qweight'.
131+
if is_param:
132+
new_module._parameters[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device)
133+
elif is_buffer:
134+
new_module._buffers[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device)
135+
136+
setattr(parent, leaf, new_module)
122137

123138
def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]:
124139
max_memory = {key: val * 0.90 for key, val in max_memory.items()}
@@ -157,24 +172,25 @@ def _process_model_before_weight_loading(
157172
keep_in_fp32_modules: List[str] = [],
158173
**kwargs,
159174
):
160-
# TODO: deal with `device_map`
161175
self.modules_to_not_convert = self.quantization_config.modules_to_not_convert
162176

163177
if not isinstance(self.modules_to_not_convert, list):
164178
self.modules_to_not_convert = [self.modules_to_not_convert]
165179

166180
self.modules_to_not_convert.extend(keep_in_fp32_modules)
181+
182+
# TODO: revisit
183+
# Extend `self.modules_to_not_convert` to keys that are supposed to be offloaded to `cpu` or `disk`
184+
# if isinstance(device_map, dict) and len(device_map.keys()) > 1:
185+
# keys_on_cpu = [key for key, value in device_map.items() if value in ["disk", "cpu"]]
186+
# self.modules_to_not_convert.extend(keys_on_cpu)
187+
167188
# Purge `None`.
168189
# Unlike `transformers`, we don't know if we should always keep certain modules in FP32
169190
# in case of diffusion transformer models. For language models and others alike, `lm_head`
170191
# and tied modules are usually kept in FP32.
171192
self.modules_to_not_convert = [module for module in self.modules_to_not_convert if module is not None]
172193

173-
model = replace_with_nunchaku_linear(
174-
model,
175-
modules_to_not_convert=self.modules_to_not_convert,
176-
quantization_config=self.quantization_config,
177-
)
178194
model.config.quantization_config = self.quantization_config
179195

180196
def _process_model_after_weight_loading(self, model, **kwargs):

src/diffusers/quantizers/nunchaku/utils.py

Lines changed: 0 additions & 81 deletions
This file was deleted.

0 commit comments

Comments
 (0)