From f329975da78ce71ee350da22f8641b9f76b7faa5 Mon Sep 17 00:00:00 2001 From: blepping Date: Mon, 12 Aug 2024 05:31:35 -0600 Subject: [PATCH 1/3] Make loading unquantized models work --- __init__.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/__init__.py b/__init__.py index 8b72c7f..34ed7cc 100644 --- a/__init__.py +++ b/__init__.py @@ -47,9 +47,16 @@ def copy_quant_state(state: QuantState, device: torch.device = None) -> QuantSta state2=state2, ) - class ForgeParams4bit(Params4bit): - def to(self, *args, **kwargs): + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if func != torch._C.TensorBase.detach: + return super().__torch_function__(func, types, args, kwargs or {}) + return args[0] + + def to(self, *args, copy=False, **kwargs): + if copy: + return self.to(*args, **kwargs) device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) if device is not None and device.type == "cuda" and not self.bnb_quantized: return self._quantize(device) @@ -78,7 +85,7 @@ def __init__(self, *, device, dtype, quant_type, **kwargs): self.weight = None self.quant_state = None self.bias = None - self.quant_type = quant_type + self.quant_type = quant_type if quant_type is not None else "fp4" def _save_to_state_dict(self, destination, prefix, keep_vars): super()._save_to_state_dict(destination, prefix, keep_vars) From 1f1e58ca452933d15aa96f0771ff8cedd5eaf7f9 Mon Sep 17 00:00:00 2001 From: blepping Date: Mon, 12 Aug 2024 06:55:13 -0600 Subject: [PATCH 2/3] Fix to with copy enabled, hopefully more reliable detach --- __init__.py | 30 ++++++++++++++++++++++++------ 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/__init__.py b/__init__.py index 34ed7cc..4521043 100644 --- a/__init__.py +++ b/__init__.py @@ -48,15 +48,35 @@ def copy_quant_state(state: QuantState, device: torch.device = None) -> QuantSta ) class ForgeParams4bit(Params4bit): + _torch_fn_depth=0 + @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): - if func != torch._C.TensorBase.detach: + if cls._torch_fn_depth > 0 or func != torch._C.TensorBase.detach: return super().__torch_function__(func, types, args, kwargs or {}) - return args[0] + cls._torch_fn_depth += 1 + try: + slf = args[0] + n = ForgeParams4bit( + torch.nn.Parameter.detach(slf), + requires_grad=slf.requires_grad, + quant_state=copy_quant_state(slf.quant_state, slf.device), + blocksize=slf.blocksize, + compress_statistics=slf.compress_statistics, + quant_type=slf.quant_type, + quant_storage=slf.quant_storage, + bnb_quantized=slf.bnb_quantized, + module=slf.module + ) + return n + finally: + cls._torch_fn_depth -= 1 + + # return args[0] def to(self, *args, copy=False, **kwargs): if copy: - return self.to(*args, **kwargs) + return self.clone().to(*args, **kwargs) device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) if device is not None and device.type == "cuda" and not self.bnb_quantized: return self._quantize(device) @@ -165,9 +185,7 @@ def forward(self, x): self.weight = self.weight.to(layer_original_device) return out else: - weight, bias, signal = weights_manual_cast(self, x, skip_weight_dtype=True, skip_bias_dtype=True) - with main_stream_worker(weight, bias, signal): - return functional_linear_4bits(x, weight, bias) + raise RuntimeError("Unexpected state in forward") class CheckpointLoaderNF4: From 0bc0e411c0e6a3a54d91e97c68c8c4776f2ccb1b Mon Sep 17 00:00:00 2001 From: blepping Date: Mon, 12 Aug 2024 08:30:42 -0600 Subject: [PATCH 3/3] Support nf4 and fp4 Add CheckpointLoaderBNB node that allows selecting the BNB dtype --- __init__.py | 98 +++++++++++++++++++++++++++++------------------------ 1 file changed, 53 insertions(+), 45 deletions(-) diff --git a/__init__.py b/__init__.py index 4521043..e5e0ec3 100644 --- a/__init__.py +++ b/__init__.py @@ -1,10 +1,7 @@ #shamelessly taken from forge -import nodes import folder_paths -import bitsandbytes - import torch import bitsandbytes as bnb @@ -57,7 +54,7 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): cls._torch_fn_depth += 1 try: slf = args[0] - n = ForgeParams4bit( + n = cls( torch.nn.Parameter.detach(slf), requires_grad=slf.requires_grad, quant_state=copy_quant_state(slf.quant_state, slf.device), @@ -72,8 +69,6 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): finally: cls._torch_fn_depth -= 1 - # return args[0] - def to(self, *args, copy=False, **kwargs): if copy: return self.clone().to(*args, **kwargs) @@ -81,7 +76,7 @@ def to(self, *args, copy=False, **kwargs): if device is not None and device.type == "cuda" and not self.bnb_quantized: return self._quantize(device) else: - n = ForgeParams4bit( + n = self.__class__( torch.nn.Parameter.to(self, device=device, dtype=dtype, non_blocking=non_blocking), requires_grad=self.requires_grad, quant_state=copy_quant_state(self.quant_state, device), @@ -105,7 +100,7 @@ def __init__(self, *, device, dtype, quant_type, **kwargs): self.weight = None self.quant_state = None self.bias = None - self.quant_type = quant_type if quant_type is not None else "fp4" + self.quant_type = quant_type def _save_to_state_dict(self, destination, prefix, keep_vars): super()._save_to_state_dict(destination, prefix, keep_vars) @@ -153,57 +148,70 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, miss else: super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) -current_device = None -current_dtype = None -current_manual_cast_enabled = False -current_bnb_dtype = None import comfy.ops -class OPS(comfy.ops.manual_cast): - class Linear(ForgeLoader4Bit): - def __init__(self, *args, device=None, dtype=None, **kwargs): - super().__init__(device=device, dtype=dtype, quant_type=current_bnb_dtype) - self.parameters_manual_cast = current_manual_cast_enabled - - def forward(self, x): - self.weight.quant_state = self.quant_state - - if self.bias is not None and self.bias.dtype != x.dtype: - # Maybe this can also be set to all non-bnb ops since the cost is very low. - # And it only invokes one time, and most linear does not have bias - self.bias.data = self.bias.data.to(x.dtype) - - if not self.parameters_manual_cast: - return functional_linear_4bits(x, self.weight, self.bias) - elif not self.weight.bnb_quantized: - assert x.device.type == 'cuda', 'BNB Must Use CUDA as Computation Device!' - layer_original_device = self.weight.device - self.weight = self.weight._quantize(x.device) - bias = self.bias.to(x.device) if self.bias is not None else None - out = functional_linear_4bits(x, self.weight, bias) - self.weight = self.weight.to(layer_original_device) - return out - else: - raise RuntimeError("Unexpected state in forward") - - -class CheckpointLoaderNF4: +def make_ops(loader_class, current_device = None, current_dtype = None, current_manual_cast_enabled = False, current_bnb_dtype = None): + + class OPS(comfy.ops.manual_cast): + class Linear(loader_class): + def __init__(self, *args, device=None, dtype=None, **kwargs): + super().__init__(device=device, dtype=dtype, quant_type=current_bnb_dtype) + self.parameters_manual_cast = current_manual_cast_enabled + + def forward(self, x): + self.weight.quant_state = self.quant_state + + if self.bias is not None and self.bias.dtype != x.dtype: + # Maybe this can also be set to all non-bnb ops since the cost is very low. + # And it only invokes one time, and most linear does not have bias + self.bias.data = self.bias.data.to(x.dtype) + + if not self.parameters_manual_cast: + return functional_linear_4bits(x, self.weight, self.bias) + elif not self.weight.bnb_quantized: + assert x.device.type == 'cuda', 'BNB Must Use CUDA as Computation Device!' + layer_original_device = self.weight.device + self.weight = self.weight._quantize(x.device) + bias = self.bias.to(x.device) if self.bias is not None else None + out = functional_linear_4bits(x, self.weight, bias) + self.weight = self.weight.to(layer_original_device) + return out + else: + raise RuntimeError("Unexpected state in forward") + + return OPS + + +class CheckpointLoaderBNB: @classmethod def INPUT_TYPES(s): - return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ), - }} + return {"required": { + "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ), + "bnb_dtype": (("default", "nf4", "fp4"), {"default": "default"}), + }} + RETURN_TYPES = ("MODEL", "CLIP", "VAE") FUNCTION = "load_checkpoint" CATEGORY = "loaders" - def load_checkpoint(self, ckpt_name): + def load_checkpoint(self, ckpt_name, bnb_dtype="default"): + if bnb_dtype == "default": + bnb_dtype = None + ops = make_ops(ForgeLoader4Bit, current_bnb_dtype = bnb_dtype) ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) - out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"), model_options={"custom_operations": OPS}) + out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"), model_options={"custom_operations": ops}) return out[:3] +class CheckpointLoaderNF4(CheckpointLoaderBNB): + @classmethod + def INPUT_TYPES(s): + return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ), } } + + NODE_CLASS_MAPPINGS = { "CheckpointLoaderNF4": CheckpointLoaderNF4, + "CheckpointLoaderBNB": CheckpointLoaderBNB, }