diff --git a/__init__.py b/__init__.py index 8b72c7f..e5e0ec3 100644 --- a/__init__.py +++ b/__init__.py @@ -1,10 +1,7 @@ #shamelessly taken from forge -import nodes import folder_paths -import bitsandbytes - import torch import bitsandbytes as bnb @@ -47,14 +44,39 @@ def copy_quant_state(state: QuantState, device: torch.device = None) -> QuantSta state2=state2, ) - class ForgeParams4bit(Params4bit): - def to(self, *args, **kwargs): + _torch_fn_depth=0 + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if cls._torch_fn_depth > 0 or func != torch._C.TensorBase.detach: + return super().__torch_function__(func, types, args, kwargs or {}) + cls._torch_fn_depth += 1 + try: + slf = args[0] + n = cls( + torch.nn.Parameter.detach(slf), + requires_grad=slf.requires_grad, + quant_state=copy_quant_state(slf.quant_state, slf.device), + blocksize=slf.blocksize, + compress_statistics=slf.compress_statistics, + quant_type=slf.quant_type, + quant_storage=slf.quant_storage, + bnb_quantized=slf.bnb_quantized, + module=slf.module + ) + return n + finally: + cls._torch_fn_depth -= 1 + + def to(self, *args, copy=False, **kwargs): + if copy: + return self.clone().to(*args, **kwargs) device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) if device is not None and device.type == "cuda" and not self.bnb_quantized: return self._quantize(device) else: - n = ForgeParams4bit( + n = self.__class__( torch.nn.Parameter.to(self, device=device, dtype=dtype, non_blocking=non_blocking), requires_grad=self.requires_grad, quant_state=copy_quant_state(self.quant_state, device), @@ -126,59 +148,70 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, miss else: super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) -current_device = None -current_dtype = None -current_manual_cast_enabled = False -current_bnb_dtype = None import comfy.ops -class OPS(comfy.ops.manual_cast): - class Linear(ForgeLoader4Bit): - def __init__(self, *args, device=None, dtype=None, **kwargs): - super().__init__(device=device, dtype=dtype, quant_type=current_bnb_dtype) - self.parameters_manual_cast = current_manual_cast_enabled - - def forward(self, x): - self.weight.quant_state = self.quant_state - - if self.bias is not None and self.bias.dtype != x.dtype: - # Maybe this can also be set to all non-bnb ops since the cost is very low. - # And it only invokes one time, and most linear does not have bias - self.bias.data = self.bias.data.to(x.dtype) - - if not self.parameters_manual_cast: - return functional_linear_4bits(x, self.weight, self.bias) - elif not self.weight.bnb_quantized: - assert x.device.type == 'cuda', 'BNB Must Use CUDA as Computation Device!' - layer_original_device = self.weight.device - self.weight = self.weight._quantize(x.device) - bias = self.bias.to(x.device) if self.bias is not None else None - out = functional_linear_4bits(x, self.weight, bias) - self.weight = self.weight.to(layer_original_device) - return out - else: - weight, bias, signal = weights_manual_cast(self, x, skip_weight_dtype=True, skip_bias_dtype=True) - with main_stream_worker(weight, bias, signal): - return functional_linear_4bits(x, weight, bias) - - -class CheckpointLoaderNF4: +def make_ops(loader_class, current_device = None, current_dtype = None, current_manual_cast_enabled = False, current_bnb_dtype = None): + + class OPS(comfy.ops.manual_cast): + class Linear(loader_class): + def __init__(self, *args, device=None, dtype=None, **kwargs): + super().__init__(device=device, dtype=dtype, quant_type=current_bnb_dtype) + self.parameters_manual_cast = current_manual_cast_enabled + + def forward(self, x): + self.weight.quant_state = self.quant_state + + if self.bias is not None and self.bias.dtype != x.dtype: + # Maybe this can also be set to all non-bnb ops since the cost is very low. + # And it only invokes one time, and most linear does not have bias + self.bias.data = self.bias.data.to(x.dtype) + + if not self.parameters_manual_cast: + return functional_linear_4bits(x, self.weight, self.bias) + elif not self.weight.bnb_quantized: + assert x.device.type == 'cuda', 'BNB Must Use CUDA as Computation Device!' + layer_original_device = self.weight.device + self.weight = self.weight._quantize(x.device) + bias = self.bias.to(x.device) if self.bias is not None else None + out = functional_linear_4bits(x, self.weight, bias) + self.weight = self.weight.to(layer_original_device) + return out + else: + raise RuntimeError("Unexpected state in forward") + + return OPS + + +class CheckpointLoaderBNB: @classmethod def INPUT_TYPES(s): - return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ), - }} + return {"required": { + "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ), + "bnb_dtype": (("default", "nf4", "fp4"), {"default": "default"}), + }} + RETURN_TYPES = ("MODEL", "CLIP", "VAE") FUNCTION = "load_checkpoint" CATEGORY = "loaders" - def load_checkpoint(self, ckpt_name): + def load_checkpoint(self, ckpt_name, bnb_dtype="default"): + if bnb_dtype == "default": + bnb_dtype = None + ops = make_ops(ForgeLoader4Bit, current_bnb_dtype = bnb_dtype) ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) - out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"), model_options={"custom_operations": OPS}) + out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"), model_options={"custom_operations": ops}) return out[:3] +class CheckpointLoaderNF4(CheckpointLoaderBNB): + @classmethod + def INPUT_TYPES(s): + return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ), } } + + NODE_CLASS_MAPPINGS = { "CheckpointLoaderNF4": CheckpointLoaderNF4, + "CheckpointLoaderBNB": CheckpointLoaderBNB, }