Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions .github/workflows/publish_action.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
name: Publish to Comfy registry
on:
workflow_dispatch:
push:
branches:
- master
paths:
- "pyproject.toml"

permissions:
issues: write

jobs:
publish-node:
name: Publish Custom Node to registry
runs-on: ubuntu-latest
if: ${{ github.repository_owner == 'silveroxides' }}
steps:
- name: Check out code
uses: actions/checkout@v4
- name: Publish Custom Node
uses: Comfy-Org/publish-node-action@v1
with:
personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }} ## Add your own personal access token to your Github Repository secrets and reference it here.
22 changes: 19 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,27 @@
#### NOTE: This is very likely Deprecated in favor of GGUF which seems to give better results: https://github.com/city96/ComfyUI-GGUF
#### ~~NOTE: This is very likely Deprecated in favor of GGUF which seems to give better results~~
Some users can experience speedup by combining loading UNET as NF4 using the loader from this repo and load T5XXL as GGUF using the repo from https://github.com/city96/ComfyUI-GGUF

Now on the [manager](https://github.com/ltdrdata/ComfyUI-Manager) for easy installation. Make sure to select Channel:dev in the ComfyUI manager menu or install via git url.

A quickly written custom node that uses code from [Forge](https://github.com/lllyasviel/stable-diffusion-webui-forge) to support the [nf4 flux dev checkpoint](https://huggingface.co/lllyasviel/flux1-dev-bnb-nf4/blob/main/flux1-dev-bnb-nf4.safetensors) and [nf4 flux schnell checkpoint](https://huggingface.co/silveroxides/flux1-nf4-weights/blob/main/flux1-schnell-bnb-nf4.safetensors).
You can find the checkpoints and UNET in the linked repositories on huggingface or by searching for NF4 on Civitai

### [CivitAI search link](https://civitai.com/search/models?baseModel=Flux.1%20S&baseModel=Flux.1%20D&sortBy=models_v9&query=nf4)

### [nf4 flux unet only](https://huggingface.co/silveroxides/flux1-nf4-unet)

### [nf4 flux dev checkpoint](https://huggingface.co/lllyasviel/flux1-dev-bnb-nf4/blob/main/flux1-dev-bnb-nf4.safetensors)

### [nf4 flux schnell checkpoint](https://huggingface.co/silveroxides/flux1-nf4-weights/blob/main/flux1-schnell-bnb-nf4.safetensors)

Requires installing bitsandbytes.

Make sure your ComfyUI is updated.

The node is: CheckpointLoaderNF4, just plug it in your flux workflow instead of the regular one.
The nodes are:
#### "CheckpointLoaderNF4": "Load NF4 Flux Checkpoint"

#### "UNETLoaderNF4": "Load NF4 Flux UNET"

just plug it in your flux workflow instead of the regular ones.

Code adapted from the implementation by Illyasviel at [Forge](https://github.com/lllyasviel/stable-diffusion-webui-forge).
225 changes: 179 additions & 46 deletions __init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
#shamelessly taken from forge

import nodes
import folder_paths

import bitsandbytes

import torch
import bitsandbytes as bnb

Expand Down Expand Up @@ -47,14 +44,39 @@ def copy_quant_state(state: QuantState, device: torch.device = None) -> QuantSta
state2=state2,
)


class ForgeParams4bit(Params4bit):
def to(self, *args, **kwargs):
_torch_fn_depth=0

@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if cls._torch_fn_depth > 0 or func != torch._C.TensorBase.detach:
return super().__torch_function__(func, types, args, kwargs or {})
cls._torch_fn_depth += 1
try:
slf = args[0]
n = cls(
torch.nn.Parameter.detach(slf),
requires_grad=slf.requires_grad,
quant_state=copy_quant_state(slf.quant_state, slf.device),
blocksize=slf.blocksize,
compress_statistics=slf.compress_statistics,
quant_type=slf.quant_type,
quant_storage=slf.quant_storage,
bnb_quantized=slf.bnb_quantized,
module=slf.module
)
return n
finally:
cls._torch_fn_depth -= 1

def to(self, *args, copy=False, **kwargs):
if copy:
return self.clone().to(*args, **kwargs)
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
if device is not None and device.type == "cuda" and not self.bnb_quantized:
return self._quantize(device)
else:
n = ForgeParams4bit(
n = self.__class__(
torch.nn.Parameter.to(self, device=device, dtype=dtype, non_blocking=non_blocking),
requires_grad=self.requires_grad,
quant_state=copy_quant_state(self.quant_state, device),
Expand Down Expand Up @@ -126,59 +148,170 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, miss
else:
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)

current_device = None
current_dtype = None
current_manual_cast_enabled = False
current_bnb_dtype = None

import comfy.ops

class OPS(comfy.ops.manual_cast):
class Linear(ForgeLoader4Bit):
def __init__(self, *args, device=None, dtype=None, **kwargs):
super().__init__(device=device, dtype=dtype, quant_type=current_bnb_dtype)
self.parameters_manual_cast = current_manual_cast_enabled

def forward(self, x):
self.weight.quant_state = self.quant_state

if self.bias is not None and self.bias.dtype != x.dtype:
# Maybe this can also be set to all non-bnb ops since the cost is very low.
# And it only invokes one time, and most linear does not have bias
self.bias.data = self.bias.data.to(x.dtype)

if not self.parameters_manual_cast:
return functional_linear_4bits(x, self.weight, self.bias)
elif not self.weight.bnb_quantized:
assert x.device.type == 'cuda', 'BNB Must Use CUDA as Computation Device!'
layer_original_device = self.weight.device
self.weight = self.weight._quantize(x.device)
bias = self.bias.to(x.device) if self.bias is not None else None
out = functional_linear_4bits(x, self.weight, bias)
self.weight = self.weight.to(layer_original_device)
return out
else:
weight, bias, signal = weights_manual_cast(self, x, skip_weight_dtype=True, skip_bias_dtype=True)
with main_stream_worker(weight, bias, signal):
return functional_linear_4bits(x, weight, bias)

def make_ops(loader_class, current_device = None, current_dtype = None, current_manual_cast_enabled = False, current_bnb_dtype = None):

class OPS(comfy.ops.manual_cast):
class Linear(loader_class):
def __init__(self, *args, device=None, dtype=None, **kwargs):
super().__init__(device=device, dtype=dtype, quant_type=current_bnb_dtype)
self.parameters_manual_cast = current_manual_cast_enabled

def forward(self, x):
self.weight.quant_state = self.quant_state

if self.bias is not None and self.bias.dtype != x.dtype:
# Maybe this can also be set to all non-bnb ops since the cost is very low.
# And it only invokes one time, and most linear does not have bias
self.bias.data = self.bias.data.to(x.dtype)

if not self.parameters_manual_cast:
return functional_linear_4bits(x, self.weight, self.bias)
elif not self.weight.bnb_quantized:
assert x.device.type == 'cuda', 'BNB Must Use CUDA as Computation Device!'
layer_original_device = self.weight.device
self.weight = self.weight._quantize(x.device)
bias = self.bias.to(x.device) if self.bias is not None else None
out = functional_linear_4bits(x, self.weight, bias)
self.weight = self.weight.to(layer_original_device)
return out
else:
raise RuntimeError("Unexpected state in forward")

return OPS

class CheckpointLoaderNF4:
NodeId = 'CheckpointLoaderNF4'
NodeName = 'Load FP4 or NF4 Quantized Checkpoint Model'
@classmethod
def INPUT_TYPES(s):
return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
}}
return {"required": {
"ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
"bnb_dtype": (("default", "nf4", "fp4"), {"default": "default"}),
}}

RETURN_TYPES = ("MODEL", "CLIP", "VAE")
FUNCTION = "load_checkpoint"

CATEGORY = "loaders"

def load_checkpoint(self, ckpt_name):
def load_checkpoint(self, ckpt_name, bnb_dtype="default"):
if bnb_dtype == "default":
bnb_dtype = None
ops = make_ops(ForgeLoader4Bit, current_bnb_dtype = bnb_dtype)
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"), model_options={"custom_operations": OPS})
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"), model_options={"custom_operations": ops})
return out[:3]

NODE_CLASS_MAPPINGS = {
"CheckpointLoaderNF4": CheckpointLoaderNF4,
}
class UNETLoaderNF4:
NodeId = 'UNETLoaderNF4'
NodeName = 'Load FP4 or NF4 Quantized Diffusion or UNET Model'
@classmethod
def INPUT_TYPES(s):
return {"required": {
"unet_name": (folder_paths.get_filename_list("diffusion_models"), ),
"bnb_dtype": (("default", "nf4", "fp4"), {"default": "default"}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "load_unet"

CATEGORY = "advanced/loaders"

def load_unet(self, unet_name, bnb_dtype="default"):
if bnb_dtype == "default":
bnb_dtype = None
ops = make_ops(ForgeLoader4Bit, current_bnb_dtype = bnb_dtype)
unet_path = folder_paths.get_full_path("unet", unet_name)
model = comfy.sd.load_diffusion_model(unet_path, model_options={"custom_operations": ops})
return (model,)

class CLIPLoaderNF4:
# WIP
NodeId = 'CLIPLoaderNF4'
NodeName = 'Load FP4 or NF4 Quantized Text Encoder'
@classmethod
def INPUT_TYPES(s):
return {"required": {
"clip_name": (folder_paths.get_filename_list("text_encoders"), ),
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos"], ),
"bnb_dtype": (("default", "nf4", "fp4"), {"default": "default"}),
}}
RETURN_TYPES = ("CLIP",)
FUNCTION = "load_clip"

CATEGORY = "advanced/loaders"

def load_clip(self, clip_name, type):
if type == "stable_cascade":
clip_type = comfy.sd.CLIPType.STABLE_CASCADE
elif type == "sd3":
clip_type = comfy.sd.CLIPType.SD3
elif type == "stable_audio":
clip_type = comfy.sd.CLIPType.STABLE_AUDIO
elif type == "mochi":
clip_type = comfy.sd.CLIPType.MOCHI
elif type == "ltxv":
clip_type = comfy.sd.CLIPType.LTXV
elif type == "pixart":
clip_type = comfy.sd.CLIPType.PIXART
else:
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION

clip_path = folder_paths.get_full_path("clip", clip_name)
clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options={"custom_operations": OPS})
return (clip,)

class DualCLIPLoaderNF4:
# WIP
NodeId = 'DualCLIPLoaderNF4'
NodeName = 'Load FP4 or NF4 Quantized Text Encoders'
@classmethod
def INPUT_TYPES(s):
return {"required": {
"clip_name1": (folder_paths.get_filename_list("text_encoders"), ),
"clip_name2": (folder_paths.get_filename_list("text_encoders"), ),
"type": (["sdxl", "sd3", "flux", "hunyuan_video", "custom_flux", "fluxmod"], ),
"bnb_dtype": (("default", "nf4", "fp4"), {"default": "default"}),
}}
RETURN_TYPES = ("CLIP",)
FUNCTION = "load_clip"

CATEGORY = "advanced/loaders"

def load_clip(self, clip_name1, clip_name2, type):
clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1)
clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2)
if type == "sdxl":
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
elif type == "sd3":
clip_type = comfy.sd.CLIPType.SD3
elif type == "flux":
clip_type = comfy.sd.CLIPType.FLUX
elif type == "hunyuan_video":
clip_type = comfy.sd.CLIPType.HUNYUAN_VIDEO
elif type == "custom_flux":
clip_type = comfy.sd.CLIPType.FLUXC
elif type == "fluxmod":
clip_type = comfy.sd.CLIPType.FLUXMOD

clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options={"custom_operations": OPS})
return (clip,)

node_list = [
# Checkpoint model loaders
CheckpointLoaderNF4,
# Diffusion model loaders
UNETLoaderNF4,
# Text encoder model loaders
# WIP CLIPLoaderNF4,
# WIP DualCLIPLoaderNF4,
]

NODE_CLASS_MAPPINGS = {}
NODE_DISPLAY_NAME_MAPPINGS = {}

for node in node_list:
NODE_CLASS_MAPPINGS[node.NodeId] = node
NODE_DISPLAY_NAME_MAPPINGS[node.NodeId] = node.NodeName
14 changes: 14 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
[project]
name = "ComfyUI_bnb_nf4_fp4_Loaders"
description = 'Checkpoint and UNET/Diffusion Model loaders for <a href="https://github.com/bitsandbytes-foundation/bitsandbytes">BitsAndBytes</a> NF4 quantized models'
version = "1.0.5"
license = { file = "LICENSE.txt" }

[project.urls]
Repository = "https://github.com/silveroxides/ComfyUI_bitsandbytes_NF4"
# Used by Comfy Registry https://comfyregistry.org

[tool.comfy]
PublisherId = "silveroxides"
DisplayName = "BitsAndBytes NF4 Loaders"
Icon = ""