From cd2674c1cde2e3d9b61b4fa2648f605ebbce8f7b Mon Sep 17 00:00:00 2001 From: Silver Date: Fri, 16 Aug 2024 20:05:19 +0200 Subject: [PATCH 01/16] Add UNET Loader --- __init__.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/__init__.py b/__init__.py index 8b72c7f..540e224 100644 --- a/__init__.py +++ b/__init__.py @@ -178,7 +178,27 @@ def load_checkpoint(self, ckpt_name): out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"), model_options={"custom_operations": OPS}) return out[:3] +class UNETLoaderNF4: + @classmethod + def INPUT_TYPES(s): + return {"required": { "unet_name": (folder_paths.get_filename_list("unet"), ), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "load_unet" + + CATEGORY = "advanced/loaders" + + def load_unet(self, unet_name): + unet_path = folder_paths.get_full_path("unet", unet_name) + model = comfy.sd.load_diffusion_model(unet_path, model_options={"custom_operations": OPS}) + return (model,) + NODE_CLASS_MAPPINGS = { "CheckpointLoaderNF4": CheckpointLoaderNF4, + "UNETLoaderNF4": UNETLoaderNF4, } +NODE_DISPLAY_NAME_MAPPINGS = { + "CheckpointLoaderNF4": "Load NF4 Flux Checkpoint", + "UNETLoaderNF4": "Load NF4 Flux UNET", +} From ed51115578207927b1aa56cab06784f4d441a978 Mon Sep 17 00:00:00 2001 From: Udhul <126940798+Udhul@users.noreply.github.com> Date: Wed, 21 Aug 2024 21:58:16 +0200 Subject: [PATCH 02/16] Support V2 Model storing chunk 64 norm in full precision float32 and removing the second stage of double quantization. --- __init__.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/__init__.py b/__init__.py index 8b72c7f..5a99bc4 100644 --- a/__init__.py +++ b/__init__.py @@ -36,7 +36,7 @@ def copy_quant_state(state: QuantState, device: torch.device = None) -> QuantSta else None ) - return QuantState( + quant_state = QuantState( absmax=state.absmax.to(device), shape=state.shape, code=state.code.to(device), @@ -47,6 +47,11 @@ def copy_quant_state(state: QuantState, device: torch.device = None) -> QuantSta state2=state2, ) + # Manually add chunk_64_norm as an attribute if it exists + if hasattr(state, 'chunk_64_norm'): + quant_state.chunk_64_norm = state.chunk_64_norm.to(device) + + return quant_state class ForgeParams4bit(Params4bit): def to(self, *args, **kwargs): @@ -65,6 +70,11 @@ def to(self, *args, **kwargs): bnb_quantized=self.bnb_quantized, module=self.module ) + + # Manually copy chunk_64_norm if it exists + if hasattr(self.quant_state, 'chunk_64_norm'): + n.quant_state.chunk_64_norm = self.quant_state.chunk_64_norm + self.module.quant_state = n.quant_state self.data = n.data self.quant_state = n.quant_state From 06c84e09dc2087317632187577c48aaf25f0c791 Mon Sep 17 00:00:00 2001 From: Silver Date: Tue, 27 Aug 2024 00:08:01 +0200 Subject: [PATCH 03/16] change unet folder to diffusion_models --- __init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/__init__.py b/__init__.py index 21fc397..5a1c799 100644 --- a/__init__.py +++ b/__init__.py @@ -191,7 +191,7 @@ def load_checkpoint(self, ckpt_name): class UNETLoaderNF4: @classmethod def INPUT_TYPES(s): - return {"required": { "unet_name": (folder_paths.get_filename_list("unet"), ), + return {"required": { "unet_name": (folder_paths.get_filename_list("diffusion_models"), ), }} RETURN_TYPES = ("MODEL",) FUNCTION = "load_unet" From 402b89dc2a7a3d73e121dcd1f98df02afd90f5e1 Mon Sep 17 00:00:00 2001 From: Silver <65376327+silveroxides@users.noreply.github.com> Date: Sat, 31 Aug 2024 14:30:51 +0200 Subject: [PATCH 04/16] Update README.md --- README.md | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index c205b41..4988b98 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,27 @@ -#### NOTE: This is very likely Deprecated in favor of GGUF which seems to give better results: https://github.com/city96/ComfyUI-GGUF +#### ~~NOTE: This is very likely Deprecated in favor of GGUF which seems to give better results~~ +Some users can experience speedup by combining loading UNET as NF4 using the loader from this repo and load T5XXL as GGUF using the repo from https://github.com/city96/ComfyUI-GGUF Now on the [manager](https://github.com/ltdrdata/ComfyUI-Manager) for easy installation. Make sure to select Channel:dev in the ComfyUI manager menu or install via git url. -A quickly written custom node that uses code from [Forge](https://github.com/lllyasviel/stable-diffusion-webui-forge) to support the [nf4 flux dev checkpoint](https://huggingface.co/lllyasviel/flux1-dev-bnb-nf4/blob/main/flux1-dev-bnb-nf4.safetensors) and [nf4 flux schnell checkpoint](https://huggingface.co/silveroxides/flux1-nf4-weights/blob/main/flux1-schnell-bnb-nf4.safetensors). +You can find the checkpoints and UNET in the linked repositories on huggingface or by searching for NF4 on Civitai + +### [CivitAI search link](https://civitai.com/search/models?baseModel=Flux.1%20S&baseModel=Flux.1%20D&sortBy=models_v9&query=nf4) + +### [nf4 flux unet only](https://huggingface.co/silveroxides/flux1-nf4-unet) + +### [nf4 flux dev checkpoint](https://huggingface.co/lllyasviel/flux1-dev-bnb-nf4/blob/main/flux1-dev-bnb-nf4.safetensors) + +### [nf4 flux schnell checkpoint](https://huggingface.co/silveroxides/flux1-nf4-weights/blob/main/flux1-schnell-bnb-nf4.safetensors) Requires installing bitsandbytes. Make sure your ComfyUI is updated. -The node is: CheckpointLoaderNF4, just plug it in your flux workflow instead of the regular one. +The nodes are: +#### "CheckpointLoaderNF4": "Load NF4 Flux Checkpoint" + +#### "UNETLoaderNF4": "Load NF4 Flux UNET" + +just plug it in your flux workflow instead of the regular ones. + +Code adapted from the implementation by Illyasviel at [Forge](https://github.com/lllyasviel/stable-diffusion-webui-forge). From ec3952b16561f68f793d968cb842dc61e931838c Mon Sep 17 00:00:00 2001 From: Thomas Ward Date: Sat, 31 Aug 2024 21:23:29 -0400 Subject: [PATCH 05/16] github workflow: publish to Comfy Registry Requisite: setting your GH Secrets - refer to https://docs.comfy.org/registry/overview#option-2-github-actions - so they are usable in the workflow. Also requires you getting a personal access token on the Comfy Registry (see link). --- .github/workflows/publish_action.yml | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100644 .github/workflows/publish_action.yml diff --git a/.github/workflows/publish_action.yml b/.github/workflows/publish_action.yml new file mode 100644 index 0000000..e3c5c78 --- /dev/null +++ b/.github/workflows/publish_action.yml @@ -0,0 +1,20 @@ +name: Publish to Comfy registry +on: + workflow_dispatch: + push: + branches: + - main + paths: + - "pyproject.toml" + +jobs: + publish-node: + name: Publish Custom Node to registry + runs-on: ubuntu-latest + steps: + - name: Check out code + uses: actions/checkout@v4 + - name: Publish Custom Node + uses: Comfy-Org/publish-node-action@main + with: + personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }} ## Add your own personal access token to your Github Repository secrets and reference it here. From f5a678373283c12ea7e10496c683f95fcdc287bd Mon Sep 17 00:00:00 2001 From: Thomas Ward Date: Sat, 31 Aug 2024 21:28:59 -0400 Subject: [PATCH 06/16] Update publish_action.yml --- .github/workflows/publish_action.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/publish_action.yml b/.github/workflows/publish_action.yml index e3c5c78..2773024 100644 --- a/.github/workflows/publish_action.yml +++ b/.github/workflows/publish_action.yml @@ -3,7 +3,7 @@ on: workflow_dispatch: push: branches: - - main + - master paths: - "pyproject.toml" From 3eada96cfef448535a55c2b4f4b80a4329d0f2a8 Mon Sep 17 00:00:00 2001 From: Silver Date: Tue, 17 Sep 2024 20:38:38 +0200 Subject: [PATCH 07/16] Add pyproject.toml --- pyproject.toml | 14 ++++++++++++++ 1 file changed, 14 insertions(+) create mode 100644 pyproject.toml diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..ee8f495 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,14 @@ +[project] +name = "ComfyUI_bnb_NF4_loaders" +description = "Checkpoint and UNET/Diffusion Model loaders for BitsAndBytes NF4 quantized models" +version = "1.0.0" +license = { file = "LICENSE.txt" } + +[project.urls] +Repository = "https://github.com/silveroxides/ComfyUI_bitsandbytes_NF4" +# Used by Comfy Registry https://comfyregistry.org + +[tool.comfy] +PublisherId = "silveroxides" +DisplayName = "BitsAndBytes NF4 Loaders" +Icon = "" \ No newline at end of file From 00b189891e97955af2f19c61dfd4dcd0dca13628 Mon Sep 17 00:00:00 2001 From: Silver <65376327+silveroxides@users.noreply.github.com> Date: Tue, 17 Sep 2024 20:42:07 +0200 Subject: [PATCH 08/16] Update pyproject.toml --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index ee8f495..7c09b59 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "ComfyUI_bnb_NF4_loaders" description = "Checkpoint and UNET/Diffusion Model loaders for BitsAndBytes NF4 quantized models" -version = "1.0.0" +version = "1.0.1" license = { file = "LICENSE.txt" } [project.urls] @@ -11,4 +11,4 @@ Repository = "https://github.com/silveroxides/ComfyUI_bitsandbytes_NF4" [tool.comfy] PublisherId = "silveroxides" DisplayName = "BitsAndBytes NF4 Loaders" -Icon = "" \ No newline at end of file +Icon = "" From 359f0d47ab9a1731e5ac09a40c024e07e7a5c75d Mon Sep 17 00:00:00 2001 From: Silver <65376327+silveroxides@users.noreply.github.com> Date: Tue, 17 Sep 2024 20:44:57 +0200 Subject: [PATCH 09/16] Update pyproject.toml --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 7c09b59..b8f04e3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] -name = "ComfyUI_bnb_NF4_loaders" +name = "comfyui_bnb_nf4_loaders" description = "Checkpoint and UNET/Diffusion Model loaders for BitsAndBytes NF4 quantized models" -version = "1.0.1" +version = "1.0.2" license = { file = "LICENSE.txt" } [project.urls] From 7a194740d60fe2e52fb31bb6b38918088d1f2278 Mon Sep 17 00:00:00 2001 From: silveroxides Date: Mon, 20 Jan 2025 09:28:30 +0100 Subject: [PATCH 10/16] Add preliminary WIP text encoder loaders --- __init__.py | 95 ++++++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 87 insertions(+), 8 deletions(-) diff --git a/__init__.py b/__init__.py index 5a1c799..1ac454c 100644 --- a/__init__.py +++ b/__init__.py @@ -174,6 +174,8 @@ def forward(self, x): class CheckpointLoaderNF4: + NodeId = 'CheckpointLoaderNF4' + NodeName = 'Load FP4 or NF4 Quantized Checkpoint Model' @classmethod def INPUT_TYPES(s): return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ), @@ -189,6 +191,8 @@ def load_checkpoint(self, ckpt_name): return out[:3] class UNETLoaderNF4: + NodeId = 'UNETLoaderNF4' + NodeName = 'Load FP4 or NF4 Quantized Diffusion or UNET Model' @classmethod def INPUT_TYPES(s): return {"required": { "unet_name": (folder_paths.get_filename_list("diffusion_models"), ), @@ -203,12 +207,87 @@ def load_unet(self, unet_name): model = comfy.sd.load_diffusion_model(unet_path, model_options={"custom_operations": OPS}) return (model,) -NODE_CLASS_MAPPINGS = { - "CheckpointLoaderNF4": CheckpointLoaderNF4, - "UNETLoaderNF4": UNETLoaderNF4, -} +class CLIPLoaderNF4: + # WIP + NodeId = 'CLIPLoaderNF4' + NodeName = 'Load FP4 or NF4 Quantized Text Encoder' + @classmethod + def INPUT_TYPES(s): + return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ), + "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos"], ), + },} + RETURN_TYPES = ("CLIP",) + FUNCTION = "load_clip" + + CATEGORY = "advanced/loaders" + + def load_clip(self, clip_name, type): + if type == "stable_cascade": + clip_type = comfy.sd.CLIPType.STABLE_CASCADE + elif type == "sd3": + clip_type = comfy.sd.CLIPType.SD3 + elif type == "stable_audio": + clip_type = comfy.sd.CLIPType.STABLE_AUDIO + elif type == "mochi": + clip_type = comfy.sd.CLIPType.MOCHI + elif type == "ltxv": + clip_type = comfy.sd.CLIPType.LTXV + elif type == "pixart": + clip_type = comfy.sd.CLIPType.PIXART + else: + clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION + + clip_path = folder_paths.get_full_path("clip", clip_name) + clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options={"custom_operations": OPS}) + return (clip,) + +class DualCLIPLoaderNF4: + # WIP + NodeId = 'DualCLIPLoaderNF4' + NodeName = 'Load FP4 or NF4 Quantized Text Encoders' + @classmethod + def INPUT_TYPES(s): + return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ), + "clip_name2": (folder_paths.get_filename_list("text_encoders"), ), + "type": (["sdxl", "sd3", "flux", "hunyuan_video", "custom_flux", "fluxmod"], ), + },} + RETURN_TYPES = ("CLIP",) + FUNCTION = "load_clip" + + CATEGORY = "advanced/loaders" -NODE_DISPLAY_NAME_MAPPINGS = { - "CheckpointLoaderNF4": "Load NF4 Flux Checkpoint", - "UNETLoaderNF4": "Load NF4 Flux UNET", -} + def load_clip(self, clip_name1, clip_name2, type): + clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1) + clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2) + if type == "sdxl": + clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION + elif type == "sd3": + clip_type = comfy.sd.CLIPType.SD3 + elif type == "flux": + clip_type = comfy.sd.CLIPType.FLUX + elif type == "hunyuan_video": + clip_type = comfy.sd.CLIPType.HUNYUAN_VIDEO + elif type == "custom_flux": + clip_type = comfy.sd.CLIPType.FLUXC + elif type == "fluxmod": + clip_type = comfy.sd.CLIPType.FLUXMOD + + clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options={"custom_operations": OPS}) + return (clip,) + +node_list = [ + # Checkpoint model loaders + CheckpointLoaderNF4, + # Diffusion model loaders + UNETLoaderNF4, + # Text encoder model loaders + # WIP CLIPLoaderNF4, + # WIP DualCLIPLoaderNF4, +] + +NODE_CLASS_MAPPINGS = {} +NODE_DISPLAY_NAME_MAPPINGS = {} + +for node in node_list: + NODE_CLASS_MAPPINGS[node.NodeId] = node + NODE_DISPLAY_NAME_MAPPINGS[node.NodeId] = node.NodeName From 824826f735d3ce92d5c4db0a0a609a139806d59a Mon Sep 17 00:00:00 2001 From: Silver <65376327+silveroxides@users.noreply.github.com> Date: Mon, 20 Jan 2025 09:30:13 +0100 Subject: [PATCH 11/16] Update pyproject.toml Includes Preliminary inactive WIP Text Encoder loaders and change to the Node CLASS and NAME MAPPING --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index b8f04e3..98a1011 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "comfyui_bnb_nf4_loaders" description = "Checkpoint and UNET/Diffusion Model loaders for BitsAndBytes NF4 quantized models" -version = "1.0.2" +version = "1.0.3" license = { file = "LICENSE.txt" } [project.urls] From 550402372692fb6603ceea31936cf0ffd27b23ea Mon Sep 17 00:00:00 2001 From: snomiao Date: Sat, 25 Jan 2025 07:28:04 +0000 Subject: [PATCH 12/16] chore(publish): update workflow for node publishing with permissions and version check --- .github/workflows/publish_action.yml | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/.github/workflows/publish_action.yml b/.github/workflows/publish_action.yml index 2773024..25ce502 100644 --- a/.github/workflows/publish_action.yml +++ b/.github/workflows/publish_action.yml @@ -7,14 +7,18 @@ on: paths: - "pyproject.toml" +permissions: + issues: write + jobs: publish-node: name: Publish Custom Node to registry runs-on: ubuntu-latest + if: ${{ github.repository_owner == 'silveroxides' }} steps: - name: Check out code uses: actions/checkout@v4 - name: Publish Custom Node - uses: Comfy-Org/publish-node-action@main + uses: Comfy-Org/publish-node-action@v1 with: personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }} ## Add your own personal access token to your Github Repository secrets and reference it here. From a4759fba005928f19f86b66160d76a333c289b79 Mon Sep 17 00:00:00 2001 From: Silver <65376327+silveroxides@users.noreply.github.com> Date: Fri, 14 Mar 2025 23:59:41 +0100 Subject: [PATCH 13/16] Update pyproject.toml --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 98a1011..89df5a2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] -name = "comfyui_bnb_nf4_loaders" +name = "ComfyUI_bnb_nf4_fp4_Loaders" description = "Checkpoint and UNET/Diffusion Model loaders for BitsAndBytes NF4 quantized models" -version = "1.0.3" +version = "1.0.4" license = { file = "LICENSE.txt" } [project.urls] From d726db689f90ca0e9971b246a7760bb6533167ad Mon Sep 17 00:00:00 2001 From: Silver <65376327+silveroxides@users.noreply.github.com> Date: Wed, 19 Mar 2025 12:24:11 +0100 Subject: [PATCH 14/16] Update pyproject.toml --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 89df5a2..c920ace 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI_bnb_nf4_fp4_Loaders" -description = "Checkpoint and UNET/Diffusion Model loaders for BitsAndBytes NF4 quantized models" +description = 'Checkpoint and UNET/Diffusion Model loaders for BitsAndBytes NF4 quantized models' version = "1.0.4" license = { file = "LICENSE.txt" } From a83f91e8af1179a62af2257498c3c651685cb2cd Mon Sep 17 00:00:00 2001 From: Silver <65376327+silveroxides@users.noreply.github.com> Date: Wed, 19 Mar 2025 12:25:07 +0100 Subject: [PATCH 15/16] Update pyproject.toml --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index c920ace..33e8bf1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "ComfyUI_bnb_nf4_fp4_Loaders" description = 'Checkpoint and UNET/Diffusion Model loaders for BitsAndBytes NF4 quantized models' -version = "1.0.4" +version = "1.0.5" license = { file = "LICENSE.txt" } [project.urls] From d8681a8d8222dd6a1c8feb0a251edc2a29442709 Mon Sep 17 00:00:00 2001 From: silveroxides Date: Fri, 4 Apr 2025 19:49:14 +0200 Subject: [PATCH 16/16] add support for loading unquantized tensors. Need work though --- __init__.py | 152 ++++++++++++++++++++++++++++++---------------------- 1 file changed, 88 insertions(+), 64 deletions(-) diff --git a/__init__.py b/__init__.py index 1ac454c..f0039e9 100644 --- a/__init__.py +++ b/__init__.py @@ -1,10 +1,7 @@ #shamelessly taken from forge -import nodes import folder_paths -import bitsandbytes - import torch import bitsandbytes as bnb @@ -36,7 +33,7 @@ def copy_quant_state(state: QuantState, device: torch.device = None) -> QuantSta else None ) - quant_state = QuantState( + return QuantState( absmax=state.absmax.to(device), shape=state.shape, code=state.code.to(device), @@ -47,19 +44,39 @@ def copy_quant_state(state: QuantState, device: torch.device = None) -> QuantSta state2=state2, ) - # Manually add chunk_64_norm as an attribute if it exists - if hasattr(state, 'chunk_64_norm'): - quant_state.chunk_64_norm = state.chunk_64_norm.to(device) +class ForgeParams4bit(Params4bit): + _torch_fn_depth=0 - return quant_state + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if cls._torch_fn_depth > 0 or func != torch._C.TensorBase.detach: + return super().__torch_function__(func, types, args, kwargs or {}) + cls._torch_fn_depth += 1 + try: + slf = args[0] + n = cls( + torch.nn.Parameter.detach(slf), + requires_grad=slf.requires_grad, + quant_state=copy_quant_state(slf.quant_state, slf.device), + blocksize=slf.blocksize, + compress_statistics=slf.compress_statistics, + quant_type=slf.quant_type, + quant_storage=slf.quant_storage, + bnb_quantized=slf.bnb_quantized, + module=slf.module + ) + return n + finally: + cls._torch_fn_depth -= 1 -class ForgeParams4bit(Params4bit): - def to(self, *args, **kwargs): + def to(self, *args, copy=False, **kwargs): + if copy: + return self.clone().to(*args, **kwargs) device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) if device is not None and device.type == "cuda" and not self.bnb_quantized: return self._quantize(device) else: - n = ForgeParams4bit( + n = self.__class__( torch.nn.Parameter.to(self, device=device, dtype=dtype, non_blocking=non_blocking), requires_grad=self.requires_grad, quant_state=copy_quant_state(self.quant_state, device), @@ -70,11 +87,6 @@ def to(self, *args, **kwargs): bnb_quantized=self.bnb_quantized, module=self.module ) - - # Manually copy chunk_64_norm if it exists - if hasattr(self.quant_state, 'chunk_64_norm'): - n.quant_state.chunk_64_norm = self.quant_state.chunk_64_norm - self.module.quant_state = n.quant_state self.data = n.data self.quant_state = n.quant_state @@ -136,58 +148,61 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, miss else: super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) -current_device = None -current_dtype = None -current_manual_cast_enabled = False -current_bnb_dtype = None import comfy.ops -class OPS(comfy.ops.manual_cast): - class Linear(ForgeLoader4Bit): - def __init__(self, *args, device=None, dtype=None, **kwargs): - super().__init__(device=device, dtype=dtype, quant_type=current_bnb_dtype) - self.parameters_manual_cast = current_manual_cast_enabled - - def forward(self, x): - self.weight.quant_state = self.quant_state - - if self.bias is not None and self.bias.dtype != x.dtype: - # Maybe this can also be set to all non-bnb ops since the cost is very low. - # And it only invokes one time, and most linear does not have bias - self.bias.data = self.bias.data.to(x.dtype) - - if not self.parameters_manual_cast: - return functional_linear_4bits(x, self.weight, self.bias) - elif not self.weight.bnb_quantized: - assert x.device.type == 'cuda', 'BNB Must Use CUDA as Computation Device!' - layer_original_device = self.weight.device - self.weight = self.weight._quantize(x.device) - bias = self.bias.to(x.device) if self.bias is not None else None - out = functional_linear_4bits(x, self.weight, bias) - self.weight = self.weight.to(layer_original_device) - return out - else: - weight, bias, signal = weights_manual_cast(self, x, skip_weight_dtype=True, skip_bias_dtype=True) - with main_stream_worker(weight, bias, signal): - return functional_linear_4bits(x, weight, bias) - +def make_ops(loader_class, current_device = None, current_dtype = None, current_manual_cast_enabled = False, current_bnb_dtype = None): + + class OPS(comfy.ops.manual_cast): + class Linear(loader_class): + def __init__(self, *args, device=None, dtype=None, **kwargs): + super().__init__(device=device, dtype=dtype, quant_type=current_bnb_dtype) + self.parameters_manual_cast = current_manual_cast_enabled + + def forward(self, x): + self.weight.quant_state = self.quant_state + + if self.bias is not None and self.bias.dtype != x.dtype: + # Maybe this can also be set to all non-bnb ops since the cost is very low. + # And it only invokes one time, and most linear does not have bias + self.bias.data = self.bias.data.to(x.dtype) + + if not self.parameters_manual_cast: + return functional_linear_4bits(x, self.weight, self.bias) + elif not self.weight.bnb_quantized: + assert x.device.type == 'cuda', 'BNB Must Use CUDA as Computation Device!' + layer_original_device = self.weight.device + self.weight = self.weight._quantize(x.device) + bias = self.bias.to(x.device) if self.bias is not None else None + out = functional_linear_4bits(x, self.weight, bias) + self.weight = self.weight.to(layer_original_device) + return out + else: + raise RuntimeError("Unexpected state in forward") + + return OPS class CheckpointLoaderNF4: NodeId = 'CheckpointLoaderNF4' NodeName = 'Load FP4 or NF4 Quantized Checkpoint Model' @classmethod def INPUT_TYPES(s): - return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ), - }} + return {"required": { + "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ), + "bnb_dtype": (("default", "nf4", "fp4"), {"default": "default"}), + }} + RETURN_TYPES = ("MODEL", "CLIP", "VAE") FUNCTION = "load_checkpoint" CATEGORY = "loaders" - def load_checkpoint(self, ckpt_name): + def load_checkpoint(self, ckpt_name, bnb_dtype="default"): + if bnb_dtype == "default": + bnb_dtype = None + ops = make_ops(ForgeLoader4Bit, current_bnb_dtype = bnb_dtype) ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) - out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"), model_options={"custom_operations": OPS}) + out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"), model_options={"custom_operations": ops}) return out[:3] class UNETLoaderNF4: @@ -195,16 +210,21 @@ class UNETLoaderNF4: NodeName = 'Load FP4 or NF4 Quantized Diffusion or UNET Model' @classmethod def INPUT_TYPES(s): - return {"required": { "unet_name": (folder_paths.get_filename_list("diffusion_models"), ), - }} + return {"required": { + "unet_name": (folder_paths.get_filename_list("diffusion_models"), ), + "bnb_dtype": (("default", "nf4", "fp4"), {"default": "default"}), + }} RETURN_TYPES = ("MODEL",) FUNCTION = "load_unet" CATEGORY = "advanced/loaders" - def load_unet(self, unet_name): + def load_unet(self, unet_name, bnb_dtype="default"): + if bnb_dtype == "default": + bnb_dtype = None + ops = make_ops(ForgeLoader4Bit, current_bnb_dtype = bnb_dtype) unet_path = folder_paths.get_full_path("unet", unet_name) - model = comfy.sd.load_diffusion_model(unet_path, model_options={"custom_operations": OPS}) + model = comfy.sd.load_diffusion_model(unet_path, model_options={"custom_operations": ops}) return (model,) class CLIPLoaderNF4: @@ -213,9 +233,11 @@ class CLIPLoaderNF4: NodeName = 'Load FP4 or NF4 Quantized Text Encoder' @classmethod def INPUT_TYPES(s): - return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ), - "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos"], ), - },} + return {"required": { + "clip_name": (folder_paths.get_filename_list("text_encoders"), ), + "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos"], ), + "bnb_dtype": (("default", "nf4", "fp4"), {"default": "default"}), + }} RETURN_TYPES = ("CLIP",) FUNCTION = "load_clip" @@ -247,10 +269,12 @@ class DualCLIPLoaderNF4: NodeName = 'Load FP4 or NF4 Quantized Text Encoders' @classmethod def INPUT_TYPES(s): - return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ), - "clip_name2": (folder_paths.get_filename_list("text_encoders"), ), - "type": (["sdxl", "sd3", "flux", "hunyuan_video", "custom_flux", "fluxmod"], ), - },} + return {"required": { + "clip_name1": (folder_paths.get_filename_list("text_encoders"), ), + "clip_name2": (folder_paths.get_filename_list("text_encoders"), ), + "type": (["sdxl", "sd3", "flux", "hunyuan_video", "custom_flux", "fluxmod"], ), + "bnb_dtype": (("default", "nf4", "fp4"), {"default": "default"}), + }} RETURN_TYPES = ("CLIP",) FUNCTION = "load_clip"