diff --git a/.github/workflows/publish_action.yml b/.github/workflows/publish_action.yml new file mode 100644 index 0000000..25ce502 --- /dev/null +++ b/.github/workflows/publish_action.yml @@ -0,0 +1,24 @@ +name: Publish to Comfy registry +on: + workflow_dispatch: + push: + branches: + - master + paths: + - "pyproject.toml" + +permissions: + issues: write + +jobs: + publish-node: + name: Publish Custom Node to registry + runs-on: ubuntu-latest + if: ${{ github.repository_owner == 'silveroxides' }} + steps: + - name: Check out code + uses: actions/checkout@v4 + - name: Publish Custom Node + uses: Comfy-Org/publish-node-action@v1 + with: + personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }} ## Add your own personal access token to your Github Repository secrets and reference it here. diff --git a/README.md b/README.md index c205b41..4988b98 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,27 @@ -#### NOTE: This is very likely Deprecated in favor of GGUF which seems to give better results: https://github.com/city96/ComfyUI-GGUF +#### ~~NOTE: This is very likely Deprecated in favor of GGUF which seems to give better results~~ +Some users can experience speedup by combining loading UNET as NF4 using the loader from this repo and load T5XXL as GGUF using the repo from https://github.com/city96/ComfyUI-GGUF Now on the [manager](https://github.com/ltdrdata/ComfyUI-Manager) for easy installation. Make sure to select Channel:dev in the ComfyUI manager menu or install via git url. -A quickly written custom node that uses code from [Forge](https://github.com/lllyasviel/stable-diffusion-webui-forge) to support the [nf4 flux dev checkpoint](https://huggingface.co/lllyasviel/flux1-dev-bnb-nf4/blob/main/flux1-dev-bnb-nf4.safetensors) and [nf4 flux schnell checkpoint](https://huggingface.co/silveroxides/flux1-nf4-weights/blob/main/flux1-schnell-bnb-nf4.safetensors). +You can find the checkpoints and UNET in the linked repositories on huggingface or by searching for NF4 on Civitai + +### [CivitAI search link](https://civitai.com/search/models?baseModel=Flux.1%20S&baseModel=Flux.1%20D&sortBy=models_v9&query=nf4) + +### [nf4 flux unet only](https://huggingface.co/silveroxides/flux1-nf4-unet) + +### [nf4 flux dev checkpoint](https://huggingface.co/lllyasviel/flux1-dev-bnb-nf4/blob/main/flux1-dev-bnb-nf4.safetensors) + +### [nf4 flux schnell checkpoint](https://huggingface.co/silveroxides/flux1-nf4-weights/blob/main/flux1-schnell-bnb-nf4.safetensors) Requires installing bitsandbytes. Make sure your ComfyUI is updated. -The node is: CheckpointLoaderNF4, just plug it in your flux workflow instead of the regular one. +The nodes are: +#### "CheckpointLoaderNF4": "Load NF4 Flux Checkpoint" + +#### "UNETLoaderNF4": "Load NF4 Flux UNET" + +just plug it in your flux workflow instead of the regular ones. + +Code adapted from the implementation by Illyasviel at [Forge](https://github.com/lllyasviel/stable-diffusion-webui-forge). diff --git a/__init__.py b/__init__.py index 8b72c7f..f0039e9 100644 --- a/__init__.py +++ b/__init__.py @@ -1,10 +1,7 @@ #shamelessly taken from forge -import nodes import folder_paths -import bitsandbytes - import torch import bitsandbytes as bnb @@ -47,14 +44,39 @@ def copy_quant_state(state: QuantState, device: torch.device = None) -> QuantSta state2=state2, ) - class ForgeParams4bit(Params4bit): - def to(self, *args, **kwargs): + _torch_fn_depth=0 + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if cls._torch_fn_depth > 0 or func != torch._C.TensorBase.detach: + return super().__torch_function__(func, types, args, kwargs or {}) + cls._torch_fn_depth += 1 + try: + slf = args[0] + n = cls( + torch.nn.Parameter.detach(slf), + requires_grad=slf.requires_grad, + quant_state=copy_quant_state(slf.quant_state, slf.device), + blocksize=slf.blocksize, + compress_statistics=slf.compress_statistics, + quant_type=slf.quant_type, + quant_storage=slf.quant_storage, + bnb_quantized=slf.bnb_quantized, + module=slf.module + ) + return n + finally: + cls._torch_fn_depth -= 1 + + def to(self, *args, copy=False, **kwargs): + if copy: + return self.clone().to(*args, **kwargs) device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) if device is not None and device.type == "cuda" and not self.bnb_quantized: return self._quantize(device) else: - n = ForgeParams4bit( + n = self.__class__( torch.nn.Parameter.to(self, device=device, dtype=dtype, non_blocking=non_blocking), requires_grad=self.requires_grad, quant_state=copy_quant_state(self.quant_state, device), @@ -126,59 +148,170 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, miss else: super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) -current_device = None -current_dtype = None -current_manual_cast_enabled = False -current_bnb_dtype = None import comfy.ops -class OPS(comfy.ops.manual_cast): - class Linear(ForgeLoader4Bit): - def __init__(self, *args, device=None, dtype=None, **kwargs): - super().__init__(device=device, dtype=dtype, quant_type=current_bnb_dtype) - self.parameters_manual_cast = current_manual_cast_enabled - - def forward(self, x): - self.weight.quant_state = self.quant_state - - if self.bias is not None and self.bias.dtype != x.dtype: - # Maybe this can also be set to all non-bnb ops since the cost is very low. - # And it only invokes one time, and most linear does not have bias - self.bias.data = self.bias.data.to(x.dtype) - - if not self.parameters_manual_cast: - return functional_linear_4bits(x, self.weight, self.bias) - elif not self.weight.bnb_quantized: - assert x.device.type == 'cuda', 'BNB Must Use CUDA as Computation Device!' - layer_original_device = self.weight.device - self.weight = self.weight._quantize(x.device) - bias = self.bias.to(x.device) if self.bias is not None else None - out = functional_linear_4bits(x, self.weight, bias) - self.weight = self.weight.to(layer_original_device) - return out - else: - weight, bias, signal = weights_manual_cast(self, x, skip_weight_dtype=True, skip_bias_dtype=True) - with main_stream_worker(weight, bias, signal): - return functional_linear_4bits(x, weight, bias) - +def make_ops(loader_class, current_device = None, current_dtype = None, current_manual_cast_enabled = False, current_bnb_dtype = None): + + class OPS(comfy.ops.manual_cast): + class Linear(loader_class): + def __init__(self, *args, device=None, dtype=None, **kwargs): + super().__init__(device=device, dtype=dtype, quant_type=current_bnb_dtype) + self.parameters_manual_cast = current_manual_cast_enabled + + def forward(self, x): + self.weight.quant_state = self.quant_state + + if self.bias is not None and self.bias.dtype != x.dtype: + # Maybe this can also be set to all non-bnb ops since the cost is very low. + # And it only invokes one time, and most linear does not have bias + self.bias.data = self.bias.data.to(x.dtype) + + if not self.parameters_manual_cast: + return functional_linear_4bits(x, self.weight, self.bias) + elif not self.weight.bnb_quantized: + assert x.device.type == 'cuda', 'BNB Must Use CUDA as Computation Device!' + layer_original_device = self.weight.device + self.weight = self.weight._quantize(x.device) + bias = self.bias.to(x.device) if self.bias is not None else None + out = functional_linear_4bits(x, self.weight, bias) + self.weight = self.weight.to(layer_original_device) + return out + else: + raise RuntimeError("Unexpected state in forward") + + return OPS class CheckpointLoaderNF4: + NodeId = 'CheckpointLoaderNF4' + NodeName = 'Load FP4 or NF4 Quantized Checkpoint Model' @classmethod def INPUT_TYPES(s): - return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ), - }} + return {"required": { + "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ), + "bnb_dtype": (("default", "nf4", "fp4"), {"default": "default"}), + }} + RETURN_TYPES = ("MODEL", "CLIP", "VAE") FUNCTION = "load_checkpoint" CATEGORY = "loaders" - def load_checkpoint(self, ckpt_name): + def load_checkpoint(self, ckpt_name, bnb_dtype="default"): + if bnb_dtype == "default": + bnb_dtype = None + ops = make_ops(ForgeLoader4Bit, current_bnb_dtype = bnb_dtype) ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) - out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"), model_options={"custom_operations": OPS}) + out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"), model_options={"custom_operations": ops}) return out[:3] -NODE_CLASS_MAPPINGS = { - "CheckpointLoaderNF4": CheckpointLoaderNF4, -} +class UNETLoaderNF4: + NodeId = 'UNETLoaderNF4' + NodeName = 'Load FP4 or NF4 Quantized Diffusion or UNET Model' + @classmethod + def INPUT_TYPES(s): + return {"required": { + "unet_name": (folder_paths.get_filename_list("diffusion_models"), ), + "bnb_dtype": (("default", "nf4", "fp4"), {"default": "default"}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "load_unet" + + CATEGORY = "advanced/loaders" + + def load_unet(self, unet_name, bnb_dtype="default"): + if bnb_dtype == "default": + bnb_dtype = None + ops = make_ops(ForgeLoader4Bit, current_bnb_dtype = bnb_dtype) + unet_path = folder_paths.get_full_path("unet", unet_name) + model = comfy.sd.load_diffusion_model(unet_path, model_options={"custom_operations": ops}) + return (model,) + +class CLIPLoaderNF4: + # WIP + NodeId = 'CLIPLoaderNF4' + NodeName = 'Load FP4 or NF4 Quantized Text Encoder' + @classmethod + def INPUT_TYPES(s): + return {"required": { + "clip_name": (folder_paths.get_filename_list("text_encoders"), ), + "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos"], ), + "bnb_dtype": (("default", "nf4", "fp4"), {"default": "default"}), + }} + RETURN_TYPES = ("CLIP",) + FUNCTION = "load_clip" + + CATEGORY = "advanced/loaders" + + def load_clip(self, clip_name, type): + if type == "stable_cascade": + clip_type = comfy.sd.CLIPType.STABLE_CASCADE + elif type == "sd3": + clip_type = comfy.sd.CLIPType.SD3 + elif type == "stable_audio": + clip_type = comfy.sd.CLIPType.STABLE_AUDIO + elif type == "mochi": + clip_type = comfy.sd.CLIPType.MOCHI + elif type == "ltxv": + clip_type = comfy.sd.CLIPType.LTXV + elif type == "pixart": + clip_type = comfy.sd.CLIPType.PIXART + else: + clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION + + clip_path = folder_paths.get_full_path("clip", clip_name) + clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options={"custom_operations": OPS}) + return (clip,) +class DualCLIPLoaderNF4: + # WIP + NodeId = 'DualCLIPLoaderNF4' + NodeName = 'Load FP4 or NF4 Quantized Text Encoders' + @classmethod + def INPUT_TYPES(s): + return {"required": { + "clip_name1": (folder_paths.get_filename_list("text_encoders"), ), + "clip_name2": (folder_paths.get_filename_list("text_encoders"), ), + "type": (["sdxl", "sd3", "flux", "hunyuan_video", "custom_flux", "fluxmod"], ), + "bnb_dtype": (("default", "nf4", "fp4"), {"default": "default"}), + }} + RETURN_TYPES = ("CLIP",) + FUNCTION = "load_clip" + + CATEGORY = "advanced/loaders" + + def load_clip(self, clip_name1, clip_name2, type): + clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1) + clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2) + if type == "sdxl": + clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION + elif type == "sd3": + clip_type = comfy.sd.CLIPType.SD3 + elif type == "flux": + clip_type = comfy.sd.CLIPType.FLUX + elif type == "hunyuan_video": + clip_type = comfy.sd.CLIPType.HUNYUAN_VIDEO + elif type == "custom_flux": + clip_type = comfy.sd.CLIPType.FLUXC + elif type == "fluxmod": + clip_type = comfy.sd.CLIPType.FLUXMOD + + clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options={"custom_operations": OPS}) + return (clip,) + +node_list = [ + # Checkpoint model loaders + CheckpointLoaderNF4, + # Diffusion model loaders + UNETLoaderNF4, + # Text encoder model loaders + # WIP CLIPLoaderNF4, + # WIP DualCLIPLoaderNF4, +] + +NODE_CLASS_MAPPINGS = {} +NODE_DISPLAY_NAME_MAPPINGS = {} + +for node in node_list: + NODE_CLASS_MAPPINGS[node.NodeId] = node + NODE_DISPLAY_NAME_MAPPINGS[node.NodeId] = node.NodeName diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..33e8bf1 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,14 @@ +[project] +name = "ComfyUI_bnb_nf4_fp4_Loaders" +description = 'Checkpoint and UNET/Diffusion Model loaders for BitsAndBytes NF4 quantized models' +version = "1.0.5" +license = { file = "LICENSE.txt" } + +[project.urls] +Repository = "https://github.com/silveroxides/ComfyUI_bitsandbytes_NF4" +# Used by Comfy Registry https://comfyregistry.org + +[tool.comfy] +PublisherId = "silveroxides" +DisplayName = "BitsAndBytes NF4 Loaders" +Icon = ""