Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 78 additions & 45 deletions __init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
#shamelessly taken from forge

import nodes
import folder_paths

import bitsandbytes

import torch
import bitsandbytes as bnb

Expand Down Expand Up @@ -47,14 +44,39 @@ def copy_quant_state(state: QuantState, device: torch.device = None) -> QuantSta
state2=state2,
)


class ForgeParams4bit(Params4bit):
def to(self, *args, **kwargs):
_torch_fn_depth=0

@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if cls._torch_fn_depth > 0 or func != torch._C.TensorBase.detach:
return super().__torch_function__(func, types, args, kwargs or {})
cls._torch_fn_depth += 1
try:
slf = args[0]
n = cls(
torch.nn.Parameter.detach(slf),
requires_grad=slf.requires_grad,
quant_state=copy_quant_state(slf.quant_state, slf.device),
blocksize=slf.blocksize,
compress_statistics=slf.compress_statistics,
quant_type=slf.quant_type,
quant_storage=slf.quant_storage,
bnb_quantized=slf.bnb_quantized,
module=slf.module
)
return n
finally:
cls._torch_fn_depth -= 1

def to(self, *args, copy=False, **kwargs):
if copy:
return self.clone().to(*args, **kwargs)
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
if device is not None and device.type == "cuda" and not self.bnb_quantized:
return self._quantize(device)
else:
n = ForgeParams4bit(
n = self.__class__(
torch.nn.Parameter.to(self, device=device, dtype=dtype, non_blocking=non_blocking),
requires_grad=self.requires_grad,
quant_state=copy_quant_state(self.quant_state, device),
Expand Down Expand Up @@ -126,59 +148,70 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, miss
else:
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)

current_device = None
current_dtype = None
current_manual_cast_enabled = False
current_bnb_dtype = None

import comfy.ops

class OPS(comfy.ops.manual_cast):
class Linear(ForgeLoader4Bit):
def __init__(self, *args, device=None, dtype=None, **kwargs):
super().__init__(device=device, dtype=dtype, quant_type=current_bnb_dtype)
self.parameters_manual_cast = current_manual_cast_enabled

def forward(self, x):
self.weight.quant_state = self.quant_state

if self.bias is not None and self.bias.dtype != x.dtype:
# Maybe this can also be set to all non-bnb ops since the cost is very low.
# And it only invokes one time, and most linear does not have bias
self.bias.data = self.bias.data.to(x.dtype)

if not self.parameters_manual_cast:
return functional_linear_4bits(x, self.weight, self.bias)
elif not self.weight.bnb_quantized:
assert x.device.type == 'cuda', 'BNB Must Use CUDA as Computation Device!'
layer_original_device = self.weight.device
self.weight = self.weight._quantize(x.device)
bias = self.bias.to(x.device) if self.bias is not None else None
out = functional_linear_4bits(x, self.weight, bias)
self.weight = self.weight.to(layer_original_device)
return out
else:
weight, bias, signal = weights_manual_cast(self, x, skip_weight_dtype=True, skip_bias_dtype=True)
with main_stream_worker(weight, bias, signal):
return functional_linear_4bits(x, weight, bias)


class CheckpointLoaderNF4:
def make_ops(loader_class, current_device = None, current_dtype = None, current_manual_cast_enabled = False, current_bnb_dtype = None):

class OPS(comfy.ops.manual_cast):
class Linear(loader_class):
def __init__(self, *args, device=None, dtype=None, **kwargs):
super().__init__(device=device, dtype=dtype, quant_type=current_bnb_dtype)
self.parameters_manual_cast = current_manual_cast_enabled

def forward(self, x):
self.weight.quant_state = self.quant_state

if self.bias is not None and self.bias.dtype != x.dtype:
# Maybe this can also be set to all non-bnb ops since the cost is very low.
# And it only invokes one time, and most linear does not have bias
self.bias.data = self.bias.data.to(x.dtype)

if not self.parameters_manual_cast:
return functional_linear_4bits(x, self.weight, self.bias)
elif not self.weight.bnb_quantized:
assert x.device.type == 'cuda', 'BNB Must Use CUDA as Computation Device!'
layer_original_device = self.weight.device
self.weight = self.weight._quantize(x.device)
bias = self.bias.to(x.device) if self.bias is not None else None
out = functional_linear_4bits(x, self.weight, bias)
self.weight = self.weight.to(layer_original_device)
return out
else:
raise RuntimeError("Unexpected state in forward")

return OPS


class CheckpointLoaderBNB:
@classmethod
def INPUT_TYPES(s):
return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
}}
return {"required": {
"ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
"bnb_dtype": (("default", "nf4", "fp4"), {"default": "default"}),
}}

RETURN_TYPES = ("MODEL", "CLIP", "VAE")
FUNCTION = "load_checkpoint"

CATEGORY = "loaders"

def load_checkpoint(self, ckpt_name):
def load_checkpoint(self, ckpt_name, bnb_dtype="default"):
if bnb_dtype == "default":
bnb_dtype = None
ops = make_ops(ForgeLoader4Bit, current_bnb_dtype = bnb_dtype)
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"), model_options={"custom_operations": OPS})
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"), model_options={"custom_operations": ops})
return out[:3]

class CheckpointLoaderNF4(CheckpointLoaderBNB):
@classmethod
def INPUT_TYPES(s):
return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ), } }


NODE_CLASS_MAPPINGS = {
"CheckpointLoaderNF4": CheckpointLoaderNF4,
"CheckpointLoaderBNB": CheckpointLoaderBNB,
}